xmpp/stream.go
changeset 143 62166e57800e
parent 142 0ff033eed887
child 144 9d7fdb1d2fc1
equal deleted inserted replaced
142:0ff033eed887 143:62166e57800e
     1 // This file contains the three layers of processing for the
       
     2 // communication with the server: transport (where TLS happens), XML
       
     3 // (where strings are converted to go structures), and Stream (where
       
     4 // we respond to XMPP events on behalf of the library client), or send
       
     5 // those events to the client.
       
     6 
       
     7 package xmpp
       
     8 
       
     9 import (
       
    10 	"crypto/md5"
       
    11 	"crypto/rand"
       
    12 	"crypto/tls"
       
    13 	"encoding/base64"
       
    14 	"encoding/xml"
       
    15 	"fmt"
       
    16 	"io"
       
    17 	"math/big"
       
    18 	"net"
       
    19 	"reflect"
       
    20 	"regexp"
       
    21 	"strings"
       
    22 	"time"
       
    23 )
       
    24 
       
    25 // Callback to handle a stanza with a particular id.
       
    26 type stanzaHandler struct {
       
    27 	id string
       
    28 	// Return true means pass this to the application
       
    29 	f func(Stanza) bool
       
    30 }
       
    31 
       
    32 func (cl *Client) readTransport(w io.WriteCloser) {
       
    33 	defer w.Close()
       
    34 	p := make([]byte, 1024)
       
    35 	for {
       
    36 		if cl.socket == nil {
       
    37 			cl.waitForSocket()
       
    38 		}
       
    39 		cl.socket.SetReadDeadline(time.Now().Add(time.Second))
       
    40 		nr, err := cl.socket.Read(p)
       
    41 		if nr == 0 {
       
    42 			if errno, ok := err.(*net.OpError); ok {
       
    43 				if errno.Timeout() {
       
    44 					continue
       
    45 				}
       
    46 			}
       
    47 			Warn.Logf("read: %s", err)
       
    48 			break
       
    49 		}
       
    50 		nw, err := w.Write(p[:nr])
       
    51 		if nw < nr {
       
    52 			Warn.Logf("read: %s", err)
       
    53 			break
       
    54 		}
       
    55 	}
       
    56 }
       
    57 
       
    58 func (cl *Client) writeTransport(r io.Reader) {
       
    59 	defer cl.socket.Close()
       
    60 	p := make([]byte, 1024)
       
    61 	for {
       
    62 		nr, err := r.Read(p)
       
    63 		if nr == 0 {
       
    64 			Warn.Logf("write: %s", err)
       
    65 			break
       
    66 		}
       
    67 		nw, err := cl.socket.Write(p[:nr])
       
    68 		if nw < nr {
       
    69 			Warn.Logf("write: %s", err)
       
    70 			break
       
    71 		}
       
    72 	}
       
    73 }
       
    74 
       
    75 func readXml(r io.Reader, ch chan<- interface{},
       
    76 	extStanza map[xml.Name]reflect.Type) {
       
    77 	if _, ok := Debug.(*noLog); !ok {
       
    78 		pr, pw := io.Pipe()
       
    79 		go tee(r, pw, "S: ")
       
    80 		r = pr
       
    81 	}
       
    82 	defer close(ch)
       
    83 
       
    84 	// This trick loads our namespaces into the parser.
       
    85 	nsstr := fmt.Sprintf(`<a xmlns="%s" xmlns:stream="%s">`,
       
    86 		NsClient, NsStream)
       
    87 	nsrdr := strings.NewReader(nsstr)
       
    88 	p := xml.NewDecoder(io.MultiReader(nsrdr, r))
       
    89 	p.Token()
       
    90 
       
    91 Loop:
       
    92 	for {
       
    93 		// Sniff the next token on the stream.
       
    94 		t, err := p.Token()
       
    95 		if t == nil {
       
    96 			if err != io.EOF {
       
    97 				Warn.Logf("read: %s", err)
       
    98 			}
       
    99 			break
       
   100 		}
       
   101 		var se xml.StartElement
       
   102 		var ok bool
       
   103 		if se, ok = t.(xml.StartElement); !ok {
       
   104 			continue
       
   105 		}
       
   106 
       
   107 		// Allocate the appropriate structure for this token.
       
   108 		var obj interface{}
       
   109 		switch se.Name.Space + " " + se.Name.Local {
       
   110 		case NsStream + " stream":
       
   111 			st, err := parseStream(se)
       
   112 			if err != nil {
       
   113 				Warn.Logf("unmarshal stream: %s", err)
       
   114 				break Loop
       
   115 			}
       
   116 			ch <- st
       
   117 			continue
       
   118 		case "stream error", NsStream + " error":
       
   119 			obj = &streamError{}
       
   120 		case NsStream + " features":
       
   121 			obj = &Features{}
       
   122 		case NsTLS + " proceed", NsTLS + " failure":
       
   123 			obj = &starttls{}
       
   124 		case NsSASL + " challenge", NsSASL + " failure",
       
   125 			NsSASL + " success":
       
   126 			obj = &auth{}
       
   127 		case NsClient + " iq":
       
   128 			obj = &Iq{}
       
   129 		case NsClient + " message":
       
   130 			obj = &Message{}
       
   131 		case NsClient + " presence":
       
   132 			obj = &Presence{}
       
   133 		default:
       
   134 			obj = &Generic{}
       
   135 			Info.Logf("Ignoring unrecognized: %s %s", se.Name.Space,
       
   136 				se.Name.Local)
       
   137 		}
       
   138 
       
   139 		// Read the complete XML stanza.
       
   140 		err = p.DecodeElement(obj, &se)
       
   141 		if err != nil {
       
   142 			Warn.Logf("unmarshal: %s", err)
       
   143 			break Loop
       
   144 		}
       
   145 
       
   146 		// If it's a Stanza, we try to unmarshal its innerxml
       
   147 		// into objects of the appropriate respective
       
   148 		// types. This is specified by our extensions.
       
   149 		if st, ok := obj.(Stanza); ok {
       
   150 			err = parseExtended(st.GetHeader(), extStanza)
       
   151 			if err != nil {
       
   152 				Warn.Logf("ext unmarshal: %s", err)
       
   153 				break Loop
       
   154 			}
       
   155 		}
       
   156 
       
   157 		// Put it on the channel.
       
   158 		ch <- obj
       
   159 	}
       
   160 }
       
   161 
       
   162 func parseExtended(st *Header, extStanza map[xml.Name]reflect.Type) error {
       
   163 	// Now parse the stanza's innerxml to find the string that we
       
   164 	// can unmarshal this nested element from.
       
   165 	reader := strings.NewReader(st.Innerxml)
       
   166 	p := xml.NewDecoder(reader)
       
   167 	for {
       
   168 		t, err := p.Token()
       
   169 		if err == io.EOF {
       
   170 			break
       
   171 		}
       
   172 		if err != nil {
       
   173 			return err
       
   174 		}
       
   175 		if se, ok := t.(xml.StartElement); ok {
       
   176 			if typ, ok := extStanza[se.Name]; ok {
       
   177 				nested := reflect.New(typ).Interface()
       
   178 
       
   179 				// Unmarshal the nested element and
       
   180 				// stuff it back into the stanza.
       
   181 				err := p.DecodeElement(nested, &se)
       
   182 				if err != nil {
       
   183 					return err
       
   184 				}
       
   185 				st.Nested = append(st.Nested, nested)
       
   186 			}
       
   187 		}
       
   188 	}
       
   189 
       
   190 	return nil
       
   191 }
       
   192 
       
   193 func writeXml(w io.Writer, ch <-chan interface{}) {
       
   194 	if _, ok := Debug.(*noLog); !ok {
       
   195 		pr, pw := io.Pipe()
       
   196 		go tee(pr, w, "C: ")
       
   197 		w = pw
       
   198 	}
       
   199 	defer func(w io.Writer) {
       
   200 		if c, ok := w.(io.Closer); ok {
       
   201 			c.Close()
       
   202 		}
       
   203 	}(w)
       
   204 
       
   205 	enc := xml.NewEncoder(w)
       
   206 
       
   207 	for obj := range ch {
       
   208 		if st, ok := obj.(*stream); ok {
       
   209 			_, err := w.Write([]byte(st.String()))
       
   210 			if err != nil {
       
   211 				Warn.Logf("write: %s", err)
       
   212 			}
       
   213 		} else {
       
   214 			err := enc.Encode(obj)
       
   215 			if err != nil {
       
   216 				Warn.Logf("marshal: %s", err)
       
   217 				break
       
   218 			}
       
   219 		}
       
   220 	}
       
   221 }
       
   222 
       
   223 func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- Stanza) {
       
   224 	defer close(cliOut)
       
   225 
       
   226 	handlers := make(map[string]func(Stanza) bool)
       
   227 Loop:
       
   228 	for {
       
   229 		select {
       
   230 		case h := <-cl.handlers:
       
   231 			handlers[h.id] = h.f
       
   232 		case x, ok := <-srvIn:
       
   233 			if !ok {
       
   234 				break Loop
       
   235 			}
       
   236 			switch obj := x.(type) {
       
   237 			case *stream:
       
   238 				handleStream(obj)
       
   239 			case *streamError:
       
   240 				cl.handleStreamError(obj)
       
   241 			case *Features:
       
   242 				cl.handleFeatures(obj)
       
   243 			case *starttls:
       
   244 				cl.handleTls(obj)
       
   245 			case *auth:
       
   246 				cl.handleSasl(obj)
       
   247 			case Stanza:
       
   248 				send := true
       
   249 				id := obj.GetHeader().Id
       
   250 				if handlers[id] != nil {
       
   251 					f := handlers[id]
       
   252 					delete(handlers, id)
       
   253 					send = f(obj)
       
   254 				}
       
   255 				if send {
       
   256 					cliOut <- obj
       
   257 				}
       
   258 			default:
       
   259 				Warn.Logf("Unhandled non-stanza: %T %#v", x, x)
       
   260 			}
       
   261 		}
       
   262 	}
       
   263 }
       
   264 
       
   265 // This loop is paused until resource binding is complete. Otherwise
       
   266 // the app might inject something inappropriate into our negotiations
       
   267 // with the server. The control channel controls this loop's
       
   268 // activity.
       
   269 func writeStream(srvOut chan<- interface{}, cliIn <-chan Stanza,
       
   270 	control <-chan int) {
       
   271 	defer close(srvOut)
       
   272 
       
   273 	var input <-chan Stanza
       
   274 Loop:
       
   275 	for {
       
   276 		select {
       
   277 		case status := <-control:
       
   278 			switch status {
       
   279 			case 0:
       
   280 				input = nil
       
   281 			case 1:
       
   282 				input = cliIn
       
   283 			case -1:
       
   284 				break Loop
       
   285 			}
       
   286 		case x, ok := <-input:
       
   287 			if !ok {
       
   288 				break Loop
       
   289 			}
       
   290 			if x == nil {
       
   291 				Info.Log("Refusing to send nil stanza")
       
   292 				continue
       
   293 			}
       
   294 			srvOut <- x
       
   295 		}
       
   296 	}
       
   297 }
       
   298 
       
   299 func handleStream(ss *stream) {
       
   300 }
       
   301 
       
   302 func (cl *Client) handleStreamError(se *streamError) {
       
   303 	Info.Logf("Received stream error: %v", se)
       
   304 	cl.socket.Close()
       
   305 }
       
   306 
       
   307 func (cl *Client) handleFeatures(fe *Features) {
       
   308 	cl.Features = fe
       
   309 	if fe.Starttls != nil {
       
   310 		start := &starttls{XMLName: xml.Name{Space: NsTLS,
       
   311 			Local: "starttls"}}
       
   312 		cl.sendXml <- start
       
   313 		return
       
   314 	}
       
   315 
       
   316 	if len(fe.Mechanisms.Mechanism) > 0 {
       
   317 		cl.chooseSasl(fe)
       
   318 		return
       
   319 	}
       
   320 
       
   321 	if fe.Bind != nil {
       
   322 		cl.bind(fe.Bind)
       
   323 		return
       
   324 	}
       
   325 }
       
   326 
       
   327 // readTransport() is running concurrently. We need to stop it,
       
   328 // negotiate TLS, then start it again. It calls waitForSocket() in
       
   329 // its inner loop; see below.
       
   330 func (cl *Client) handleTls(t *starttls) {
       
   331 	tcp := cl.socket
       
   332 
       
   333 	// Set the socket to nil, and wait for the reader routine to
       
   334 	// signal that it's paused.
       
   335 	cl.socket = nil
       
   336 	cl.socketSync.Add(1)
       
   337 	cl.socketSync.Wait()
       
   338 
       
   339 	// Negotiate TLS with the server.
       
   340 	tls := tls.Client(tcp, &cl.tlsConfig)
       
   341 
       
   342 	// Make the TLS connection available to the reader, and wait
       
   343 	// for it to signal that it's working again.
       
   344 	cl.socketSync.Add(1)
       
   345 	cl.socket = tls
       
   346 	cl.socketSync.Wait()
       
   347 
       
   348 	Info.Log("TLS negotiation succeeded.")
       
   349 	cl.Features = nil
       
   350 
       
   351 	// Now re-send the initial handshake message to start the new
       
   352 	// session.
       
   353 	hsOut := &stream{To: cl.Jid.Domain, Version: XMPPVersion}
       
   354 	cl.sendXml <- hsOut
       
   355 }
       
   356 
       
   357 // Synchronize with handleTls(). Called from readTransport() when
       
   358 // cl.socket is nil.
       
   359 func (cl *Client) waitForSocket() {
       
   360 	// Signal that we've stopped reading from the socket.
       
   361 	cl.socketSync.Done()
       
   362 
       
   363 	// Wait until the socket is available again.
       
   364 	for cl.socket == nil {
       
   365 		time.Sleep(1e8)
       
   366 	}
       
   367 
       
   368 	// Signal that we're going back to the read loop.
       
   369 	cl.socketSync.Done()
       
   370 }
       
   371 
       
   372 // BUG(cjyar): Doesn't implement TLS/SASL EXTERNAL.
       
   373 func (cl *Client) chooseSasl(fe *Features) {
       
   374 	var digestMd5 bool
       
   375 	for _, m := range fe.Mechanisms.Mechanism {
       
   376 		switch strings.ToLower(m) {
       
   377 		case "digest-md5":
       
   378 			digestMd5 = true
       
   379 		}
       
   380 	}
       
   381 
       
   382 	if digestMd5 {
       
   383 		auth := &auth{XMLName: xml.Name{Space: NsSASL, Local: "auth"}, Mechanism: "DIGEST-MD5"}
       
   384 		cl.sendXml <- auth
       
   385 	}
       
   386 }
       
   387 
       
   388 func (cl *Client) handleSasl(srv *auth) {
       
   389 	switch strings.ToLower(srv.XMLName.Local) {
       
   390 	case "challenge":
       
   391 		b64 := base64.StdEncoding
       
   392 		str, err := b64.DecodeString(srv.Chardata)
       
   393 		if err != nil {
       
   394 			Warn.Logf("SASL challenge decode: %s", err)
       
   395 			return
       
   396 		}
       
   397 		srvMap := parseSasl(string(str))
       
   398 
       
   399 		if cl.saslExpected == "" {
       
   400 			cl.saslDigest1(srvMap)
       
   401 		} else {
       
   402 			cl.saslDigest2(srvMap)
       
   403 		}
       
   404 	case "failure":
       
   405 		Info.Log("SASL authentication failed")
       
   406 	case "success":
       
   407 		Info.Log("Sasl authentication succeeded")
       
   408 		cl.Features = nil
       
   409 		ss := &stream{To: cl.Jid.Domain, Version: XMPPVersion}
       
   410 		cl.sendXml <- ss
       
   411 	}
       
   412 }
       
   413 
       
   414 func (cl *Client) saslDigest1(srvMap map[string]string) {
       
   415 	// Make sure it supports qop=auth
       
   416 	var hasAuth bool
       
   417 	for _, qop := range strings.Fields(srvMap["qop"]) {
       
   418 		if qop == "auth" {
       
   419 			hasAuth = true
       
   420 		}
       
   421 	}
       
   422 	if !hasAuth {
       
   423 		Warn.Log("Server doesn't support SASL auth")
       
   424 		return
       
   425 	}
       
   426 
       
   427 	// Pick a realm.
       
   428 	var realm string
       
   429 	if srvMap["realm"] != "" {
       
   430 		realm = strings.Fields(srvMap["realm"])[0]
       
   431 	}
       
   432 
       
   433 	passwd := cl.password
       
   434 	nonce := srvMap["nonce"]
       
   435 	digestUri := "xmpp/" + cl.Jid.Domain
       
   436 	nonceCount := int32(1)
       
   437 	nonceCountStr := fmt.Sprintf("%08x", nonceCount)
       
   438 
       
   439 	// Begin building the response. Username is
       
   440 	// user@domain or just domain.
       
   441 	var username string
       
   442 	if cl.Jid.Node == "" {
       
   443 		username = cl.Jid.Domain
       
   444 	} else {
       
   445 		username = cl.Jid.Node
       
   446 	}
       
   447 
       
   448 	// Generate our own nonce from random data.
       
   449 	randSize := big.NewInt(0)
       
   450 	randSize.Lsh(big.NewInt(1), 64)
       
   451 	cnonce, err := rand.Int(rand.Reader, randSize)
       
   452 	if err != nil {
       
   453 		Warn.Logf("SASL rand: %s", err)
       
   454 		return
       
   455 	}
       
   456 	cnonceStr := fmt.Sprintf("%016x", cnonce)
       
   457 
       
   458 	/* Now encode the actual password response, as well as the
       
   459 	 * expected next challenge from the server. */
       
   460 	response := saslDigestResponse(username, realm, passwd, nonce,
       
   461 		cnonceStr, "AUTHENTICATE", digestUri, nonceCountStr)
       
   462 	next := saslDigestResponse(username, realm, passwd, nonce,
       
   463 		cnonceStr, "", digestUri, nonceCountStr)
       
   464 	cl.saslExpected = next
       
   465 
       
   466 	// Build the map which will be encoded.
       
   467 	clMap := make(map[string]string)
       
   468 	clMap["realm"] = `"` + realm + `"`
       
   469 	clMap["username"] = `"` + username + `"`
       
   470 	clMap["nonce"] = `"` + nonce + `"`
       
   471 	clMap["cnonce"] = `"` + cnonceStr + `"`
       
   472 	clMap["nc"] = nonceCountStr
       
   473 	clMap["qop"] = "auth"
       
   474 	clMap["digest-uri"] = `"` + digestUri + `"`
       
   475 	clMap["response"] = response
       
   476 	if srvMap["charset"] == "utf-8" {
       
   477 		clMap["charset"] = "utf-8"
       
   478 	}
       
   479 
       
   480 	// Encode the map and send it.
       
   481 	clStr := packSasl(clMap)
       
   482 	b64 := base64.StdEncoding
       
   483 	clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "response"}, Chardata: b64.EncodeToString([]byte(clStr))}
       
   484 	cl.sendXml <- clObj
       
   485 }
       
   486 
       
   487 func (cl *Client) saslDigest2(srvMap map[string]string) {
       
   488 	if cl.saslExpected == srvMap["rspauth"] {
       
   489 		clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "response"}}
       
   490 		cl.sendXml <- clObj
       
   491 	} else {
       
   492 		clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "failure"}, Any: &Generic{XMLName: xml.Name{Space: NsSASL,
       
   493 			Local: "abort"}}}
       
   494 		cl.sendXml <- clObj
       
   495 	}
       
   496 }
       
   497 
       
   498 // Takes a string like `key1=value1,key2="value2"...` and returns a
       
   499 // key/value map.
       
   500 func parseSasl(in string) map[string]string {
       
   501 	re := regexp.MustCompile(`([^=]+)="?([^",]+)"?,?`)
       
   502 	strs := re.FindAllStringSubmatch(in, -1)
       
   503 	m := make(map[string]string)
       
   504 	for _, pair := range strs {
       
   505 		key := strings.ToLower(string(pair[1]))
       
   506 		value := string(pair[2])
       
   507 		m[key] = value
       
   508 	}
       
   509 	return m
       
   510 }
       
   511 
       
   512 // Inverse of parseSasl().
       
   513 func packSasl(m map[string]string) string {
       
   514 	var terms []string
       
   515 	for key, value := range m {
       
   516 		if key == "" || value == "" || value == `""` {
       
   517 			continue
       
   518 		}
       
   519 		terms = append(terms, key+"="+value)
       
   520 	}
       
   521 	return strings.Join(terms, ",")
       
   522 }
       
   523 
       
   524 // Computes the response string for digest authentication.
       
   525 func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
       
   526 	authenticate, digestUri, nonceCountStr string) string {
       
   527 	h := func(text string) []byte {
       
   528 		h := md5.New()
       
   529 		h.Write([]byte(text))
       
   530 		return h.Sum(nil)
       
   531 	}
       
   532 	hex := func(bytes []byte) string {
       
   533 		return fmt.Sprintf("%x", bytes)
       
   534 	}
       
   535 	kd := func(secret, data string) []byte {
       
   536 		return h(secret + ":" + data)
       
   537 	}
       
   538 
       
   539 	a1 := string(h(username+":"+realm+":"+passwd)) + ":" +
       
   540 		nonce + ":" + cnonceStr
       
   541 	a2 := authenticate + ":" + digestUri
       
   542 	response := hex(kd(hex(h(a1)), nonce+":"+
       
   543 		nonceCountStr+":"+cnonceStr+":auth:"+
       
   544 		hex(h(a2))))
       
   545 	return response
       
   546 }
       
   547 
       
   548 // Send a request to bind a resource. RFC 3920, section 7.
       
   549 func (cl *Client) bind(bindAdv *bindIq) {
       
   550 	res := cl.Jid.Resource
       
   551 	bindReq := &bindIq{}
       
   552 	if res != "" {
       
   553 		bindReq.Resource = &res
       
   554 	}
       
   555 	msg := &Iq{Header: Header{Type: "set", Id: NextId(),
       
   556 		Nested: []interface{}{bindReq}}}
       
   557 	f := func(st Stanza) bool {
       
   558 		iq, ok := st.(*Iq)
       
   559 		if !ok {
       
   560 			Warn.Log("non-iq response")
       
   561 		}
       
   562 		if iq.Type == "error" {
       
   563 			Warn.Log("Resource binding failed")
       
   564 			return false
       
   565 		}
       
   566 		var bindRepl *bindIq
       
   567 		for _, ele := range iq.Nested {
       
   568 			if b, ok := ele.(*bindIq); ok {
       
   569 				bindRepl = b
       
   570 				break
       
   571 			}
       
   572 		}
       
   573 		if bindRepl == nil {
       
   574 			Warn.Logf("Bad bind reply: %#v", iq)
       
   575 			return false
       
   576 		}
       
   577 		jidStr := bindRepl.Jid
       
   578 		if jidStr == nil || *jidStr == "" {
       
   579 			Warn.Log("Can't bind empty resource")
       
   580 			return false
       
   581 		}
       
   582 		jid := new(JID)
       
   583 		if err := jid.Set(*jidStr); err != nil {
       
   584 			Warn.Logf("Can't parse JID %s: %s", *jidStr, err)
       
   585 			return false
       
   586 		}
       
   587 		cl.Jid = *jid
       
   588 		Info.Logf("Bound resource: %s", cl.Jid.String())
       
   589 		cl.bindDone()
       
   590 		return false
       
   591 	}
       
   592 	cl.HandleStanza(msg.Id, f)
       
   593 	cl.sendXml <- msg
       
   594 }
       
   595 
       
   596 // Register a callback to handle the next XMPP stanza (iq, message, or
       
   597 // presence) with a given id. The provided function will not be called
       
   598 // more than once. If it returns false, the stanza will not be made
       
   599 // available on the normal Client.In channel. The stanza handler
       
   600 // must not read from that channel, as deliveries on it cannot proceed
       
   601 // until the handler returns true or false.
       
   602 func (cl *Client) HandleStanza(id string, f func(Stanza) bool) {
       
   603 	h := &stanzaHandler{id: id, f: f}
       
   604 	cl.handlers <- h
       
   605 }