stream.go
changeset 62 6e2eea62ccca
parent 61 16513974d273
child 63 c7f2edd25f4a
equal deleted inserted replaced
61:16513974d273 62:6e2eea62ccca
    16 	"crypto/rand"
    16 	"crypto/rand"
    17 	"crypto/tls"
    17 	"crypto/tls"
    18 	"encoding/base64"
    18 	"encoding/base64"
    19 	"fmt"
    19 	"fmt"
    20 	"io"
    20 	"io"
    21 	"log"
       
    22 	"net"
    21 	"net"
    23 	"os"
    22 	"os"
    24 	"regexp"
    23 	"regexp"
    25 	"strings"
    24 	"strings"
       
    25 	"syslog"
    26 	"time"
    26 	"time"
    27 	"xml"
    27 	"xml"
    28 )
    28 )
    29 
    29 
    30 // Callback to handle a stanza with a particular id.
    30 // Callback to handle a stanza with a particular id.
    50 			if errno, ok := err.(*net.OpError) ; ok {
    50 			if errno, ok := err.(*net.OpError) ; ok {
    51 				if errno.Timeout() {
    51 				if errno.Timeout() {
    52 					continue
    52 					continue
    53 				}
    53 				}
    54 			}
    54 			}
    55 			log.Printf("read: %s", err.String())
    55 			if Log != nil {
       
    56 				Log.Err("read: " + err.String())
       
    57 			}
    56 			break
    58 			break
    57 		}
    59 		}
    58 		nw, err := w.Write(p[:nr])
    60 		nw, err := w.Write(p[:nr])
    59 		if nw < nr {
    61 		if nw < nr {
    60 			log.Println("read: %s", err.String())
    62 			if Log != nil {
       
    63 				Log.Err("read: " + err.String())
       
    64 			}
    61 			break
    65 			break
    62 		}
    66 		}
    63 	}
    67 	}
    64 }
    68 }
    65 
    69 
    67 	defer tryClose(r, cl.socket)
    71 	defer tryClose(r, cl.socket)
    68 	p := make([]byte, 1024)
    72 	p := make([]byte, 1024)
    69 	for {
    73 	for {
    70 		nr, err := r.Read(p)
    74 		nr, err := r.Read(p)
    71 		if nr == 0 {
    75 		if nr == 0 {
    72 			log.Printf("write: %s", err.String())
    76 			if Log != nil {
       
    77 				Log.Err("write: " + err.String())
       
    78 			}
    73 			break
    79 			break
    74 		}
    80 		}
    75 		nw, err := cl.socket.Write(p[:nr])
    81 		nw, err := cl.socket.Write(p[:nr])
    76 		if nw < nr {
    82 		if nw < nr {
    77 			log.Println("write: %s", err.String())
    83 			if Log != nil {
       
    84 				Log.Err("write: " + err.String())
       
    85 			}
    78 			break
    86 			break
    79 		}
    87 		}
    80 	}
    88 	}
    81 }
    89 }
    82 
    90 
    83 func readXml(r io.Reader, ch chan<- interface{},
    91 func readXml(r io.Reader, ch chan<- interface{},
    84 	extStanza map[string] func(*xml.Name) interface{}) {
    92 	extStanza map[string] func(*xml.Name) interface{}) {
    85 	if debug {
    93 	if Loglevel >= syslog.LOG_DEBUG {
    86 		pr, pw := io.Pipe()
    94 		pr, pw := io.Pipe()
    87 		go tee(r, pw, "S: ")
    95 		go tee(r, pw, "S: ")
    88 		r = pr
    96 		r = pr
    89 	}
    97 	}
    90 	defer tryClose(r, ch)
    98 	defer tryClose(r, ch)
    93 	for {
   101 	for {
    94 		// Sniff the next token on the stream.
   102 		// Sniff the next token on the stream.
    95 		t, err := p.Token()
   103 		t, err := p.Token()
    96 		if t == nil {
   104 		if t == nil {
    97 			if err != os.EOF {
   105 			if err != os.EOF {
    98 				log.Printf("read: %v", err)
   106 				if Log != nil {
       
   107 					Log.Err("read: " + err.String())
       
   108 				}
    99 			}
   109 			}
   100 			break
   110 			break
   101 		}
   111 		}
   102 		var se xml.StartElement
   112 		var se xml.StartElement
   103 		var ok bool
   113 		var ok bool
   109 		var obj interface{}
   119 		var obj interface{}
   110 		switch se.Name.Space + " " + se.Name.Local {
   120 		switch se.Name.Space + " " + se.Name.Local {
   111 		case NsStream + " stream":
   121 		case NsStream + " stream":
   112 			st, err := parseStream(se)
   122 			st, err := parseStream(se)
   113 			if err != nil {
   123 			if err != nil {
   114 				log.Printf("unmarshal stream: %v",
   124 				if Log != nil {
   115 					err)
   125 					Log.Err("unmarshal stream: " +
       
   126 						err.String())
       
   127 				}
   116 				break
   128 				break
   117 			}
   129 			}
   118 			ch <- st
   130 			ch <- st
   119 			continue
   131 			continue
   120 		case "stream error", NsStream + " error":
   132 		case "stream error", NsStream + " error":
   132 			obj = &Message{}
   144 			obj = &Message{}
   133 		case "jabber:client presence":
   145 		case "jabber:client presence":
   134 			obj = &Presence{}
   146 			obj = &Presence{}
   135 		default:
   147 		default:
   136 			obj = &Generic{}
   148 			obj = &Generic{}
   137 			log.Printf("Ignoring unrecognized: %s %s\n",
   149 			if Log != nil {
   138 				se.Name.Space, se.Name.Local)
   150 				Log.Notice("Ignoring unrecognized: " +
       
   151 					se.Name.Space + " " + se.Name.Local)
       
   152 			}
   139 		}
   153 		}
   140 
   154 
   141 		// Read the complete XML stanza.
   155 		// Read the complete XML stanza.
   142 		err = p.Unmarshal(obj, &se)
   156 		err = p.Unmarshal(obj, &se)
   143 		if err != nil {
   157 		if err != nil {
   144 			log.Printf("unmarshal: %v", err)
   158 			if Log != nil {
       
   159 				Log.Err("unmarshal: " + err.String())
       
   160 			}
   145 			break
   161 			break
   146 		}
   162 		}
   147 
   163 
   148 		// If it's a Stanza, we try to unmarshal its innerxml
   164 		// If it's a Stanza, we try to unmarshal its innerxml
   149 		// into objects of the appropriate respective
   165 		// into objects of the appropriate respective
   150 		// types. This is specified by our extensions.
   166 		// types. This is specified by our extensions.
   151 		if st, ok := obj.(Stanza) ; ok {
   167 		if st, ok := obj.(Stanza) ; ok {
   152 			err = parseExtended(st, extStanza)
   168 			err = parseExtended(st, extStanza)
   153 			if err != nil {
   169 			if err != nil {
   154 				log.Printf("ext unmarshal: %v",
   170 				if Log != nil {
   155 					err)
   171 					Log.Err("ext unmarshal: " +
       
   172 						err.String())
       
   173 				}
   156 				break
   174 				break
   157 			}
   175 			}
   158 		}
   176 		}
   159 
   177 
   160 		// Put it on the channel.
   178 		// Put it on the channel.
   193 
   211 
   194 	return nil
   212 	return nil
   195 }
   213 }
   196 
   214 
   197 func writeXml(w io.Writer, ch <-chan interface{}) {
   215 func writeXml(w io.Writer, ch <-chan interface{}) {
   198 	if debug {
   216 	if Loglevel >= syslog.LOG_DEBUG {
   199 		pr, pw := io.Pipe()
   217 		pr, pw := io.Pipe()
   200 		go tee(pr, w, "C: ")
   218 		go tee(pr, w, "C: ")
   201 		w = pw
   219 		w = pw
   202 	}
   220 	}
   203 	defer tryClose(w, ch)
   221 	defer tryClose(w, ch)
   204 
   222 
   205 	for obj := range ch {
   223 	for obj := range ch {
   206 		err := xml.Marshal(w, obj)
   224 		err := xml.Marshal(w, obj)
   207 		if err != nil {
   225 		if err != nil {
   208 			log.Printf("write: %v", err)
   226 			if Log != nil {
       
   227 				Log.Err("write: " + err.String())
       
   228 			}
   209 			break
   229 			break
   210 		}
   230 		}
   211 	}
   231 	}
   212 }
   232 }
   213 
   233 
   241 			if !send {
   261 			if !send {
   242 				continue
   262 				continue
   243 			}
   263 			}
   244 			st, ok := x.(Stanza)
   264 			st, ok := x.(Stanza)
   245 			if !ok {
   265 			if !ok {
   246 				log.Printf("Unhandled non-stanza: %v",
   266 				if Log != nil {
   247 					x)
   267 					Log.Warning(fmt.Sprintf(
       
   268 						"Unhandled non-stanza: %v", x))
       
   269 				}
   248 				continue
   270 				continue
   249 			}
   271 			}
   250 			if handlers[st.GetId()] != nil {
   272 			if handlers[st.GetId()] != nil {
   251 				f := handlers[st.GetId()]
   273 				f := handlers[st.GetId()]
   252 				handlers[st.GetId()] = nil
   274 				handlers[st.GetId()] = nil
   279 			case -1:
   301 			case -1:
   280 				break
   302 				break
   281 			}
   303 			}
   282 		case x := <- input:
   304 		case x := <- input:
   283 			if x == nil {
   305 			if x == nil {
   284 				log.Println("Refusing to send nil stanza")
   306 				if Log != nil {
       
   307 					Log.Notice("Refusing to send" +
       
   308 						" nil stanza")
       
   309 				}
   285 				continue
   310 				continue
   286 			}
   311 			}
   287 			srvOut <- x
   312 			srvOut <- x
   288 		}
   313 		}
   289 	}
   314 	}
   296 	defer close(app)
   321 	defer close(app)
   297 	for {
   322 	for {
   298 		select {
   323 		select {
   299 		case newFilterOut := <- filterOut:
   324 		case newFilterOut := <- filterOut:
   300 			if newFilterOut == nil {
   325 			if newFilterOut == nil {
   301 				log.Println("Received nil filter")
   326 				if Log != nil {
       
   327 					Log.Warning("Received nil filter")
       
   328 				}
   302 				filterIn <- nil
   329 				filterIn <- nil
   303 				continue
   330 				continue
   304 			}
   331 			}
   305 			filterIn <- topFilter
   332 			filterIn <- topFilter
   306 			topFilter = newFilterOut
   333 			topFilter = newFilterOut
   323 
   350 
   324 func handleStream(ss *stream) {
   351 func handleStream(ss *stream) {
   325 }
   352 }
   326 
   353 
   327 func (cl *Client) handleStreamError(se *streamError) {
   354 func (cl *Client) handleStreamError(se *streamError) {
   328 	log.Printf("Received stream error: %v", se)
   355 	if Log != nil {
       
   356 		Log.Notice(fmt.Sprintf("Received stream error: %v", se))
       
   357 	}
   329 	cl.Close()
   358 	cl.Close()
   330 }
   359 }
   331 
   360 
   332 func (cl *Client) handleFeatures(fe *Features) {
   361 func (cl *Client) handleFeatures(fe *Features) {
   333 	cl.Features = fe
   362 	cl.Features = fe
   372 
   401 
   373 	// Reset the read timeout on the (underlying) socket so the
   402 	// Reset the read timeout on the (underlying) socket so the
   374 	// reader doesn't get woken up unnecessarily.
   403 	// reader doesn't get woken up unnecessarily.
   375 	tcp.SetReadTimeout(0)
   404 	tcp.SetReadTimeout(0)
   376 
   405 
   377 	log.Println("TLS negotiation succeeded.")
   406 	if Log != nil {
       
   407 		Log.Info("TLS negotiation succeeded.")
       
   408 	}
   378 	cl.Features = nil
   409 	cl.Features = nil
   379 
   410 
   380 	// Now re-send the initial handshake message to start the new
   411 	// Now re-send the initial handshake message to start the new
   381 	// session.
   412 	// session.
   382 	hsOut := &stream{To: cl.Jid.Domain, Version: Version}
   413 	hsOut := &stream{To: cl.Jid.Domain, Version: Version}
   419 	switch strings.ToLower(srv.XMLName.Local) {
   450 	switch strings.ToLower(srv.XMLName.Local) {
   420 	case "challenge":
   451 	case "challenge":
   421 		b64 := base64.StdEncoding
   452 		b64 := base64.StdEncoding
   422 		str, err := b64.DecodeString(srv.Chardata)
   453 		str, err := b64.DecodeString(srv.Chardata)
   423 		if err != nil {
   454 		if err != nil {
   424 			log.Printf("SASL challenge decode: %s",
   455 			if Log != nil {
   425 				err.String())
   456 				Log.Err("SASL challenge decode: " +
       
   457 					err.String())
       
   458 			}
   426 			return;
   459 			return;
   427 		}
   460 		}
   428 		srvMap := parseSasl(string(str))
   461 		srvMap := parseSasl(string(str))
   429 
   462 
   430 		if cl.saslExpected == "" {
   463 		if cl.saslExpected == "" {
   431 			cl.saslDigest1(srvMap)
   464 			cl.saslDigest1(srvMap)
   432 		} else {
   465 		} else {
   433 			cl.saslDigest2(srvMap)
   466 			cl.saslDigest2(srvMap)
   434 		}
   467 		}
   435 	case "failure":
   468 	case "failure":
   436 		log.Println("SASL authentication failed")
   469 		if Log != nil {
       
   470 			Log.Notice("SASL authentication failed")
       
   471 		}
   437 	case "success":
   472 	case "success":
   438 		log.Println("SASL authentication succeeded")
   473 		if Log != nil {
       
   474 			Log.Info("Sasl authentication succeeded")
       
   475 		}
   439 		cl.Features = nil
   476 		cl.Features = nil
   440 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   477 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   441 		cl.xmlOut <- ss
   478 		cl.xmlOut <- ss
   442 	}
   479 	}
   443 }
   480 }
   449 		if qop == "auth" {
   486 		if qop == "auth" {
   450 			hasAuth = true
   487 			hasAuth = true
   451 		}
   488 		}
   452 	}
   489 	}
   453 	if !hasAuth {
   490 	if !hasAuth {
   454 		log.Println("Server doesn't support SASL auth")
   491 		if Log != nil {
       
   492 			Log.Err("Server doesn't support SASL auth")
       
   493 		}
   455 		return;
   494 		return;
   456 	}
   495 	}
   457 
   496 
   458 	// Pick a realm.
   497 	// Pick a realm.
   459 	var realm string
   498 	var realm string
   479 	// Generate our own nonce from random data.
   518 	// Generate our own nonce from random data.
   480 	randSize := big.NewInt(0)
   519 	randSize := big.NewInt(0)
   481 	randSize.Lsh(big.NewInt(1), 64)
   520 	randSize.Lsh(big.NewInt(1), 64)
   482 	cnonce, err := rand.Int(rand.Reader, randSize)
   521 	cnonce, err := rand.Int(rand.Reader, randSize)
   483 	if err != nil {
   522 	if err != nil {
   484 		log.Println("SASL rand: %s", err.String())
   523 		if Log != nil {
       
   524 			Log.Err("SASL rand: " + err.String())
       
   525 		}
   485 		return
   526 		return
   486 	}
   527 	}
   487 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   528 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   488 
   529 
   489 	/* Now encode the actual password response, as well as the
   530 	/* Now encode the actual password response, as well as the
   589 		bindReq.Resource = &res
   630 		bindReq.Resource = &res
   590 	}
   631 	}
   591 	msg := &Iq{Type: "set", Id: <- Id, Nested: []interface{}{bindReq}}
   632 	msg := &Iq{Type: "set", Id: <- Id, Nested: []interface{}{bindReq}}
   592 	f := func(st Stanza) bool {
   633 	f := func(st Stanza) bool {
   593 		if st.GetType() == "error" {
   634 		if st.GetType() == "error" {
   594 			log.Println("Resource binding failed")
   635 			if Log != nil {
       
   636 				Log.Err("Resource binding failed")
       
   637 			}
   595 			return false
   638 			return false
   596 		}
   639 		}
   597 		var bindRepl *bindIq
   640 		var bindRepl *bindIq
   598 		for _, ele := range(st.GetNested()) {
   641 		for _, ele := range(st.GetNested()) {
   599 			if b, ok := ele.(*bindIq) ; ok {
   642 			if b, ok := ele.(*bindIq) ; ok {
   600 				bindRepl = b
   643 				bindRepl = b
   601 				break
   644 				break
   602 			}
   645 			}
   603 		}
   646 		}
   604 		if bindRepl == nil {
   647 		if bindRepl == nil {
   605 			log.Printf("bad bind reply: %v", st)
   648 			if Log != nil {
       
   649 				Log.Err(fmt.Sprintf("Bad bind reply: %v",
       
   650 					st))
       
   651 			}
   606 			return false
   652 			return false
   607 		}
   653 		}
   608 		jidStr := bindRepl.Jid
   654 		jidStr := bindRepl.Jid
   609 		if jidStr == nil || *jidStr == "" {
   655 		if jidStr == nil || *jidStr == "" {
   610 			log.Println("empty resource")
   656 			if Log != nil {
       
   657 				Log.Err("Can't bind empty resource")
       
   658 			}
   611 			return false
   659 			return false
   612 		}
   660 		}
   613 		jid := new(JID)
   661 		jid := new(JID)
   614 		if !jid.Set(*jidStr) {
   662 		if !jid.Set(*jidStr) {
   615 			log.Println("Can't parse JID %s", jidStr)
   663 			if Log != nil {
       
   664 				Log.Err("Can't parse JID " + *jidStr)
       
   665 			}
   616 			return false
   666 			return false
   617 		}
   667 		}
   618 		cl.Jid = *jid
   668 		cl.Jid = *jid
   619 		log.Printf("Bound resource: %s", cl.Jid.String())
   669 		if Log != nil {
       
   670 			Log.Info("Bound resource: " + cl.Jid.String())
       
   671 		}
   620 		cl.bindDone()
   672 		cl.bindDone()
   621 		return false
   673 		return false
   622 	}
   674 	}
   623 	cl.HandleStanza(msg.Id, f)
   675 	cl.HandleStanza(msg.Id, f)
   624 	cl.xmlOut <- msg
   676 	cl.xmlOut <- msg