stream.go
changeset 76 caa722ab8a0f
parent 75 03a923eb5c01
child 78 a5848c75d270
equal deleted inserted replaced
75:03a923eb5c01 76:caa722ab8a0f
    50 				if errno.Timeout() {
    50 				if errno.Timeout() {
    51 					continue
    51 					continue
    52 				}
    52 				}
    53 			}
    53 			}
    54 			if Log != nil {
    54 			if Log != nil {
    55 				Log.Err("read: " + err.Error())
    55 				Log.Println("read: " + err.Error())
    56 			}
    56 			}
    57 			break
    57 			break
    58 		}
    58 		}
    59 		nw, err := w.Write(p[:nr])
    59 		nw, err := w.Write(p[:nr])
    60 		if nw < nr {
    60 		if nw < nr {
    61 			if Log != nil {
    61 			if Log != nil {
    62 				Log.Err("read: " + err.Error())
    62 				Log.Println("read: " + err.Error())
    63 			}
    63 			}
    64 			break
    64 			break
    65 		}
    65 		}
    66 	}
    66 	}
    67 }
    67 }
    71 	p := make([]byte, 1024)
    71 	p := make([]byte, 1024)
    72 	for {
    72 	for {
    73 		nr, err := r.Read(p)
    73 		nr, err := r.Read(p)
    74 		if nr == 0 {
    74 		if nr == 0 {
    75 			if Log != nil {
    75 			if Log != nil {
    76 				Log.Err("write: " + err.Error())
    76 				Log.Println("write: " + err.Error())
    77 			}
    77 			}
    78 			break
    78 			break
    79 		}
    79 		}
    80 		nw, err := cl.socket.Write(p[:nr])
    80 		nw, err := cl.socket.Write(p[:nr])
    81 		if nw < nr {
    81 		if nw < nr {
    82 			if Log != nil {
    82 			if Log != nil {
    83 				Log.Err("write: " + err.Error())
    83 				Log.Println("write: " + err.Error())
    84 			}
    84 			}
    85 			break
    85 			break
    86 		}
    86 		}
    87 	}
    87 	}
    88 }
    88 }
   102 		// Sniff the next token on the stream.
   102 		// Sniff the next token on the stream.
   103 		t, err := p.Token()
   103 		t, err := p.Token()
   104 		if t == nil {
   104 		if t == nil {
   105 			if err != io.EOF {
   105 			if err != io.EOF {
   106 				if Log != nil {
   106 				if Log != nil {
   107 					Log.Err("read: " + err.Error())
   107 					Log.Println("read: " + err.Error())
   108 				}
   108 				}
   109 			}
   109 			}
   110 			break
   110 			break
   111 		}
   111 		}
   112 		var se xml.StartElement
   112 		var se xml.StartElement
   120 		switch se.Name.Space + " " + se.Name.Local {
   120 		switch se.Name.Space + " " + se.Name.Local {
   121 		case NsStream + " stream":
   121 		case NsStream + " stream":
   122 			st, err := parseStream(se)
   122 			st, err := parseStream(se)
   123 			if err != nil {
   123 			if err != nil {
   124 				if Log != nil {
   124 				if Log != nil {
   125 					Log.Err("unmarshal stream: " +
   125 					Log.Println("unmarshal stream: " +
   126 						err.Error())
   126 						err.Error())
   127 				}
   127 				}
   128 				break Loop
   128 				break Loop
   129 			}
   129 			}
   130 			ch <- st
   130 			ch <- st
   144 			obj = &Message{}
   144 			obj = &Message{}
   145 		case "jabber:client presence":
   145 		case "jabber:client presence":
   146 			obj = &Presence{}
   146 			obj = &Presence{}
   147 		default:
   147 		default:
   148 			obj = &Generic{}
   148 			obj = &Generic{}
   149 			if Log != nil {
   149 			if Log != nil && Loglevel >= syslog.LOG_NOTICE {
   150 				Log.Notice("Ignoring unrecognized: " +
   150 				Log.Printf("Ignoring unrecognized: %s %s",
   151 					se.Name.Space + " " + se.Name.Local)
   151 					se.Name.Space, se.Name.Local)
   152 			}
   152 			}
   153 		}
   153 		}
   154 
   154 
   155 		// Read the complete XML stanza.
   155 		// Read the complete XML stanza.
   156 		err = p.Unmarshal(obj, &se)
   156 		err = p.Unmarshal(obj, &se)
   157 		if err != nil {
   157 		if err != nil {
   158 			if Log != nil {
   158 			if Log != nil {
   159 				Log.Err("unmarshal: " + err.Error())
   159 				Log.Println("unmarshal: " + err.Error())
   160 			}
   160 			}
   161 			break Loop
   161 			break Loop
   162 		}
   162 		}
   163 
   163 
   164 		// If it's a Stanza, we try to unmarshal its innerxml
   164 		// If it's a Stanza, we try to unmarshal its innerxml
   166 		// types. This is specified by our extensions.
   166 		// types. This is specified by our extensions.
   167 		if st, ok := obj.(Stanza); ok {
   167 		if st, ok := obj.(Stanza); ok {
   168 			err = parseExtended(st, extStanza)
   168 			err = parseExtended(st, extStanza)
   169 			if err != nil {
   169 			if err != nil {
   170 				if Log != nil {
   170 				if Log != nil {
   171 					Log.Err("ext unmarshal: " +
   171 					Log.Println("ext unmarshal: " +
   172 						err.Error())
   172 						err.Error())
   173 				}
   173 				}
   174 				break Loop
   174 				break Loop
   175 			}
   175 			}
   176 		}
   176 		}
   226 
   226 
   227 	for obj := range ch {
   227 	for obj := range ch {
   228 		err := xml.Marshal(w, obj)
   228 		err := xml.Marshal(w, obj)
   229 		if err != nil {
   229 		if err != nil {
   230 			if Log != nil {
   230 			if Log != nil {
   231 				Log.Err("write: " + err.Error())
   231 				Log.Println("write: " + err.Error())
   232 			}
   232 			}
   233 			break
   233 			break
   234 		}
   234 		}
   235 	}
   235 	}
   236 }
   236 }
   266 			if !send {
   266 			if !send {
   267 				continue
   267 				continue
   268 			}
   268 			}
   269 			st, ok := x.(Stanza)
   269 			st, ok := x.(Stanza)
   270 			if !ok {
   270 			if !ok {
   271 				if Log != nil {
   271 				if Log != nil && Loglevel >= syslog.LOG_WARNING {
   272 					Log.Warning(fmt.Sprintf(
   272 					Log.Printf(
   273 						"Unhandled non-stanza: %v", x))
   273 						"Unhandled non-stanza: %v", x)
   274 				}
   274 				}
   275 				continue
   275 				continue
   276 			}
   276 			}
   277 			if handlers[st.GetId()] != nil {
   277 			if handlers[st.GetId()] != nil {
   278 				f := handlers[st.GetId()]
   278 				f := handlers[st.GetId()]
   310 		case x, ok := <-input:
   310 		case x, ok := <-input:
   311 			if !ok {
   311 			if !ok {
   312 				break Loop
   312 				break Loop
   313 			}
   313 			}
   314 			if x == nil {
   314 			if x == nil {
   315 				if Log != nil {
   315 				if Log != nil && Loglevel >= syslog.LOG_NOTICE {
   316 					Log.Notice("Refusing to send" +
   316 					Log.Println("Refusing to send" +
   317 						" nil stanza")
   317 						" nil stanza")
   318 				}
   318 				}
   319 				continue
   319 				continue
   320 			}
   320 			}
   321 			srvOut <- x
   321 			srvOut <- x
   331 Loop:
   331 Loop:
   332 	for {
   332 	for {
   333 		select {
   333 		select {
   334 		case newFilterOut := <-filterOut:
   334 		case newFilterOut := <-filterOut:
   335 			if newFilterOut == nil {
   335 			if newFilterOut == nil {
   336 				if Log != nil {
   336 				if Log != nil && Loglevel >= syslog.LOG_WARNING {
   337 					Log.Warning("Received nil filter")
   337 					Log.Println("Received nil filter")
   338 				}
   338 				}
   339 				filterIn <- nil
   339 				filterIn <- nil
   340 				continue
   340 				continue
   341 			}
   341 			}
   342 			filterIn <- topFilter
   342 			filterIn <- topFilter
   360 
   360 
   361 func handleStream(ss *stream) {
   361 func handleStream(ss *stream) {
   362 }
   362 }
   363 
   363 
   364 func (cl *Client) handleStreamError(se *streamError) {
   364 func (cl *Client) handleStreamError(se *streamError) {
   365 	if Log != nil {
   365 	if Log != nil && Loglevel >= syslog.LOG_NOTICE {
   366 		Log.Notice(fmt.Sprintf("Received stream error: %v", se))
   366 		Log.Printf("Received stream error: %v", se)
   367 	}
   367 	}
   368 	close(cl.Out)
   368 	close(cl.Out)
   369 }
   369 }
   370 
   370 
   371 func (cl *Client) handleFeatures(fe *Features) {
   371 func (cl *Client) handleFeatures(fe *Features) {
   411 
   411 
   412 	// Reset the read timeout on the (underlying) socket so the
   412 	// Reset the read timeout on the (underlying) socket so the
   413 	// reader doesn't get woken up unnecessarily.
   413 	// reader doesn't get woken up unnecessarily.
   414 	tcp.SetReadTimeout(0)
   414 	tcp.SetReadTimeout(0)
   415 
   415 
   416 	if Log != nil {
   416 	if Log != nil && Loglevel >= syslog.LOG_INFO {
   417 		Log.Info("TLS negotiation succeeded.")
   417 		Log.Println("TLS negotiation succeeded.")
   418 	}
   418 	}
   419 	cl.Features = nil
   419 	cl.Features = nil
   420 
   420 
   421 	// Now re-send the initial handshake message to start the new
   421 	// Now re-send the initial handshake message to start the new
   422 	// session.
   422 	// session.
   460 	case "challenge":
   460 	case "challenge":
   461 		b64 := base64.StdEncoding
   461 		b64 := base64.StdEncoding
   462 		str, err := b64.DecodeString(srv.Chardata)
   462 		str, err := b64.DecodeString(srv.Chardata)
   463 		if err != nil {
   463 		if err != nil {
   464 			if Log != nil {
   464 			if Log != nil {
   465 				Log.Err("SASL challenge decode: " +
   465 				Log.Println("SASL challenge decode: " +
   466 					err.Error())
   466 					err.Error())
   467 			}
   467 			}
   468 			return
   468 			return
   469 		}
   469 		}
   470 		srvMap := parseSasl(string(str))
   470 		srvMap := parseSasl(string(str))
   473 			cl.saslDigest1(srvMap)
   473 			cl.saslDigest1(srvMap)
   474 		} else {
   474 		} else {
   475 			cl.saslDigest2(srvMap)
   475 			cl.saslDigest2(srvMap)
   476 		}
   476 		}
   477 	case "failure":
   477 	case "failure":
   478 		if Log != nil {
   478 		if Log != nil && Loglevel >= syslog.LOG_NOTICE {
   479 			Log.Notice("SASL authentication failed")
   479 			Log.Println("SASL authentication failed")
   480 		}
   480 		}
   481 	case "success":
   481 	case "success":
   482 		if Log != nil {
   482 		if Log != nil && Loglevel >= syslog.LOG_INFO {
   483 			Log.Info("Sasl authentication succeeded")
   483 			Log.Println("Sasl authentication succeeded")
   484 		}
   484 		}
   485 		cl.Features = nil
   485 		cl.Features = nil
   486 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   486 		ss := &stream{To: cl.Jid.Domain, Version: Version}
   487 		cl.xmlOut <- ss
   487 		cl.xmlOut <- ss
   488 	}
   488 	}
   496 			hasAuth = true
   496 			hasAuth = true
   497 		}
   497 		}
   498 	}
   498 	}
   499 	if !hasAuth {
   499 	if !hasAuth {
   500 		if Log != nil {
   500 		if Log != nil {
   501 			Log.Err("Server doesn't support SASL auth")
   501 			Log.Println("Server doesn't support SASL auth")
   502 		}
   502 		}
   503 		return
   503 		return
   504 	}
   504 	}
   505 
   505 
   506 	// Pick a realm.
   506 	// Pick a realm.
   528 	randSize := big.NewInt(0)
   528 	randSize := big.NewInt(0)
   529 	randSize.Lsh(big.NewInt(1), 64)
   529 	randSize.Lsh(big.NewInt(1), 64)
   530 	cnonce, err := rand.Int(rand.Reader, randSize)
   530 	cnonce, err := rand.Int(rand.Reader, randSize)
   531 	if err != nil {
   531 	if err != nil {
   532 		if Log != nil {
   532 		if Log != nil {
   533 			Log.Err("SASL rand: " + err.Error())
   533 			Log.Println("SASL rand: " + err.Error())
   534 		}
   534 		}
   535 		return
   535 		return
   536 	}
   536 	}
   537 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   537 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   538 
   538 
   635 	}
   635 	}
   636 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   636 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   637 	f := func(st Stanza) bool {
   637 	f := func(st Stanza) bool {
   638 		if st.GetType() == "error" {
   638 		if st.GetType() == "error" {
   639 			if Log != nil {
   639 			if Log != nil {
   640 				Log.Err("Resource binding failed")
   640 				Log.Println("Resource binding failed")
   641 			}
   641 			}
   642 			return false
   642 			return false
   643 		}
   643 		}
   644 		var bindRepl *bindIq
   644 		var bindRepl *bindIq
   645 		for _, ele := range st.GetNested() {
   645 		for _, ele := range st.GetNested() {
   648 				break
   648 				break
   649 			}
   649 			}
   650 		}
   650 		}
   651 		if bindRepl == nil {
   651 		if bindRepl == nil {
   652 			if Log != nil {
   652 			if Log != nil {
   653 				Log.Err(fmt.Sprintf("Bad bind reply: %v",
   653 				Log.Printf("Bad bind reply: %v", st)
   654 					st))
       
   655 			}
   654 			}
   656 			return false
   655 			return false
   657 		}
   656 		}
   658 		jidStr := bindRepl.Jid
   657 		jidStr := bindRepl.Jid
   659 		if jidStr == nil || *jidStr == "" {
   658 		if jidStr == nil || *jidStr == "" {
   660 			if Log != nil {
   659 			if Log != nil {
   661 				Log.Err("Can't bind empty resource")
   660 				Log.Println("Can't bind empty resource")
   662 			}
   661 			}
   663 			return false
   662 			return false
   664 		}
   663 		}
   665 		jid := new(JID)
   664 		jid := new(JID)
   666 		if err := jid.Set(*jidStr); err != nil {
   665 		if err := jid.Set(*jidStr); err != nil {
   667 			if Log != nil {
   666 			if Log != nil {
   668 				Log.Err(err.Error())
   667 				Log.Println(err.Error())
   669 			}
   668 			}
   670 			return false
   669 			return false
   671 		}
   670 		}
   672 		cl.Jid = *jid
   671 		cl.Jid = *jid
   673 		if Log != nil {
   672 		if Log != nil && Loglevel >= syslog.LOG_INFO {
   674 			Log.Info("Bound resource: " + cl.Jid.String())
   673 			Log.Println("Bound resource: " + cl.Jid.String())
   675 		}
   674 		}
   676 		cl.bindDone()
   675 		cl.bindDone()
   677 		return false
   676 		return false
   678 	}
   677 	}
   679 	cl.HandleStanza(msg.Id, f)
   678 	cl.HandleStanza(msg.Id, f)