stream.go
changeset 98 c9cc4eda6dce
parent 72 53f15893a1a7
child 100 24231ff0016c
equal deleted inserted replaced
88:d2ec96c80efe 98:c9cc4eda6dce
     9 // those events to the client.
     9 // those events to the client.
    10 
    10 
    11 package xmpp
    11 package xmpp
    12 
    12 
    13 import (
    13 import (
    14 	"big"
       
    15 	"crypto/md5"
    14 	"crypto/md5"
    16 	"crypto/rand"
    15 	"crypto/rand"
    17 	"crypto/tls"
    16 	"crypto/tls"
    18 	"encoding/base64"
    17 	"encoding/base64"
       
    18 	"encoding/xml"
    19 	"fmt"
    19 	"fmt"
    20 	"io"
    20 	"io"
       
    21 	"log/syslog"
       
    22 	"math/big"
    21 	"net"
    23 	"net"
    22 	"os"
       
    23 	"regexp"
    24 	"regexp"
    24 	"strings"
    25 	"strings"
    25 	"syslog"
       
    26 	"time"
    26 	"time"
    27 	"xml"
       
    28 )
    27 )
    29 
    28 
    30 // Callback to handle a stanza with a particular id.
    29 // Callback to handle a stanza with a particular id.
    31 type stanzaHandler struct {
    30 type stanzaHandler struct {
    32 	id string
    31 	id string
    37 // BUG(cjyar) Review all these *Client receiver methods. They should
    36 // BUG(cjyar) Review all these *Client receiver methods. They should
    38 // probably either all be receivers, or none.
    37 // probably either all be receivers, or none.
    39 
    38 
    40 func (cl *Client) readTransport(w io.WriteCloser) {
    39 func (cl *Client) readTransport(w io.WriteCloser) {
    41 	defer w.Close()
    40 	defer w.Close()
    42 	cl.socket.SetReadTimeout(1e8)
       
    43 	p := make([]byte, 1024)
    41 	p := make([]byte, 1024)
    44 	for {
    42 	for {
    45 		if cl.socket == nil {
    43 		if cl.socket == nil {
    46 			cl.waitForSocket()
    44 			cl.waitForSocket()
    47 		}
    45 		}
       
    46 		cl.socket.SetReadDeadline(time.Now().Add(time.Second))
    48 		nr, err := cl.socket.Read(p)
    47 		nr, err := cl.socket.Read(p)
    49 		if nr == 0 {
    48 		if nr == 0 {
    50 			if errno, ok := err.(*net.OpError); ok {
    49 			if errno, ok := err.(*net.OpError); ok {
    51 				if errno.Timeout() {
    50 				if errno.Timeout() {
    52 					continue
    51 					continue
    53 				}
    52 				}
    54 			}
    53 			}
    55 			if Log != nil {
    54 			if Log != nil {
    56 				Log.Err("read: " + err.String())
    55 				Log.Err("read: " + err.Error())
    57 			}
    56 			}
    58 			break
    57 			break
    59 		}
    58 		}
    60 		nw, err := w.Write(p[:nr])
    59 		nw, err := w.Write(p[:nr])
    61 		if nw < nr {
    60 		if nw < nr {
    62 			if Log != nil {
    61 			if Log != nil {
    63 				Log.Err("read: " + err.String())
    62 				Log.Err("read: " + err.Error())
    64 			}
    63 			}
    65 			break
    64 			break
    66 		}
    65 		}
    67 	}
    66 	}
    68 }
    67 }
    72 	p := make([]byte, 1024)
    71 	p := make([]byte, 1024)
    73 	for {
    72 	for {
    74 		nr, err := r.Read(p)
    73 		nr, err := r.Read(p)
    75 		if nr == 0 {
    74 		if nr == 0 {
    76 			if Log != nil {
    75 			if Log != nil {
    77 				Log.Err("write: " + err.String())
    76 				Log.Err("write: " + err.Error())
    78 			}
    77 			}
    79 			break
    78 			break
    80 		}
    79 		}
    81 		nw, err := cl.socket.Write(p[:nr])
    80 		nw, err := cl.socket.Write(p[:nr])
    82 		if nw < nr {
    81 		if nw < nr {
    83 			if Log != nil {
    82 			if Log != nil {
    84 				Log.Err("write: " + err.String())
    83 				Log.Err("write: " + err.Error())
    85 			}
    84 			}
    86 			break
    85 			break
    87 		}
    86 		}
    88 	}
    87 	}
    89 }
    88 }
    95 		go tee(r, pw, "S: ")
    94 		go tee(r, pw, "S: ")
    96 		r = pr
    95 		r = pr
    97 	}
    96 	}
    98 	defer close(ch)
    97 	defer close(ch)
    99 
    98 
   100 	p := xml.NewParser(r)
    99 	p := xml.NewDecoder(r)
       
   100 	p.Context.Map[""] = NsClient
       
   101 	p.Context.Map["stream"] = NsStream
   101 Loop:
   102 Loop:
   102 	for {
   103 	for {
   103 		// Sniff the next token on the stream.
   104 		// Sniff the next token on the stream.
   104 		t, err := p.Token()
   105 		t, err := p.Token()
   105 		if t == nil {
   106 		if t == nil {
   106 			if err != os.EOF {
   107 			if err != io.EOF {
   107 				if Log != nil {
   108 				if Log != nil {
   108 					Log.Err("read: " + err.String())
   109 					Log.Err("read: " + err.Error())
   109 				}
   110 				}
   110 			}
   111 			}
   111 			break
   112 			break
   112 		}
   113 		}
   113 		var se xml.StartElement
   114 		var se xml.StartElement
   122 		case NsStream + " stream":
   123 		case NsStream + " stream":
   123 			st, err := parseStream(se)
   124 			st, err := parseStream(se)
   124 			if err != nil {
   125 			if err != nil {
   125 				if Log != nil {
   126 				if Log != nil {
   126 					Log.Err("unmarshal stream: " +
   127 					Log.Err("unmarshal stream: " +
   127 						err.String())
   128 						err.Error())
   128 				}
   129 				}
   129 				break Loop
   130 				break Loop
   130 			}
   131 			}
   131 			ch <- st
   132 			ch <- st
   132 			continue
   133 			continue
   137 		case NsTLS + " proceed", NsTLS + " failure":
   138 		case NsTLS + " proceed", NsTLS + " failure":
   138 			obj = &starttls{}
   139 			obj = &starttls{}
   139 		case NsSASL + " challenge", NsSASL + " failure",
   140 		case NsSASL + " challenge", NsSASL + " failure",
   140 			NsSASL + " success":
   141 			NsSASL + " success":
   141 			obj = &auth{}
   142 			obj = &auth{}
   142 		case "jabber:client iq":
   143 		case NsClient + " iq":
   143 			obj = &Iq{}
   144 			obj = &Iq{}
   144 		case "jabber:client message":
   145 		case NsClient + " message":
   145 			obj = &Message{}
   146 			obj = &Message{}
   146 		case "jabber:client presence":
   147 		case NsClient + " presence":
   147 			obj = &Presence{}
   148 			obj = &Presence{}
   148 		default:
   149 		default:
   149 			obj = &Generic{}
   150 			obj = &Generic{}
   150 			if Log != nil {
   151 			if Log != nil {
   151 				Log.Notice("Ignoring unrecognized: " +
   152 				Log.Notice("Ignoring unrecognized: " +
   152 					se.Name.Space + " " + se.Name.Local)
   153 					se.Name.Space + " " + se.Name.Local)
   153 			}
   154 			}
   154 		}
   155 		}
   155 
   156 
   156 		// Read the complete XML stanza.
   157 		// Read the complete XML stanza.
   157 		err = p.Unmarshal(obj, &se)
   158 		err = p.DecodeElement(obj, &se)
   158 		if err != nil {
   159 		if err != nil {
   159 			if Log != nil {
   160 			if Log != nil {
   160 				Log.Err("unmarshal: " + err.String())
   161 				Log.Err("unmarshal: " + err.Error())
   161 			}
   162 			}
   162 			break Loop
   163 			break Loop
   163 		}
   164 		}
   164 
   165 
   165 		// If it's a Stanza, we try to unmarshal its innerxml
   166 		// If it's a Stanza, we try to unmarshal its innerxml
   168 		if st, ok := obj.(Stanza); ok {
   169 		if st, ok := obj.(Stanza); ok {
   169 			err = parseExtended(st, extStanza)
   170 			err = parseExtended(st, extStanza)
   170 			if err != nil {
   171 			if err != nil {
   171 				if Log != nil {
   172 				if Log != nil {
   172 					Log.Err("ext unmarshal: " +
   173 					Log.Err("ext unmarshal: " +
   173 						err.String())
   174 						err.Error())
   174 				}
   175 				}
       
   176 				fmt.Printf("ext: %v\n", err)
   175 				break Loop
   177 				break Loop
   176 			}
   178 			}
   177 		}
   179 		}
   178 
   180 
   179 		// Put it on the channel.
   181 		// Put it on the channel.
   180 		ch <- obj
   182 		ch <- obj
   181 	}
   183 	}
   182 }
   184 }
   183 
   185 
   184 func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) os.Error {
   186 func parseExtended(st Stanza, extStanza map[string]func(*xml.Name) interface{}) error {
   185 	// Now parse the stanza's innerxml to find the string that we
   187 	// Now parse the stanza's innerxml to find the string that we
   186 	// can unmarshal this nested element from.
   188 	// can unmarshal this nested element from.
   187 	reader := strings.NewReader(st.innerxml())
   189 	reader := strings.NewReader(st.innerxml())
   188 	p := xml.NewParser(reader)
   190 	p := xml.NewDecoder(reader)
   189 	for {
   191 	for {
   190 		t, err := p.Token()
   192 		t, err := p.Token()
   191 		if err == os.EOF {
   193 		if err == io.EOF {
   192 			break
   194 			break
   193 		}
   195 		}
   194 		if err != nil {
   196 		if err != nil {
   195 			return err
   197 			return err
   196 		}
   198 		}
   199 				// Call the indicated constructor.
   201 				// Call the indicated constructor.
   200 				nested := con(&se.Name)
   202 				nested := con(&se.Name)
   201 
   203 
   202 				// Unmarshal the nested element and
   204 				// Unmarshal the nested element and
   203 				// stuff it back into the stanza.
   205 				// stuff it back into the stanza.
   204 				err := p.Unmarshal(nested, &se)
   206 				err := p.DecodeElement(nested, &se)
   205 				if err != nil {
   207 				if err != nil {
   206 					return err
   208 					return err
   207 				}
   209 				}
   208 				st.addNested(nested)
   210 				st.addNested(nested)
   209 			}
   211 			}
   223 		if c, ok := w.(io.Closer); ok {
   225 		if c, ok := w.(io.Closer); ok {
   224 			c.Close()
   226 			c.Close()
   225 		}
   227 		}
   226 	}(w)
   228 	}(w)
   227 
   229 
       
   230 	enc := xml.NewEncoder(w)
       
   231 	enc.Context.Map[NsClient] = ""
       
   232 	enc.Context.Map[NsStream] = "stream"
       
   233 
   228 	for obj := range ch {
   234 	for obj := range ch {
   229 		err := xml.Marshal(w, obj)
   235 		if st, ok := obj.(*stream); ok {
   230 		if err != nil {
   236 			_, err := w.Write([]byte(st.String()))
   231 			if Log != nil {
   237 			if err != nil {
   232 				Log.Err("write: " + err.String())
   238 				if Log != nil {
   233 			}
   239 					Log.Err("write: " + err.Error())
   234 			break
   240 				}
       
   241 			}
       
   242 		} else {
       
   243 			err := enc.Encode(obj)
       
   244 			if err != nil {
       
   245 				if Log != nil {
       
   246 					Log.Err("marshal: " + err.Error())
       
   247 				}
       
   248 				break
       
   249 			}
   235 		}
   250 		}
   236 	}
   251 	}
   237 }
   252 }
   238 
   253 
   239 func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- Stanza) {
   254 func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- Stanza) {
   275 				}
   290 				}
   276 				continue
   291 				continue
   277 			}
   292 			}
   278 			if handlers[st.GetId()] != nil {
   293 			if handlers[st.GetId()] != nil {
   279 				f := handlers[st.GetId()]
   294 				f := handlers[st.GetId()]
   280 				handlers[st.GetId()] = nil
   295 				delete(handlers, st.GetId())
   281 				send = f(st)
   296 				send = f(st)
   282 			}
   297 			}
   283 			if send {
   298 			if send {
   284 				cliOut <- st
   299 				cliOut <- st
   285 			}
   300 			}
   408 	// for it to signal that it's working again.
   423 	// for it to signal that it's working again.
   409 	cl.socketSync.Add(1)
   424 	cl.socketSync.Add(1)
   410 	cl.socket = tls
   425 	cl.socket = tls
   411 	cl.socketSync.Wait()
   426 	cl.socketSync.Wait()
   412 
   427 
   413 	// Reset the read timeout on the (underlying) socket so the
       
   414 	// reader doesn't get woken up unnecessarily.
       
   415 	tcp.SetReadTimeout(0)
       
   416 
       
   417 	if Log != nil {
   428 	if Log != nil {
   418 		Log.Info("TLS negotiation succeeded.")
   429 		Log.Info("TLS negotiation succeeded.")
   419 	}
   430 	}
   420 	cl.Features = nil
   431 	cl.Features = nil
   421 
   432 
   462 		b64 := base64.StdEncoding
   473 		b64 := base64.StdEncoding
   463 		str, err := b64.DecodeString(srv.Chardata)
   474 		str, err := b64.DecodeString(srv.Chardata)
   464 		if err != nil {
   475 		if err != nil {
   465 			if Log != nil {
   476 			if Log != nil {
   466 				Log.Err("SASL challenge decode: " +
   477 				Log.Err("SASL challenge decode: " +
   467 					err.String())
   478 					err.Error())
   468 			}
   479 			}
   469 			return
   480 			return
   470 		}
   481 		}
   471 		srvMap := parseSasl(string(str))
   482 		srvMap := parseSasl(string(str))
   472 
   483 
   529 	randSize := big.NewInt(0)
   540 	randSize := big.NewInt(0)
   530 	randSize.Lsh(big.NewInt(1), 64)
   541 	randSize.Lsh(big.NewInt(1), 64)
   531 	cnonce, err := rand.Int(rand.Reader, randSize)
   542 	cnonce, err := rand.Int(rand.Reader, randSize)
   532 	if err != nil {
   543 	if err != nil {
   533 		if Log != nil {
   544 		if Log != nil {
   534 			Log.Err("SASL rand: " + err.String())
   545 			Log.Err("SASL rand: " + err.Error())
   535 		}
   546 		}
   536 		return
   547 		return
   537 	}
   548 	}
   538 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   549 	cnonceStr := fmt.Sprintf("%016x", cnonce)
   539 
   550 
   607 func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
   618 func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
   608 authenticate, digestUri, nonceCountStr string) string {
   619 authenticate, digestUri, nonceCountStr string) string {
   609 	h := func(text string) []byte {
   620 	h := func(text string) []byte {
   610 		h := md5.New()
   621 		h := md5.New()
   611 		h.Write([]byte(text))
   622 		h.Write([]byte(text))
   612 		return h.Sum()
   623 		return h.Sum(nil)
   613 	}
   624 	}
   614 	hex := func(bytes []byte) string {
   625 	hex := func(bytes []byte) string {
   615 		return fmt.Sprintf("%x", bytes)
   626 		return fmt.Sprintf("%x", bytes)
   616 	}
   627 	}
   617 	kd := func(secret, data string) []byte {
   628 	kd := func(secret, data string) []byte {
   634 	if res != "" {
   645 	if res != "" {
   635 		bindReq.Resource = &res
   646 		bindReq.Resource = &res
   636 	}
   647 	}
   637 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   648 	msg := &Iq{Type: "set", Id: <-Id, Nested: []interface{}{bindReq}}
   638 	f := func(st Stanza) bool {
   649 	f := func(st Stanza) bool {
   639 		if st.GetType() == "error" {
   650 		iq, ok := st.(*Iq)
       
   651 		if !ok {
       
   652 			if Log != nil {
       
   653 				Log.Err("non-iq response")
       
   654 			}
       
   655 		}
       
   656 		if iq.Type == "error" {
   640 			if Log != nil {
   657 			if Log != nil {
   641 				Log.Err("Resource binding failed")
   658 				Log.Err("Resource binding failed")
   642 			}
   659 			}
   643 			return false
   660 			return false
   644 		}
   661 		}
   645 		var bindRepl *bindIq
   662 		var bindRepl *bindIq
   646 		for _, ele := range st.GetNested() {
   663 		for _, ele := range iq.Nested {
   647 			if b, ok := ele.(*bindIq); ok {
   664 			if b, ok := ele.(*bindIq); ok {
   648 				bindRepl = b
   665 				bindRepl = b
   649 				break
   666 				break
   650 			}
   667 			}
   651 		}
   668 		}
   652 		if bindRepl == nil {
   669 		if bindRepl == nil {
   653 			if Log != nil {
   670 			if Log != nil {
   654 				Log.Err(fmt.Sprintf("Bad bind reply: %v",
   671 				Log.Err(fmt.Sprintf("Bad bind reply: %v",
   655 					st))
   672 					iq))
   656 			}
   673 			}
   657 			return false
   674 			return false
   658 		}
   675 		}
   659 		jidStr := bindRepl.Jid
   676 		jidStr := bindRepl.Jid
   660 		if jidStr == nil || *jidStr == "" {
   677 		if jidStr == nil || *jidStr == "" {
   662 				Log.Err("Can't bind empty resource")
   679 				Log.Err("Can't bind empty resource")
   663 			}
   680 			}
   664 			return false
   681 			return false
   665 		}
   682 		}
   666 		jid := new(JID)
   683 		jid := new(JID)
   667 		if !jid.Set(*jidStr) {
   684 		if err := jid.Set(*jidStr); err != nil {
   668 			if Log != nil {
   685 			if Log != nil {
   669 				Log.Err("Can't parse JID " + *jidStr)
   686 				Log.Err(fmt.Sprintf("Can't parse JID %s: %s",
       
   687 						*jidStr, err))
   670 			}
   688 			}
   671 			return false
   689 			return false
   672 		}
   690 		}
   673 		cl.Jid = *jid
   691 		cl.Jid = *jid
   674 		if Log != nil {
   692 		if Log != nil {