--- a/stream.go Mon Jan 23 21:54:41 2012 -0700
+++ b/stream.go Sun Dec 16 13:03:03 2012 -0700
@@ -11,20 +11,19 @@
package xmpp
import (
- "big"
"crypto/md5"
"crypto/rand"
"crypto/tls"
"encoding/base64"
+ "encoding/xml"
"fmt"
"io"
+ "log/syslog"
+ "math/big"
"net"
- "os"
"regexp"
"strings"
- "syslog"
"time"
- "xml"
)
// Callback to handle a stanza with a particular id.
@@ -39,12 +38,12 @@
func (cl *Client) readTransport(w io.WriteCloser) {
defer w.Close()
- cl.socket.SetReadTimeout(1e8)
p := make([]byte, 1024)
for {
if cl.socket == nil {
cl.waitForSocket()
}
+ cl.socket.SetReadDeadline(time.Now().Add(time.Second))
nr, err := cl.socket.Read(p)
if nr == 0 {
if errno, ok := err.(*net.OpError); ok {
@@ -53,14 +52,14 @@
}
}
if Log != nil {
- Log.Err("read: " + err.String())
+ Log.Err("read: " + err.Error())
}
break
}
nw, err := w.Write(p[:nr])
if nw < nr {
if Log != nil {
- Log.Err("read: " + err.String())
+ Log.Err("read: " + err.Error())
}
break
}
@@ -74,14 +73,14 @@
nr, err := r.Read(p)
if nr == 0 {
if Log != nil {
- Log.Err("write: " + err.String())
+ Log.Err("write: " + err.Error())
}
break
}
nw, err := cl.socket.Write(p[:nr])
if nw < nr {
if Log != nil {
- Log.Err("write: " + err.String())
+ Log.Err("write: " + err.Error())
}
break
}
@@ -97,15 +96,17 @@
}
defer close(ch)
- p := xml.NewParser(r)
+ p := xml.NewDecoder(r)
+ p.Context.Map[""] = NsClient
+ p.Context.Map["stream"] = NsStream
Loop:
for {
// Sniff the next token on the stream.
t, err := p.Token()
if t == nil {
- if err != os.EOF {
+ if err != io.EOF {
if Log != nil {
- Log.Err("read: " + err.String())
+ Log.Err("read: " + err.Error())
}
}
break
@@ -124,7 +125,7 @@
if err != nil {
if Log != nil {
Log.Err("unmarshal stream: " +
- err.String())
+ err.Error())
}
break Loop
}
@@ -139,11 +140,11 @@
case NsSASL + " challenge", NsSASL + " failure",
NsSASL + " success":
obj = &auth{}
- case "jabber:client iq":
+ case NsClient + " iq":
obj = &Iq{}
- case "jabber:client message":
+ case NsClient + " message":
obj = &Message{}
- case "jabber:client presence":
+ case NsClient + " presence":
obj = &Presence{}
default:
obj = &Generic{}
@@ -154,10 +155,10 @@
}
// Read the complete XML stanza.
- err = p.Unmarshal(obj, &se)
+ err = p.DecodeElement(obj, &se)
if err != nil {
if Log != nil {
- Log.Err("unmarshal: " + err.String())
+ Log.Err("unmarshal: " + err.Error())
}
break Loop
}
@@ -170,8 +171,9 @@
if err != nil {
if Log != nil {
Log.Err("ext unmarshal: " +
- err.String())
+ err.Error())
}
+ fmt.Printf("ext: %v\n", err)
break Loop
}
}
@@ -181,14 +183,14 @@
}
}
-func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) os.Error {
+func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) error {
// Now parse the stanza's innerxml to find the string that we
// can unmarshal this nested element from.
reader := strings.NewReader(st.innerxml())
- p := xml.NewParser(reader)
+ p := xml.NewDecoder(reader)
for {
t, err := p.Token()
- if err == os.EOF {
+ if err == io.EOF {
break
}
if err != nil {
@@ -201,7 +203,7 @@
// Unmarshal the nested element and
// stuff it back into the stanza.
- err := p.Unmarshal(nested, &se)
+ err := p.DecodeElement(nested, &se)
if err != nil {
return err
}
@@ -225,13 +227,26 @@
}
}(w)
+ enc := xml.NewEncoder(w)
+ enc.Context.Map[NsClient] = ""
+ enc.Context.Map[NsStream] = "stream"
+
for obj := range ch {
- err := xml.Marshal(w, obj)
- if err != nil {
- if Log != nil {
- Log.Err("write: " + err.String())
+ if st, ok := obj.(*stream); ok {
+ _, err := w.Write([]byte(st.String()))
+ if err != nil {
+ if Log != nil {
+ Log.Err("write: " + err.Error())
+ }
}
- break
+ } else {
+ err := enc.Encode(obj)
+ if err != nil {
+ if Log != nil {
+ Log.Err("marshal: " + err.Error())
+ }
+ break
+ }
}
}
}
@@ -277,7 +292,7 @@
}
if handlers[st.GetId()] != nil {
f := handlers[st.GetId()]
- handlers[st.GetId()] = nil
+ delete(handlers, st.GetId())
send = f(st)
}
if send {
@@ -410,10 +425,6 @@
cl.socket = tls
cl.socketSync.Wait()
- // Reset the read timeout on the (underlying) socket so the
- // reader doesn't get woken up unnecessarily.
- tcp.SetReadTimeout(0)
-
if Log != nil {
Log.Info("TLS negotiation succeeded.")
}
@@ -464,7 +475,7 @@
if err != nil {
if Log != nil {
Log.Err("SASL challenge decode: " +
- err.String())
+ err.Error())
}
return
}
@@ -531,7 +542,7 @@
cnonce, err := rand.Int(rand.Reader, randSize)
if err != nil {
if Log != nil {
- Log.Err("SASL rand: " + err.String())
+ Log.Err("SASL rand: " + err.Error())
}
return
}
@@ -609,7 +620,7 @@
h := func(text string) []byte {
h := md5.New()
h.Write([]byte(text))
- return h.Sum()
+ return h.Sum(nil)
}
hex := func(bytes []byte) string {
return fmt.Sprintf("%x", bytes)
@@ -636,14 +647,20 @@
}
msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
f := func(st Stanza) bool {
- if st.GetType() == "error" {
+ iq, ok := st.(*Iq)
+ if !ok {
+ if Log != nil {
+ Log.Err("non-iq response")
+ }
+ }
+ if iq.Type == "error" {
if Log != nil {
Log.Err("Resource binding failed")
}
return false
}
var bindRepl *bindIq
- for _, ele := range st.GetNested() {
+ for _, ele := range iq.Nested {
if b, ok := ele.(*bindIq); ok {
bindRepl = b
break
@@ -652,7 +669,7 @@
if bindRepl == nil {
if Log != nil {
Log.Err(fmt.Sprintf("Bad bind reply: %v",
- st))
+ iq))
}
return false
}
@@ -664,9 +681,10 @@
return false
}
jid := new(JID)
- if !jid.Set(*jidStr) {
+ if err := jid.Set(*jidStr); err != nil {
if Log != nil {
- Log.Err("Can't parse JID " + *jidStr)
+ Log.Err(fmt.Sprintf("Can't parse JID %s: %s",
+ *jidStr, err))
}
return false
}