stream.go
changeset 11 48be1ae93fd4
parent 10 f38b0ee7b1c1
child 12 122ab6208c3c
equal deleted inserted replaced
10:f38b0ee7b1c1 11:48be1ae93fd4
     8 // we respond to XMPP events on behalf of the library client).
     8 // we respond to XMPP events on behalf of the library client).
     9 
     9 
    10 package xmpp
    10 package xmpp
    11 
    11 
    12 import (
    12 import (
       
    13 	"big"
       
    14 	"crypto/md5"
       
    15 	"crypto/rand"
    13 	"crypto/tls"
    16 	"crypto/tls"
       
    17 	"encoding/base64"
       
    18 	"fmt"
    14 	"io"
    19 	"io"
    15 	"log"
    20 	"log"
    16 	"net"
    21 	"net"
    17 	"os"
    22 	"os"
       
    23 	"regexp"
       
    24 	"strings"
    18 	"time"
    25 	"time"
    19 	"xml"
    26 	"xml"
    20 )
    27 )
    21 
    28 
    22 func (cl *Client) readTransport(w io.Writer) {
    29 func (cl *Client) readTransport(w io.Writer) {
   102 			obj = &StreamError{}
   109 			obj = &StreamError{}
   103 		case nsStream + " features":
   110 		case nsStream + " features":
   104 			obj = &Features{}
   111 			obj = &Features{}
   105 		case nsTLS + " proceed", nsTLS + " failure":
   112 		case nsTLS + " proceed", nsTLS + " failure":
   106 			obj = &starttls{}
   113 			obj = &starttls{}
       
   114 		case nsSASL + " challenge", nsSASL + " failure",
       
   115 			nsSASL + " success":
       
   116 			obj = &auth{}
   107 		default:
   117 		default:
   108 			obj = &Unrecognized{}
   118 			obj = &Unrecognized{}
   109 			log.Printf("Ignoring unrecognized: %s %s\n",
   119 			log.Printf("Ignoring unrecognized: %s %s\n",
   110 				se.Name.Space, se.Name.Local)
   120 				se.Name.Space, se.Name.Local)
   111 		}
   121 		}
   154 			break
   164 			break
   155 		}
   165 		}
   156 	}
   166 	}
   157 }
   167 }
   158 
   168 
   159 func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) {
   169 func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) {
   160 	defer tryClose(srvIn, cliOut)
   170 	defer tryClose(srvIn, cliOut)
   161 
   171 
   162 	for x := range srvIn {
   172 	for x := range srvIn {
   163 		switch obj := x.(type) {
   173 		switch obj := x.(type) {
   164 		case *Stream:
   174 		case *Stream:
   165 			handleStream(obj)
   175 			handleStream(obj)
   166 		case *Features:
   176 		case *Features:
   167 			handleFeatures(obj, srvOut)
   177 			cl.handleFeatures(obj)
   168 		case *starttls:
   178 		case *starttls:
   169 			cl.handleTls(obj)
   179 			cl.handleTls(obj)
       
   180 		case *auth:
       
   181 			cl.handleSasl(obj)
   170 		default:
   182 		default:
   171 			cliOut <- x
   183 			cliOut <- x
   172 		}
   184 		}
   173 	}
   185 	}
   174 }
   186 }
   182 }
   194 }
   183 
   195 
   184 func handleStream(ss *Stream) {
   196 func handleStream(ss *Stream) {
   185 }
   197 }
   186 
   198 
   187 func handleFeatures(fe *Features, srvOut chan<- interface{}) {
   199 func (cl *Client) handleFeatures(fe *Features) {
   188 	if fe.Starttls != nil {
   200 	if fe.Starttls != nil {
   189 		start := &starttls{XMLName: xml.Name{Space: nsTLS,
   201 		start := &starttls{XMLName: xml.Name{Space: nsTLS,
   190 			Local: "starttls"}}
   202 			Local: "starttls"}}
   191 		srvOut <- start
   203 		cl.xmlOut <- start
       
   204 		return
       
   205 	}
       
   206 
       
   207 	if len(fe.Mechanisms.Mechanism) > 0 {
       
   208 		cl.chooseSasl(fe)
       
   209 		return
   192 	}
   210 	}
   193 }
   211 }
   194 
   212 
   195 // readTransport() is running concurrently. We need to stop it,
   213 // readTransport() is running concurrently. We need to stop it,
   196 // negotiate TLS, then start it again. It calls waitForSocket() in
   214 // negotiate TLS, then start it again. It calls waitForSocket() in
   237 	}
   255 	}
   238 
   256 
   239 	// Signal that we're going back to the read loop.
   257 	// Signal that we're going back to the read loop.
   240 	cl.socketSync.Done()
   258 	cl.socketSync.Done()
   241 }
   259 }
       
   260 
       
   261 func (cl *Client) chooseSasl(fe *Features) {
       
   262 	var digestMd5 bool
       
   263 	for _, m := range(fe.Mechanisms.Mechanism) {
       
   264 		switch strings.ToLower(m) {
       
   265 		case "digest-md5":
       
   266 			digestMd5 = true
       
   267 		}
       
   268 	}
       
   269 
       
   270 	if digestMd5 {
       
   271 		auth := &auth{XMLName: xml.Name{Space: nsSASL, Local:
       
   272 				"auth"}, Mechanism: "DIGEST-MD5"}
       
   273 		cl.xmlOut <- auth
       
   274 	}
       
   275 }
       
   276 
       
   277 func (cl *Client) handleSasl(srv *auth) {
       
   278 	switch strings.ToLower(srv.XMLName.Local) {
       
   279 	case "challenge":
       
   280 		b64 := base64.StdEncoding
       
   281 		str, err := b64.DecodeString(srv.Chardata)
       
   282 		if err != nil {
       
   283 			log.Printf("SASL challenge decode: %s",
       
   284 				err.String())
       
   285 			return;
       
   286 		}
       
   287 		srvMap := parseSasl(string(str))
       
   288 
       
   289 		if cl.saslExpected == "" {
       
   290 			cl.saslDigest1(srvMap)
       
   291 		} else {
       
   292 			cl.saslDigest2(srvMap)
       
   293 		}
       
   294 	case "failure":
       
   295 		log.Println("SASL authentication failed")
       
   296 	case "success":
       
   297 		log.Println("SASL authentication succeeded")
       
   298 		ss := &Stream{To: cl.Jid.Domain, Version: Version}
       
   299 		cl.xmlOut <- ss
       
   300 	}
       
   301 }
       
   302 
       
   303 func (cl *Client) saslDigest1(srvMap map[string] string) {
       
   304 	// Make sure it supports qop=auth
       
   305 	var hasAuth bool
       
   306 	for _, qop := range(strings.Fields(srvMap["qop"])) {
       
   307 		if qop == "auth" {
       
   308 			hasAuth = true
       
   309 		}
       
   310 	}
       
   311 	if !hasAuth {
       
   312 		log.Println("Server doesn't support SASL auth")
       
   313 		return;
       
   314 	}
       
   315 
       
   316 	// Pick a realm.
       
   317 	var realm string
       
   318 	if srvMap["realm"] != "" {
       
   319 		realm = strings.Fields(srvMap["realm"])[0]
       
   320 	}
       
   321 
       
   322 	passwd := cl.password
       
   323 	nonce := srvMap["nonce"]
       
   324 	digestUri := "xmpp/" + cl.Jid.Domain
       
   325 	nonceCount := int32(1)
       
   326 	nonceCountStr := fmt.Sprintf("%08x", nonceCount)
       
   327 
       
   328 	// Begin building the response. Username is
       
   329 	// user@domain or just domain.
       
   330 	var username string
       
   331 	if cl.Jid.Node == nil {
       
   332 		username = cl.Jid.Domain
       
   333 	} else {
       
   334 		username = *cl.Jid.Node
       
   335 	}
       
   336 
       
   337 	// Generate our own nonce from random data.
       
   338 	randSize := big.NewInt(0)
       
   339 	randSize.Lsh(big.NewInt(1), 64)
       
   340 	cnonce, err := rand.Int(rand.Reader, randSize)
       
   341 	if err != nil {
       
   342 		log.Println("SASL rand: %s", err.String())
       
   343 		return
       
   344 	}
       
   345 	cnonceStr := fmt.Sprintf("%016x", cnonce)
       
   346 
       
   347 	/* Now encode the actual password response, as well as the
       
   348 	 * expected next challenge from the server. */
       
   349 	response := saslDigestResponse(username, realm, passwd, nonce,
       
   350 		cnonceStr, "AUTHENTICATE", digestUri, nonceCountStr)
       
   351 	next := saslDigestResponse(username, realm, passwd, nonce,
       
   352 		cnonceStr, "", digestUri, nonceCountStr)
       
   353 	cl.saslExpected = next
       
   354 
       
   355 	// Build the map which will be encoded.
       
   356 	clMap := make(map[string]string)
       
   357 	clMap["realm"] = `"` + realm + `"`
       
   358 	clMap["username"] = `"` + username + `"`
       
   359 	clMap["nonce"] = `"` + nonce + `"`
       
   360 	clMap["cnonce"] = `"` + cnonceStr + `"`
       
   361 	clMap["nc"] =  nonceCountStr
       
   362 	clMap["qop"] = "auth"
       
   363 	clMap["digest-uri"] = `"` + digestUri + `"`
       
   364 	clMap["response"] = response
       
   365 	if srvMap["charset"] == "utf-8" {
       
   366 		clMap["charset"] = "utf-8"
       
   367 	}
       
   368 
       
   369 	// Encode the map and send it.
       
   370 	clStr := packSasl(clMap)
       
   371 	b64 := base64.StdEncoding
       
   372 	clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
       
   373 			"response"}, Chardata:
       
   374 		b64.EncodeToString([]byte(clStr))}
       
   375 	cl.xmlOut <- clObj
       
   376 }
       
   377 
       
   378 func (cl *Client) saslDigest2(srvMap map[string] string) {
       
   379 	if cl.saslExpected == srvMap["rspauth"] {
       
   380 		clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
       
   381 				"response"}}
       
   382 		cl.xmlOut <- clObj
       
   383 	} else {
       
   384 		clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
       
   385 				"failure"}, Any:
       
   386 			&Unrecognized{XMLName: xml.Name{Space: nsSASL,
       
   387 				Local: "abort"}}}
       
   388 		cl.xmlOut <- clObj
       
   389 	}
       
   390 }
       
   391 
       
   392 // Takes a string like `key1=value1,key2="value2"...` and returns a
       
   393 // key/value map.
       
   394 func parseSasl(in string) map[string]string {
       
   395 	re := regexp.MustCompile(`([^=]+)="?([^",]+)"?,?`)
       
   396 	strs := re.FindAllStringSubmatch(in, -1)
       
   397 	m := make(map[string]string)
       
   398 	for _, pair := range(strs) {
       
   399 		key := strings.ToLower(string(pair[1]))
       
   400 		value := string(pair[2])
       
   401 		m[key] = value
       
   402 	}
       
   403 	return m
       
   404 }
       
   405 
       
   406 func packSasl(m map[string]string) string {
       
   407 	var terms []string
       
   408 	for key, value := range(m) {
       
   409 		if key == "" || value == "" || value == `""` {
       
   410 			continue
       
   411 		}
       
   412 		terms = append(terms, key + "=" + value)
       
   413 	}
       
   414 	return strings.Join(terms, ",")
       
   415 }
       
   416 
       
   417 func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
       
   418 	authenticate, digestUri, nonceCountStr string) string {
       
   419 	h := func(text string) []byte {
       
   420 		h := md5.New()
       
   421 		h.Write([]byte(text))
       
   422 		return h.Sum()
       
   423 	}
       
   424 	hex := func(bytes []byte) string {
       
   425 		return fmt.Sprintf("%x", bytes)
       
   426 	}
       
   427 	kd := func(secret, data string) []byte {
       
   428 		return h(secret + ":" + data)
       
   429 	}
       
   430 
       
   431 	a1 := string(h(username + ":" + realm + ":" + passwd)) + ":" +
       
   432 		nonce + ":" + cnonceStr
       
   433 	a2 := authenticate + ":" + digestUri
       
   434 	response := hex(kd(hex(h(a1)), nonce + ":" +
       
   435 		nonceCountStr + ":" + cnonceStr + ":auth:" +
       
   436 		hex(h(a2))))
       
   437 	return response
       
   438 }