stream.go
changeset 102 872e936f9f3f
parent 100 24231ff0016c
child 104 99e03b33b20d
equal deleted inserted replaced
101:5d721a565503 102:872e936f9f3f
    48 			if errno, ok := err.(*net.OpError); ok {
    48 			if errno, ok := err.(*net.OpError); ok {
    49 				if errno.Timeout() {
    49 				if errno.Timeout() {
    50 					continue
    50 					continue
    51 				}
    51 				}
    52 			}
    52 			}
    53 			Warnf("read: %s", err)
    53 			Warn.Logf("read: %s", err)
    54 			break
    54 			break
    55 		}
    55 		}
    56 		nw, err := w.Write(p[:nr])
    56 		nw, err := w.Write(p[:nr])
    57 		if nw < nr {
    57 		if nw < nr {
    58 			Warnf("read: %s", err)
    58 			Warn.Logf("read: %s", err)
    59 			break
    59 			break
    60 		}
    60 		}
    61 	}
    61 	}
    62 }
    62 }
    63 
    63 
    65 	defer cl.socket.Close()
    65 	defer cl.socket.Close()
    66 	p := make([]byte, 1024)
    66 	p := make([]byte, 1024)
    67 	for {
    67 	for {
    68 		nr, err := r.Read(p)
    68 		nr, err := r.Read(p)
    69 		if nr == 0 {
    69 		if nr == 0 {
    70 			Warnf("write: %s", err)
    70 			Warn.Logf("write: %s", err)
    71 			break
    71 			break
    72 		}
    72 		}
    73 		nw, err := cl.socket.Write(p[:nr])
    73 		nw, err := cl.socket.Write(p[:nr])
    74 		if nw < nr {
    74 		if nw < nr {
    75 			Warnf("write: %s", err)
    75 			Warn.Logf("write: %s", err)
    76 			break
    76 			break
    77 		}
    77 		}
    78 	}
    78 	}
    79 }
    79 }
    80 
    80 
    81 func readXml(r io.Reader, ch chan<- interface{},
    81 func readXml(r io.Reader, ch chan<- interface{},
    82 extStanza map[string]func(*xml.Name) interface{}) {
    82 extStanza map[string]func(*xml.Name) interface{}) {
    83 	if Debug != nil {
    83 	if _, ok := Debug.(*noLog) ; ok {
    84 		pr, pw := io.Pipe()
    84 		pr, pw := io.Pipe()
    85 		go tee(r, pw, "S: ")
    85 		go tee(r, pw, "S: ")
    86 		r = pr
    86 		r = pr
    87 	}
    87 	}
    88 	defer close(ch)
    88 	defer close(ch)
    94 	for {
    94 	for {
    95 		// Sniff the next token on the stream.
    95 		// Sniff the next token on the stream.
    96 		t, err := p.Token()
    96 		t, err := p.Token()
    97 		if t == nil {
    97 		if t == nil {
    98 			if err != io.EOF {
    98 			if err != io.EOF {
    99 				Warnf("read: %s", err)
    99 				Warn.Logf("read: %s", err)
   100 			}
   100 			}
   101 			break
   101 			break
   102 		}
   102 		}
   103 		var se xml.StartElement
   103 		var se xml.StartElement
   104 		var ok bool
   104 		var ok bool
   110 		var obj interface{}
   110 		var obj interface{}
   111 		switch se.Name.Space + " " + se.Name.Local {
   111 		switch se.Name.Space + " " + se.Name.Local {
   112 		case NsStream + " stream":
   112 		case NsStream + " stream":
   113 			st, err := parseStream(se)
   113 			st, err := parseStream(se)
   114 			if err != nil {
   114 			if err != nil {
   115 				Warnf("unmarshal stream: %s", err)
   115 				Warn.Logf("unmarshal stream: %s", err)
   116 				break Loop
   116 				break Loop
   117 			}
   117 			}
   118 			ch <- st
   118 			ch <- st
   119 			continue
   119 			continue
   120 		case "stream error", NsStream + " error":
   120 		case "stream error", NsStream + " error":
   132 			obj = &Message{}
   132 			obj = &Message{}
   133 		case NsClient + " presence":
   133 		case NsClient + " presence":
   134 			obj = &Presence{}
   134 			obj = &Presence{}
   135 		default:
   135 		default:
   136 			obj = &Generic{}
   136 			obj = &Generic{}
   137 			Infof("Ignoring unrecognized: %s %s", se.Name.Space,
   137 			Info.Logf("Ignoring unrecognized: %s %s", se.Name.Space,
   138 				se.Name.Local)
   138 				se.Name.Local)
   139 		}
   139 		}
   140 
   140 
   141 		// Read the complete XML stanza.
   141 		// Read the complete XML stanza.
   142 		err = p.DecodeElement(obj, &se)
   142 		err = p.DecodeElement(obj, &se)
   143 		if err != nil {
   143 		if err != nil {
   144 			Warnf("unmarshal: %s", err)
   144 			Warn.Logf("unmarshal: %s", err)
   145 			break Loop
   145 			break Loop
   146 		}
   146 		}
   147 
   147 
   148 		// If it's a Stanza, we try to unmarshal its innerxml
   148 		// If it's a Stanza, we try to unmarshal its innerxml
   149 		// into objects of the appropriate respective
   149 		// into objects of the appropriate respective
   150 		// types. This is specified by our extensions.
   150 		// types. This is specified by our extensions.
   151 		if st, ok := obj.(Stanza); ok {
   151 		if st, ok := obj.(Stanza); ok {
   152 			err = parseExtended(st, extStanza)
   152 			err = parseExtended(st, extStanza)
   153 			if err != nil {
   153 			if err != nil {
   154 				Warnf("ext unmarshal: %s", err)
   154 				Warn.Logf("ext unmarshal: %s", err)
   155 				break Loop
   155 				break Loop
   156 			}
   156 			}
   157 		}
   157 		}
   158 
   158 
   159 		// Put it on the channel.
   159 		// Put it on the channel.
   192 
   192 
   193 	return nil
   193 	return nil
   194 }
   194 }
   195 
   195 
   196 func writeXml(w io.Writer, ch <-chan interface{}) {
   196 func writeXml(w io.Writer, ch <-chan interface{}) {
   197 	if Debug != nil {
   197 	if _, ok := Debug.(*noLog) ; ok {
   198 		pr, pw := io.Pipe()
   198 		pr, pw := io.Pipe()
   199 		go tee(pr, w, "C: ")
   199 		go tee(pr, w, "C: ")
   200 		w = pw
   200 		w = pw
   201 	}
   201 	}
   202 	defer func(w io.Writer) {
   202 	defer func(w io.Writer) {
   211 
   211 
   212 	for obj := range ch {
   212 	for obj := range ch {
   213 		if st, ok := obj.(*stream); ok {
   213 		if st, ok := obj.(*stream); ok {
   214 			_, err := w.Write([]byte(st.String()))
   214 			_, err := w.Write([]byte(st.String()))
   215 			if err != nil {
   215 			if err != nil {
   216 				Warnf("write: %s", err)
   216 				Warn.Logf("write: %s", err)
   217 			}
   217 			}
   218 		} else {
   218 		} else {
   219 			err := enc.Encode(obj)
   219 			err := enc.Encode(obj)
   220 			if err != nil {
   220 			if err != nil {
   221 				Warnf("marshal: %s", err)
   221 				Warn.Logf("marshal: %s", err)
   222 				break
   222 				break
   223 			}
   223 			}
   224 		}
   224 		}
   225 	}
   225 	}
   226 }
   226 }
   256 			if !send {
   256 			if !send {
   257 				continue
   257 				continue
   258 			}
   258 			}
   259 			st, ok := x.(Stanza)
   259 			st, ok := x.(Stanza)
   260 			if !ok {
   260 			if !ok {
   261 				Warnf("Unhandled non-stanza: %v", x)
   261 				Warn.Logf("Unhandled non-stanza: %v", x)
   262 				continue
   262 				continue
   263 			}
   263 			}
   264 			if handlers[st.GetId()] != nil {
   264 			if handlers[st.GetId()] != nil {
   265 				f := handlers[st.GetId()]
   265 				f := handlers[st.GetId()]
   266 				delete(handlers, st.GetId())
   266 				delete(handlers, st.GetId())
   297 		case x, ok := <-input:
   297 		case x, ok := <-input:
   298 			if !ok {
   298 			if !ok {
   299 				break Loop
   299 				break Loop
   300 			}
   300 			}
   301 			if x == nil {
   301 			if x == nil {
   302 				Infof("Refusing to send nil stanza")
   302 				Info.Logf("Refusing to send nil stanza")
   303 				continue
   303 				continue
   304 			}
   304 			}
   305 			srvOut <- x
   305 			srvOut <- x
   306 		}
   306 		}
   307 	}
   307 	}
   315 Loop:
   315 Loop:
   316 	for {
   316 	for {
   317 		select {
   317 		select {
   318 		case newFilterOut := <-filterOut:
   318 		case newFilterOut := <-filterOut:
   319 			if newFilterOut == nil {
   319 			if newFilterOut == nil {
   320 				Warnf("Received nil filter")
   320 				Warn.Logf("Received nil filter")
   321 				filterIn <- nil
   321 				filterIn <- nil
   322 				continue
   322 				continue
   323 			}
   323 			}
   324 			filterIn <- topFilter
   324 			filterIn <- topFilter
   325 			topFilter = newFilterOut
   325 			topFilter = newFilterOut
   342 
   342 
   343 func handleStream(ss *stream) {
   343 func handleStream(ss *stream) {
   344 }
   344 }
   345 
   345 
   346 func (cl *Client) handleStreamError(se *streamError) {
   346 func (cl *Client) handleStreamError(se *streamError) {
   347 	Infof("Received stream error: %v", se)
   347 	Info.Logf("Received stream error: %v", se)
   348 	close(cl.Out)
   348 	close(cl.Out)
   349 }
   349 }
   350 
   350 
   351 func (cl *Client) handleFeatures(fe *Features) {
   351 func (cl *Client) handleFeatures(fe *Features) {
   352 	cl.Features = fe
   352 	cl.Features = fe
   387 	// for it to signal that it's working again.
   387 	// for it to signal that it's working again.
   388 	cl.socketSync.Add(1)
   388 	cl.socketSync.Add(1)
   389 	cl.socket = tls
   389 	cl.socket = tls
   390 	cl.socketSync.Wait()
   390 	cl.socketSync.Wait()
   391 
   391 
   392 	Infof("TLS negotiation succeeded.")
   392 	Info.Logf("TLS negotiation succeeded.")
   393 	cl.Features = nil
   393 	cl.Features = nil
   394 
   394 
   395 	// Now re-send the initial handshake message to start the new
   395 	// Now re-send the initial handshake message to start the new
   396 	// session.
   396 	// session.
   397 	hsOut := &stream{To: cl.Jid.Domain, Version: Version}
   397 	hsOut := &stream{To: cl.Jid.Domain, Version: Version}
   433 	switch strings.ToLower(srv.XMLName.Local) {
   433 	switch strings.ToLower(srv.XMLName.Local) {
   434 	case "challenge":
   434 	case "challenge":
   435 		b64 := base64.StdEncoding
   435 		b64 := base64.StdEncoding
   436 		str, err := b64.DecodeString(srv.Chardata)
   436 		str, err := b64.DecodeString(srv.Chardata)
   437 		if err != nil {
   437 		if err != nil {
   438 			Warnf("SASL challenge decode: %s", err)
   438 			Warn.Logf("SASL challenge decode: %s", err)
   439 			return
   439 			return
   440 		}
   440 		}
   441 		srvMap := parseSasl(string(str))
   441 		srvMap := parseSasl(string(str))
   442 
   442 
   443 		if cl.saslExpected == "" {
   443 		if cl.saslExpected == "" {
   444 			cl.saslDigest1(srvMap)
   444 			cl.saslDigest1(srvMap)
   445 		} else {
   445 		} else {
   446 			cl.saslDigest2(srvMap)
   446 			cl.saslDigest2(srvMap)
   447 		}
   447 		}
   448 	case "failure":
   448 	case "failure":
   449 		Infof("SASL authentication failed")
   449 		Info.Logf("SASL authentication failed")
   450 	case "success":
   450 	case "success":
   451 		Infof("Sasl authentication succeeded")
   451 		Info.Logf("Sasl authentication succeeded")
   452 		cl.Features = nil
   452 		cl.Features = nil
   453 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   453 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   454 		cl.xmlOut <- ss
   454 		cl.xmlOut <- ss
   455 	}
   455 	}
   456 }
   456 }
   462 		if qop == "auth" {
   462 		if qop == "auth" {
   463 			hasAuth = true
   463 			hasAuth = true
   464 		}
   464 		}
   465 	}
   465 	}
   466 	if !hasAuth {
   466 	if !hasAuth {
   467 		Warnf("Server doesn't support SASL auth")
   467 		Warn.Logf("Server doesn't support SASL auth")
   468 		return
   468 		return
   469 	}
   469 	}
   470 
   470 
   471 	// Pick a realm.
   471 	// Pick a realm.
   472 	var realm string
   472 	var realm string
   492 	// Generate our own nonce from random data.
   492 	// Generate our own nonce from random data.
   493 	randSize := big.NewInt(0)
   493 	randSize := big.NewInt(0)
   494 	randSize.Lsh(big.NewInt(1), 64)
   494 	randSize.Lsh(big.NewInt(1), 64)
   495 	cnonce, err := rand.Int(rand.Reader, randSize)
   495 	cnonce, err := rand.Int(rand.Reader, randSize)
   496 	if err != nil {
   496 	if err != nil {
   497 		Warnf("SASL rand: %s", err)
   497 		Warn.Logf("SASL rand: %s", err)
   498 		return
   498 		return
   499 	}
   499 	}
   500 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   500 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   501 
   501 
   502 	/* Now encode the actual password response, as well as the
   502 	/* Now encode the actual password response, as well as the
   598 	}
   598 	}
   599 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   599 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   600 	f := func(st Stanza) bool {
   600 	f := func(st Stanza) bool {
   601 		iq, ok := st.(*Iq)
   601 		iq, ok := st.(*Iq)
   602 		if !ok {
   602 		if !ok {
   603 			Warnf("non-iq response")
   603 			Warn.Logf("non-iq response")
   604 		}
   604 		}
   605 		if iq.Type == "error" {
   605 		if iq.Type == "error" {
   606 			Warnf("Resource binding failed")
   606 			Warn.Logf("Resource binding failed")
   607 			return false
   607 			return false
   608 		}
   608 		}
   609 		var bindRepl *bindIq
   609 		var bindRepl *bindIq
   610 		for _, ele := range iq.Nested {
   610 		for _, ele := range iq.Nested {
   611 			if b, ok := ele.(*bindIq); ok {
   611 			if b, ok := ele.(*bindIq); ok {
   612 				bindRepl = b
   612 				bindRepl = b
   613 				break
   613 				break
   614 			}
   614 			}
   615 		}
   615 		}
   616 		if bindRepl == nil {
   616 		if bindRepl == nil {
   617 			Warnf("Bad bind reply: %v", iq)
   617 			Warn.Logf("Bad bind reply: %v", iq)
   618 			return false
   618 			return false
   619 		}
   619 		}
   620 		jidStr := bindRepl.Jid
   620 		jidStr := bindRepl.Jid
   621 		if jidStr == nil || *jidStr == "" {
   621 		if jidStr == nil || *jidStr == "" {
   622 			Warnf("Can't bind empty resource")
   622 			Warn.Logf("Can't bind empty resource")
   623 			return false
   623 			return false
   624 		}
   624 		}
   625 		jid := new(JID)
   625 		jid := new(JID)
   626 		if err := jid.Set(*jidStr); err != nil {
   626 		if err := jid.Set(*jidStr); err != nil {
   627 			Warnf("Can't parse JID %s: %s", *jidStr, err)
   627 			Warn.Logf("Can't parse JID %s: %s", *jidStr, err)
   628 			return false
   628 			return false
   629 		}
   629 		}
   630 		cl.Jid = *jid
   630 		cl.Jid = *jid
   631 		Infof("Bound resource: %s", cl.Jid.String())
   631 		Info.Logf("Bound resource: %s", cl.Jid.String())
   632 		cl.bindDone()
   632 		cl.bindDone()
   633 		return false
   633 		return false
   634 	}
   634 	}
   635 	cl.HandleStanza(msg.Id, f)
   635 	cl.HandleStanza(msg.Id, f)
   636 	cl.xmlOut <- msg
   636 	cl.xmlOut <- msg