xmpp/stream.go
changeset 126 367e76b3028e
parent 121 ebb86cbdd218
child 127 a8f9a0c07fc8
equal deleted inserted replaced
125:f464f14e39a7 126:367e76b3028e
       
     1 // Copyright 2011 The Go Authors.  All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 // This file contains the three layers of processing for the
       
     6 // communication with the server: transport (where TLS happens), XML
       
     7 // (where strings are converted to go structures), and Stream (where
       
     8 // we respond to XMPP events on behalf of the library client), or send
       
     9 // those events to the client.
       
    10 
       
    11 package xmpp
       
    12 
       
    13 import (
       
    14 	"crypto/md5"
       
    15 	"crypto/rand"
       
    16 	"crypto/tls"
       
    17 	"encoding/base64"
       
    18 	"encoding/xml"
       
    19 	"fmt"
       
    20 	"io"
       
    21 	"math/big"
       
    22 	"net"
       
    23 	"regexp"
       
    24 	"strings"
       
    25 	"time"
       
    26 )
       
    27 
       
    28 // Callback to handle a stanza with a particular id.
       
    29 type stanzaHandler struct {
       
    30 	id string
       
    31 	// Return true means pass this to the application
       
    32 	f func(Stanza) bool
       
    33 }
       
    34 
       
    35 func (cl *Client) readTransport(w io.WriteCloser) {
       
    36 	defer w.Close()
       
    37 	p := make([]byte, 1024)
       
    38 	for {
       
    39 		if cl.socket == nil {
       
    40 			cl.waitForSocket()
       
    41 		}
       
    42 		cl.socket.SetReadDeadline(time.Now().Add(time.Second))
       
    43 		nr, err := cl.socket.Read(p)
       
    44 		if nr == 0 {
       
    45 			if errno, ok := err.(*net.OpError); ok {
       
    46 				if errno.Timeout() {
       
    47 					continue
       
    48 				}
       
    49 			}
       
    50 			Warn.Logf("read: %s", err)
       
    51 			break
       
    52 		}
       
    53 		nw, err := w.Write(p[:nr])
       
    54 		if nw < nr {
       
    55 			Warn.Logf("read: %s", err)
       
    56 			break
       
    57 		}
       
    58 	}
       
    59 }
       
    60 
       
    61 func (cl *Client) writeTransport(r io.Reader) {
       
    62 	defer cl.socket.Close()
       
    63 	p := make([]byte, 1024)
       
    64 	for {
       
    65 		nr, err := r.Read(p)
       
    66 		if nr == 0 {
       
    67 			Warn.Logf("write: %s", err)
       
    68 			break
       
    69 		}
       
    70 		nw, err := cl.socket.Write(p[:nr])
       
    71 		if nw < nr {
       
    72 			Warn.Logf("write: %s", err)
       
    73 			break
       
    74 		}
       
    75 	}
       
    76 }
       
    77 
       
    78 func readXml(r io.Reader, ch chan<- interface{},
       
    79 	extStanza map[string]func(*xml.Name) interface{}) {
       
    80 	if _, ok := Debug.(*noLog); !ok {
       
    81 		pr, pw := io.Pipe()
       
    82 		go tee(r, pw, "S: ")
       
    83 		r = pr
       
    84 	}
       
    85 	defer close(ch)
       
    86 
       
    87 	// This trick loads our namespaces into the parser.
       
    88 	nsstr := fmt.Sprintf(`<a xmlns="%s" xmlns:stream="%s">`,
       
    89 		NsClient, NsStream)
       
    90 	nsrdr := strings.NewReader(nsstr)
       
    91 	p := xml.NewDecoder(io.MultiReader(nsrdr, r))
       
    92 	p.Token()
       
    93 
       
    94 Loop:
       
    95 	for {
       
    96 		// Sniff the next token on the stream.
       
    97 		t, err := p.Token()
       
    98 		if t == nil {
       
    99 			if err != io.EOF {
       
   100 				Warn.Logf("read: %s", err)
       
   101 			}
       
   102 			break
       
   103 		}
       
   104 		var se xml.StartElement
       
   105 		var ok bool
       
   106 		if se, ok = t.(xml.StartElement); !ok {
       
   107 			continue
       
   108 		}
       
   109 
       
   110 		// Allocate the appropriate structure for this token.
       
   111 		var obj interface{}
       
   112 		switch se.Name.Space + " " + se.Name.Local {
       
   113 		case NsStream + " stream":
       
   114 			st, err := parseStream(se)
       
   115 			if err != nil {
       
   116 				Warn.Logf("unmarshal stream: %s", err)
       
   117 				break Loop
       
   118 			}
       
   119 			ch <- st
       
   120 			continue
       
   121 		case "stream error", NsStream + " error":
       
   122 			obj = &streamError{}
       
   123 		case NsStream + " features":
       
   124 			obj = &Features{}
       
   125 		case NsTLS + " proceed", NsTLS + " failure":
       
   126 			obj = &starttls{}
       
   127 		case NsSASL + " challenge", NsSASL + " failure",
       
   128 			NsSASL + " success":
       
   129 			obj = &auth{}
       
   130 		case NsClient + " iq":
       
   131 			obj = &Iq{}
       
   132 		case NsClient + " message":
       
   133 			obj = &Message{}
       
   134 		case NsClient + " presence":
       
   135 			obj = &Presence{}
       
   136 		default:
       
   137 			obj = &Generic{}
       
   138 			Info.Logf("Ignoring unrecognized: %s %s", se.Name.Space,
       
   139 				se.Name.Local)
       
   140 		}
       
   141 
       
   142 		// Read the complete XML stanza.
       
   143 		err = p.DecodeElement(obj, &se)
       
   144 		if err != nil {
       
   145 			Warn.Logf("unmarshal: %s", err)
       
   146 			break Loop
       
   147 		}
       
   148 
       
   149 		// If it's a Stanza, we try to unmarshal its innerxml
       
   150 		// into objects of the appropriate respective
       
   151 		// types. This is specified by our extensions.
       
   152 		if st, ok := obj.(Stanza); ok {
       
   153 			err = parseExtended(st.GetHeader(), extStanza)
       
   154 			if err != nil {
       
   155 				Warn.Logf("ext unmarshal: %s", err)
       
   156 				break Loop
       
   157 			}
       
   158 		}
       
   159 
       
   160 		// Put it on the channel.
       
   161 		ch <- obj
       
   162 	}
       
   163 }
       
   164 
       
   165 func parseExtended(st *Header, extStanza map[string]func(*xml.Name) interface{}) error {
       
   166 	// Now parse the stanza's innerxml to find the string that we
       
   167 	// can unmarshal this nested element from.
       
   168 	reader := strings.NewReader(st.Innerxml)
       
   169 	p := xml.NewDecoder(reader)
       
   170 	for {
       
   171 		t, err := p.Token()
       
   172 		if err == io.EOF {
       
   173 			break
       
   174 		}
       
   175 		if err != nil {
       
   176 			return err
       
   177 		}
       
   178 		if se, ok := t.(xml.StartElement); ok {
       
   179 			if con, ok := extStanza[se.Name.Space]; ok {
       
   180 				// Call the indicated constructor.
       
   181 				nested := con(&se.Name)
       
   182 
       
   183 				// Unmarshal the nested element and
       
   184 				// stuff it back into the stanza.
       
   185 				err := p.DecodeElement(nested, &se)
       
   186 				if err != nil {
       
   187 					return err
       
   188 				}
       
   189 				st.Nested = append(st.Nested, nested)
       
   190 			}
       
   191 		}
       
   192 	}
       
   193 
       
   194 	return nil
       
   195 }
       
   196 
       
   197 func writeXml(w io.Writer, ch <-chan interface{}) {
       
   198 	if _, ok := Debug.(*noLog); !ok {
       
   199 		pr, pw := io.Pipe()
       
   200 		go tee(pr, w, "C: ")
       
   201 		w = pw
       
   202 	}
       
   203 	defer func(w io.Writer) {
       
   204 		if c, ok := w.(io.Closer); ok {
       
   205 			c.Close()
       
   206 		}
       
   207 	}(w)
       
   208 
       
   209 	enc := xml.NewEncoder(w)
       
   210 
       
   211 	for obj := range ch {
       
   212 		if st, ok := obj.(*stream); ok {
       
   213 			_, err := w.Write([]byte(st.String()))
       
   214 			if err != nil {
       
   215 				Warn.Logf("write: %s", err)
       
   216 			}
       
   217 		} else {
       
   218 			err := enc.Encode(obj)
       
   219 			if err != nil {
       
   220 				Warn.Logf("marshal: %s", err)
       
   221 				break
       
   222 			}
       
   223 		}
       
   224 	}
       
   225 }
       
   226 
       
   227 func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- Stanza) {
       
   228 	defer close(cliOut)
       
   229 
       
   230 	handlers := make(map[string]func(Stanza) bool)
       
   231 Loop:
       
   232 	for {
       
   233 		select {
       
   234 		case h := <-cl.handlers:
       
   235 			handlers[h.id] = h.f
       
   236 		case x, ok := <-srvIn:
       
   237 			if !ok {
       
   238 				break Loop
       
   239 			}
       
   240 			switch obj := x.(type) {
       
   241 			case *stream:
       
   242 				handleStream(obj)
       
   243 			case *streamError:
       
   244 				cl.handleStreamError(obj)
       
   245 			case *Features:
       
   246 				cl.handleFeatures(obj)
       
   247 			case *starttls:
       
   248 				cl.handleTls(obj)
       
   249 			case *auth:
       
   250 				cl.handleSasl(obj)
       
   251 			case Stanza:
       
   252 				send := true
       
   253 				id := obj.GetHeader().Id
       
   254 				if handlers[id] != nil {
       
   255 					f := handlers[id]
       
   256 					delete(handlers, id)
       
   257 					send = f(obj)
       
   258 				}
       
   259 				if send {
       
   260 					cliOut <- obj
       
   261 				}
       
   262 			default:
       
   263 				Warn.Logf("Unhandled non-stanza: %T %#v", x, x)
       
   264 			}
       
   265 		}
       
   266 	}
       
   267 }
       
   268 
       
   269 // This loop is paused until resource binding is complete. Otherwise
       
   270 // the app might inject something inappropriate into our negotiations
       
   271 // with the server. The control channel controls this loop's
       
   272 // activity.
       
   273 func writeStream(srvOut chan<- interface{}, cliIn <-chan Stanza,
       
   274 	control <-chan int) {
       
   275 	defer close(srvOut)
       
   276 
       
   277 	var input <-chan Stanza
       
   278 Loop:
       
   279 	for {
       
   280 		select {
       
   281 		case status := <-control:
       
   282 			switch status {
       
   283 			case 0:
       
   284 				input = nil
       
   285 			case 1:
       
   286 				input = cliIn
       
   287 			case -1:
       
   288 				break Loop
       
   289 			}
       
   290 		case x, ok := <-input:
       
   291 			if !ok {
       
   292 				break Loop
       
   293 			}
       
   294 			if x == nil {
       
   295 				Info.Log("Refusing to send nil stanza")
       
   296 				continue
       
   297 			}
       
   298 			srvOut <- x
       
   299 		}
       
   300 	}
       
   301 }
       
   302 
       
   303 func handleStream(ss *stream) {
       
   304 }
       
   305 
       
   306 func (cl *Client) handleStreamError(se *streamError) {
       
   307 	Info.Logf("Received stream error: %v", se)
       
   308 	close(cl.Out)
       
   309 }
       
   310 
       
   311 func (cl *Client) handleFeatures(fe *Features) {
       
   312 	cl.Features = fe
       
   313 	if fe.Starttls != nil {
       
   314 		start := &starttls{XMLName: xml.Name{Space: NsTLS,
       
   315 			Local: "starttls"}}
       
   316 		cl.xmlOut <- start
       
   317 		return
       
   318 	}
       
   319 
       
   320 	if len(fe.Mechanisms.Mechanism) > 0 {
       
   321 		cl.chooseSasl(fe)
       
   322 		return
       
   323 	}
       
   324 
       
   325 	if fe.Bind != nil {
       
   326 		cl.bind(fe.Bind)
       
   327 		return
       
   328 	}
       
   329 }
       
   330 
       
   331 // readTransport() is running concurrently. We need to stop it,
       
   332 // negotiate TLS, then start it again. It calls waitForSocket() in
       
   333 // its inner loop; see below.
       
   334 func (cl *Client) handleTls(t *starttls) {
       
   335 	tcp := cl.socket
       
   336 
       
   337 	// Set the socket to nil, and wait for the reader routine to
       
   338 	// signal that it's paused.
       
   339 	cl.socket = nil
       
   340 	cl.socketSync.Add(1)
       
   341 	cl.socketSync.Wait()
       
   342 
       
   343 	// Negotiate TLS with the server.
       
   344 	tls := tls.Client(tcp, &TlsConfig)
       
   345 
       
   346 	// Make the TLS connection available to the reader, and wait
       
   347 	// for it to signal that it's working again.
       
   348 	cl.socketSync.Add(1)
       
   349 	cl.socket = tls
       
   350 	cl.socketSync.Wait()
       
   351 
       
   352 	Info.Log("TLS negotiation succeeded.")
       
   353 	cl.Features = nil
       
   354 
       
   355 	// Now re-send the initial handshake message to start the new
       
   356 	// session.
       
   357 	hsOut := &stream{To: cl.Jid.Domain, Version: XMPPVersion}
       
   358 	cl.xmlOut <- hsOut
       
   359 }
       
   360 
       
   361 // Synchronize with handleTls(). Called from readTransport() when
       
   362 // cl.socket is nil.
       
   363 func (cl *Client) waitForSocket() {
       
   364 	// Signal that we've stopped reading from the socket.
       
   365 	cl.socketSync.Done()
       
   366 
       
   367 	// Wait until the socket is available again.
       
   368 	for cl.socket == nil {
       
   369 		time.Sleep(1e8)
       
   370 	}
       
   371 
       
   372 	// Signal that we're going back to the read loop.
       
   373 	cl.socketSync.Done()
       
   374 }
       
   375 
       
   376 // BUG(cjyar): Doesn't implement TLS/SASL EXTERNAL.
       
   377 func (cl *Client) chooseSasl(fe *Features) {
       
   378 	var digestMd5 bool
       
   379 	for _, m := range fe.Mechanisms.Mechanism {
       
   380 		switch strings.ToLower(m) {
       
   381 		case "digest-md5":
       
   382 			digestMd5 = true
       
   383 		}
       
   384 	}
       
   385 
       
   386 	if digestMd5 {
       
   387 		auth := &auth{XMLName: xml.Name{Space: NsSASL, Local: "auth"}, Mechanism: "DIGEST-MD5"}
       
   388 		cl.xmlOut <- auth
       
   389 	}
       
   390 }
       
   391 
       
   392 func (cl *Client) handleSasl(srv *auth) {
       
   393 	switch strings.ToLower(srv.XMLName.Local) {
       
   394 	case "challenge":
       
   395 		b64 := base64.StdEncoding
       
   396 		str, err := b64.DecodeString(srv.Chardata)
       
   397 		if err != nil {
       
   398 			Warn.Logf("SASL challenge decode: %s", err)
       
   399 			return
       
   400 		}
       
   401 		srvMap := parseSasl(string(str))
       
   402 
       
   403 		if cl.saslExpected == "" {
       
   404 			cl.saslDigest1(srvMap)
       
   405 		} else {
       
   406 			cl.saslDigest2(srvMap)
       
   407 		}
       
   408 	case "failure":
       
   409 		Info.Log("SASL authentication failed")
       
   410 	case "success":
       
   411 		Info.Log("Sasl authentication succeeded")
       
   412 		cl.Features = nil
       
   413 		ss := &stream{To: cl.Jid.Domain, Version: XMPPVersion}
       
   414 		cl.xmlOut <- ss
       
   415 	}
       
   416 }
       
   417 
       
   418 func (cl *Client) saslDigest1(srvMap map[string]string) {
       
   419 	// Make sure it supports qop=auth
       
   420 	var hasAuth bool
       
   421 	for _, qop := range strings.Fields(srvMap["qop"]) {
       
   422 		if qop == "auth" {
       
   423 			hasAuth = true
       
   424 		}
       
   425 	}
       
   426 	if !hasAuth {
       
   427 		Warn.Log("Server doesn't support SASL auth")
       
   428 		return
       
   429 	}
       
   430 
       
   431 	// Pick a realm.
       
   432 	var realm string
       
   433 	if srvMap["realm"] != "" {
       
   434 		realm = strings.Fields(srvMap["realm"])[0]
       
   435 	}
       
   436 
       
   437 	passwd := cl.password
       
   438 	nonce := srvMap["nonce"]
       
   439 	digestUri := "xmpp/" + cl.Jid.Domain
       
   440 	nonceCount := int32(1)
       
   441 	nonceCountStr := fmt.Sprintf("%08x", nonceCount)
       
   442 
       
   443 	// Begin building the response. Username is
       
   444 	// user@domain or just domain.
       
   445 	var username string
       
   446 	if cl.Jid.Node == "" {
       
   447 		username = cl.Jid.Domain
       
   448 	} else {
       
   449 		username = cl.Jid.Node
       
   450 	}
       
   451 
       
   452 	// Generate our own nonce from random data.
       
   453 	randSize := big.NewInt(0)
       
   454 	randSize.Lsh(big.NewInt(1), 64)
       
   455 	cnonce, err := rand.Int(rand.Reader, randSize)
       
   456 	if err != nil {
       
   457 		Warn.Logf("SASL rand: %s", err)
       
   458 		return
       
   459 	}
       
   460 	cnonceStr := fmt.Sprintf("%016x", cnonce)
       
   461 
       
   462 	/* Now encode the actual password response, as well as the
       
   463 	 * expected next challenge from the server. */
       
   464 	response := saslDigestResponse(username, realm, passwd, nonce,
       
   465 		cnonceStr, "AUTHENTICATE", digestUri, nonceCountStr)
       
   466 	next := saslDigestResponse(username, realm, passwd, nonce,
       
   467 		cnonceStr, "", digestUri, nonceCountStr)
       
   468 	cl.saslExpected = next
       
   469 
       
   470 	// Build the map which will be encoded.
       
   471 	clMap := make(map[string]string)
       
   472 	clMap["realm"] = `"` + realm + `"`
       
   473 	clMap["username"] = `"` + username + `"`
       
   474 	clMap["nonce"] = `"` + nonce + `"`
       
   475 	clMap["cnonce"] = `"` + cnonceStr + `"`
       
   476 	clMap["nc"] = nonceCountStr
       
   477 	clMap["qop"] = "auth"
       
   478 	clMap["digest-uri"] = `"` + digestUri + `"`
       
   479 	clMap["response"] = response
       
   480 	if srvMap["charset"] == "utf-8" {
       
   481 		clMap["charset"] = "utf-8"
       
   482 	}
       
   483 
       
   484 	// Encode the map and send it.
       
   485 	clStr := packSasl(clMap)
       
   486 	b64 := base64.StdEncoding
       
   487 	clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "response"}, Chardata: b64.EncodeToString([]byte(clStr))}
       
   488 	cl.xmlOut <- clObj
       
   489 }
       
   490 
       
   491 func (cl *Client) saslDigest2(srvMap map[string]string) {
       
   492 	if cl.saslExpected == srvMap["rspauth"] {
       
   493 		clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "response"}}
       
   494 		cl.xmlOut <- clObj
       
   495 	} else {
       
   496 		clObj := &auth{XMLName: xml.Name{Space: NsSASL, Local: "failure"}, Any: &Generic{XMLName: xml.Name{Space: NsSASL,
       
   497 			Local: "abort"}}}
       
   498 		cl.xmlOut <- clObj
       
   499 	}
       
   500 }
       
   501 
       
   502 // Takes a string like `key1=value1,key2="value2"...` and returns a
       
   503 // key/value map.
       
   504 func parseSasl(in string) map[string]string {
       
   505 	re := regexp.MustCompile(`([^=]+)="?([^",]+)"?,?`)
       
   506 	strs := re.FindAllStringSubmatch(in, -1)
       
   507 	m := make(map[string]string)
       
   508 	for _, pair := range strs {
       
   509 		key := strings.ToLower(string(pair[1]))
       
   510 		value := string(pair[2])
       
   511 		m[key] = value
       
   512 	}
       
   513 	return m
       
   514 }
       
   515 
       
   516 // Inverse of parseSasl().
       
   517 func packSasl(m map[string]string) string {
       
   518 	var terms []string
       
   519 	for key, value := range m {
       
   520 		if key == "" || value == "" || value == `""` {
       
   521 			continue
       
   522 		}
       
   523 		terms = append(terms, key+"="+value)
       
   524 	}
       
   525 	return strings.Join(terms, ",")
       
   526 }
       
   527 
       
   528 // Computes the response string for digest authentication.
       
   529 func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
       
   530 	authenticate, digestUri, nonceCountStr string) string {
       
   531 	h := func(text string) []byte {
       
   532 		h := md5.New()
       
   533 		h.Write([]byte(text))
       
   534 		return h.Sum(nil)
       
   535 	}
       
   536 	hex := func(bytes []byte) string {
       
   537 		return fmt.Sprintf("%x", bytes)
       
   538 	}
       
   539 	kd := func(secret, data string) []byte {
       
   540 		return h(secret + ":" + data)
       
   541 	}
       
   542 
       
   543 	a1 := string(h(username+":"+realm+":"+passwd)) + ":" +
       
   544 		nonce + ":" + cnonceStr
       
   545 	a2 := authenticate + ":" + digestUri
       
   546 	response := hex(kd(hex(h(a1)), nonce+":"+
       
   547 		nonceCountStr+":"+cnonceStr+":auth:"+
       
   548 		hex(h(a2))))
       
   549 	return response
       
   550 }
       
   551 
       
   552 // Send a request to bind a resource. RFC 3920, section 7.
       
   553 func (cl *Client) bind(bindAdv *bindIq) {
       
   554 	res := cl.Jid.Resource
       
   555 	bindReq := &bindIq{}
       
   556 	if res != "" {
       
   557 		bindReq.Resource = &res
       
   558 	}
       
   559 	msg := &Iq{Header: Header{Type: "set", Id: NextId(),
       
   560 		Nested: []interface{}{bindReq}}}
       
   561 	f := func(st Stanza) bool {
       
   562 		iq, ok := st.(*Iq)
       
   563 		if !ok {
       
   564 			Warn.Log("non-iq response")
       
   565 		}
       
   566 		if iq.Type == "error" {
       
   567 			Warn.Log("Resource binding failed")
       
   568 			return false
       
   569 		}
       
   570 		var bindRepl *bindIq
       
   571 		for _, ele := range iq.Nested {
       
   572 			if b, ok := ele.(*bindIq); ok {
       
   573 				bindRepl = b
       
   574 				break
       
   575 			}
       
   576 		}
       
   577 		if bindRepl == nil {
       
   578 			Warn.Logf("Bad bind reply: %#v", iq)
       
   579 			return false
       
   580 		}
       
   581 		jidStr := bindRepl.Jid
       
   582 		if jidStr == nil || *jidStr == "" {
       
   583 			Warn.Log("Can't bind empty resource")
       
   584 			return false
       
   585 		}
       
   586 		jid := new(JID)
       
   587 		if err := jid.Set(*jidStr); err != nil {
       
   588 			Warn.Logf("Can't parse JID %s: %s", *jidStr, err)
       
   589 			return false
       
   590 		}
       
   591 		cl.Jid = *jid
       
   592 		Info.Logf("Bound resource: %s", cl.Jid.String())
       
   593 		cl.bindDone()
       
   594 		return false
       
   595 	}
       
   596 	cl.HandleStanza(msg.Id, f)
       
   597 	cl.xmlOut <- msg
       
   598 }
       
   599 
       
   600 // Register a callback to handle the next XMPP stanza (iq, message, or
       
   601 // presence) with a given id. The provided function will not be called
       
   602 // more than once. If it returns false, the stanza will not be made
       
   603 // available on the normal Client.In channel. The stanza handler
       
   604 // must not read from that channel, as deliveries on it cannot proceed
       
   605 // until the handler returns true or false.
       
   606 func (cl *Client) HandleStanza(id string, f func(Stanza) bool) {
       
   607 	h := &stanzaHandler{id: id, f: f}
       
   608 	cl.handlers <- h
       
   609 }