stream.go
changeset 98 c9cc4eda6dce
parent 72 53f15893a1a7
child 100 24231ff0016c
--- a/stream.go	Mon Jan 23 21:54:41 2012 -0700
+++ b/stream.go	Sun Dec 16 13:03:03 2012 -0700
@@ -11,20 +11,19 @@
 package xmpp
 
 import (
-	"big"
 	"crypto/md5"
 	"crypto/rand"
 	"crypto/tls"
 	"encoding/base64"
+	"encoding/xml"
 	"fmt"
 	"io"
+	"log/syslog"
+	"math/big"
 	"net"
-	"os"
 	"regexp"
 	"strings"
-	"syslog"
 	"time"
-	"xml"
 )
 
 // Callback to handle a stanza with a particular id.
@@ -39,12 +38,12 @@
 
 func (cl *Client) readTransport(w io.WriteCloser) {
 	defer w.Close()
-	cl.socket.SetReadTimeout(1e8)
 	p := make([]byte, 1024)
 	for {
 		if cl.socket == nil {
 			cl.waitForSocket()
 		}
+		cl.socket.SetReadDeadline(time.Now().Add(time.Second))
 		nr, err := cl.socket.Read(p)
 		if nr == 0 {
 			if errno, ok := err.(*net.OpError); ok {
@@ -53,14 +52,14 @@
 				}
 			}
 			if Log != nil {
-				Log.Err("read: " + err.String())
+				Log.Err("read: " + err.Error())
 			}
 			break
 		}
 		nw, err := w.Write(p[:nr])
 		if nw < nr {
 			if Log != nil {
-				Log.Err("read: " + err.String())
+				Log.Err("read: " + err.Error())
 			}
 			break
 		}
@@ -74,14 +73,14 @@
 		nr, err := r.Read(p)
 		if nr == 0 {
 			if Log != nil {
-				Log.Err("write: " + err.String())
+				Log.Err("write: " + err.Error())
 			}
 			break
 		}
 		nw, err := cl.socket.Write(p[:nr])
 		if nw < nr {
 			if Log != nil {
-				Log.Err("write: " + err.String())
+				Log.Err("write: " + err.Error())
 			}
 			break
 		}
@@ -97,15 +96,17 @@
 	}
 	defer close(ch)
 
-	p := xml.NewParser(r)
+	p := xml.NewDecoder(r)
+	p.Context.Map[""] = NsClient
+	p.Context.Map["stream"] = NsStream
 Loop:
 	for {
 		// Sniff the next token on the stream.
 		t, err := p.Token()
 		if t == nil {
-			if err != os.EOF {
+			if err != io.EOF {
 				if Log != nil {
-					Log.Err("read: " + err.String())
+					Log.Err("read: " + err.Error())
 				}
 			}
 			break
@@ -124,7 +125,7 @@
 			if err != nil {
 				if Log != nil {
 					Log.Err("unmarshal stream: " +
-						err.String())
+						err.Error())
 				}
 				break Loop
 			}
@@ -139,11 +140,11 @@
 		case NsSASL + " challenge", NsSASL + " failure",
 			NsSASL + " success":
 			obj = &auth{}
-		case "jabber:client iq":
+		case NsClient + " iq":
 			obj = &Iq{}
-		case "jabber:client message":
+		case NsClient + " message":
 			obj = &Message{}
-		case "jabber:client presence":
+		case NsClient + " presence":
 			obj = &Presence{}
 		default:
 			obj = &Generic{}
@@ -154,10 +155,10 @@
 		}
 
 		// Read the complete XML stanza.
-		err = p.Unmarshal(obj, &se)
+		err = p.DecodeElement(obj, &se)
 		if err != nil {
 			if Log != nil {
-				Log.Err("unmarshal: " + err.String())
+				Log.Err("unmarshal: " + err.Error())
 			}
 			break Loop
 		}
@@ -170,8 +171,9 @@
 			if err != nil {
 				if Log != nil {
 					Log.Err("ext unmarshal: " +
-						err.String())
+						err.Error())
 				}
+				fmt.Printf("ext: %v\n", err)
 				break Loop
 			}
 		}
@@ -181,14 +183,14 @@
 	}
 }
 
-func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) os.Error {
+func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) error {
 	// Now parse the stanza's innerxml to find the string that we
 	// can unmarshal this nested element from.
 	reader := strings.NewReader(st.innerxml())
-	p := xml.NewParser(reader)
+	p := xml.NewDecoder(reader)
 	for {
 		t, err := p.Token()
-		if err == os.EOF {
+		if err == io.EOF {
 			break
 		}
 		if err != nil {
@@ -201,7 +203,7 @@
 
 				// Unmarshal the nested element and
 				// stuff it back into the stanza.
-				err := p.Unmarshal(nested, &se)
+				err := p.DecodeElement(nested, &se)
 				if err != nil {
 					return err
 				}
@@ -225,13 +227,26 @@
 		}
 	}(w)
 
+	enc := xml.NewEncoder(w)
+	enc.Context.Map[NsClient] = ""
+	enc.Context.Map[NsStream] = "stream"
+
 	for obj := range ch {
-		err := xml.Marshal(w, obj)
-		if err != nil {
-			if Log != nil {
-				Log.Err("write: " + err.String())
+		if st, ok := obj.(*stream); ok {
+			_, err := w.Write([]byte(st.String()))
+			if err != nil {
+				if Log != nil {
+					Log.Err("write: " + err.Error())
+				}
 			}
-			break
+		} else {
+			err := enc.Encode(obj)
+			if err != nil {
+				if Log != nil {
+					Log.Err("marshal: " + err.Error())
+				}
+				break
+			}
 		}
 	}
 }
@@ -277,7 +292,7 @@
 			}
 			if handlers[st.GetId()] != nil {
 				f := handlers[st.GetId()]
-				handlers[st.GetId()] = nil
+				delete(handlers, st.GetId())
 				send = f(st)
 			}
 			if send {
@@ -410,10 +425,6 @@
 	cl.socket = tls
 	cl.socketSync.Wait()
 
-	// Reset the read timeout on the (underlying) socket so the
-	// reader doesn't get woken up unnecessarily.
-	tcp.SetReadTimeout(0)
-
 	if Log != nil {
 		Log.Info("TLS negotiation succeeded.")
 	}
@@ -464,7 +475,7 @@
 		if err != nil {
 			if Log != nil {
 				Log.Err("SASL challenge decode: " +
-					err.String())
+					err.Error())
 			}
 			return
 		}
@@ -531,7 +542,7 @@
 	cnonce, err := rand.Int(rand.Reader, randSize)
 	if err != nil {
 		if Log != nil {
-			Log.Err("SASL rand: " + err.String())
+			Log.Err("SASL rand: " + err.Error())
 		}
 		return
 	}
@@ -609,7 +620,7 @@
 	h := func(text string) []byte {
 		h := md5.New()
 		h.Write([]byte(text))
-		return h.Sum()
+		return h.Sum(nil)
 	}
 	hex := func(bytes []byte) string {
 		return fmt.Sprintf("%x", bytes)
@@ -636,14 +647,20 @@
 	}
 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
 	f := func(st Stanza) bool {
-		if st.GetType() == "error" {
+		iq, ok := st.(*Iq)
+		if !ok {
+			if Log != nil {
+				Log.Err("non-iq response")
+			}
+		}
+		if iq.Type == "error" {
 			if Log != nil {
 				Log.Err("Resource binding failed")
 			}
 			return false
 		}
 		var bindRepl *bindIq
-		for _, ele := range st.GetNested() {
+		for _, ele := range iq.Nested {
 			if b, ok := ele.(*bindIq); ok {
 				bindRepl = b
 				break
@@ -652,7 +669,7 @@
 		if bindRepl == nil {
 			if Log != nil {
 				Log.Err(fmt.Sprintf("Bad bind reply: %v",
-					st))
+					iq))
 			}
 			return false
 		}
@@ -664,9 +681,10 @@
 			return false
 		}
 		jid := new(JID)
-		if !jid.Set(*jidStr) {
+		if err := jid.Set(*jidStr); err != nil {
 			if Log != nil {
-				Log.Err("Can't parse JID " + *jidStr)
+				Log.Err(fmt.Sprintf("Can't parse JID %s: %s",
+						*jidStr, err))
 			}
 			return false
 		}