diff -r f38b0ee7b1c1 -r 48be1ae93fd4 stream.go --- a/stream.go Mon Dec 26 18:07:14 2011 -0700 +++ b/stream.go Tue Dec 27 15:36:07 2011 -0700 @@ -10,11 +10,18 @@ package xmpp import ( + "big" + "crypto/md5" + "crypto/rand" "crypto/tls" + "encoding/base64" + "fmt" "io" "log" "net" "os" + "regexp" + "strings" "time" "xml" ) @@ -104,6 +111,9 @@ obj = &Features{} case nsTLS + " proceed", nsTLS + " failure": obj = &starttls{} + case nsSASL + " challenge", nsSASL + " failure", + nsSASL + " success": + obj = &auth{} default: obj = &Unrecognized{} log.Printf("Ignoring unrecognized: %s %s\n", @@ -156,7 +166,7 @@ } } -func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) { +func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) { defer tryClose(srvIn, cliOut) for x := range srvIn { @@ -164,9 +174,11 @@ case *Stream: handleStream(obj) case *Features: - handleFeatures(obj, srvOut) + cl.handleFeatures(obj) case *starttls: cl.handleTls(obj) + case *auth: + cl.handleSasl(obj) default: cliOut <- x } @@ -184,11 +196,17 @@ func handleStream(ss *Stream) { } -func handleFeatures(fe *Features, srvOut chan<- interface{}) { +func (cl *Client) handleFeatures(fe *Features) { if fe.Starttls != nil { start := &starttls{XMLName: xml.Name{Space: nsTLS, Local: "starttls"}} - srvOut <- start + cl.xmlOut <- start + return + } + + if len(fe.Mechanisms.Mechanism) > 0 { + cl.chooseSasl(fe) + return } } @@ -239,3 +257,182 @@ // Signal that we're going back to the read loop. cl.socketSync.Done() } + +func (cl *Client) chooseSasl(fe *Features) { + var digestMd5 bool + for _, m := range(fe.Mechanisms.Mechanism) { + switch strings.ToLower(m) { + case "digest-md5": + digestMd5 = true + } + } + + if digestMd5 { + auth := &auth{XMLName: xml.Name{Space: nsSASL, Local: + "auth"}, Mechanism: "DIGEST-MD5"} + cl.xmlOut <- auth + } +} + +func (cl *Client) handleSasl(srv *auth) { + switch strings.ToLower(srv.XMLName.Local) { + case "challenge": + b64 := base64.StdEncoding + str, err := b64.DecodeString(srv.Chardata) + if err != nil { + log.Printf("SASL challenge decode: %s", + err.String()) + return; + } + srvMap := parseSasl(string(str)) + + if cl.saslExpected == "" { + cl.saslDigest1(srvMap) + } else { + cl.saslDigest2(srvMap) + } + case "failure": + log.Println("SASL authentication failed") + case "success": + log.Println("SASL authentication succeeded") + ss := &Stream{To: cl.Jid.Domain, Version: Version} + cl.xmlOut <- ss + } +} + +func (cl *Client) saslDigest1(srvMap map[string] string) { + // Make sure it supports qop=auth + var hasAuth bool + for _, qop := range(strings.Fields(srvMap["qop"])) { + if qop == "auth" { + hasAuth = true + } + } + if !hasAuth { + log.Println("Server doesn't support SASL auth") + return; + } + + // Pick a realm. + var realm string + if srvMap["realm"] != "" { + realm = strings.Fields(srvMap["realm"])[0] + } + + passwd := cl.password + nonce := srvMap["nonce"] + digestUri := "xmpp/" + cl.Jid.Domain + nonceCount := int32(1) + nonceCountStr := fmt.Sprintf("%08x", nonceCount) + + // Begin building the response. Username is + // user@domain or just domain. + var username string + if cl.Jid.Node == nil { + username = cl.Jid.Domain + } else { + username = *cl.Jid.Node + } + + // Generate our own nonce from random data. + randSize := big.NewInt(0) + randSize.Lsh(big.NewInt(1), 64) + cnonce, err := rand.Int(rand.Reader, randSize) + if err != nil { + log.Println("SASL rand: %s", err.String()) + return + } + cnonceStr := fmt.Sprintf("%016x", cnonce) + + /* Now encode the actual password response, as well as the + * expected next challenge from the server. */ + response := saslDigestResponse(username, realm, passwd, nonce, + cnonceStr, "AUTHENTICATE", digestUri, nonceCountStr) + next := saslDigestResponse(username, realm, passwd, nonce, + cnonceStr, "", digestUri, nonceCountStr) + cl.saslExpected = next + + // Build the map which will be encoded. + clMap := make(map[string]string) + clMap["realm"] = `"` + realm + `"` + clMap["username"] = `"` + username + `"` + clMap["nonce"] = `"` + nonce + `"` + clMap["cnonce"] = `"` + cnonceStr + `"` + clMap["nc"] = nonceCountStr + clMap["qop"] = "auth" + clMap["digest-uri"] = `"` + digestUri + `"` + clMap["response"] = response + if srvMap["charset"] == "utf-8" { + clMap["charset"] = "utf-8" + } + + // Encode the map and send it. + clStr := packSasl(clMap) + b64 := base64.StdEncoding + clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local: + "response"}, Chardata: + b64.EncodeToString([]byte(clStr))} + cl.xmlOut <- clObj +} + +func (cl *Client) saslDigest2(srvMap map[string] string) { + if cl.saslExpected == srvMap["rspauth"] { + clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local: + "response"}} + cl.xmlOut <- clObj + } else { + clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local: + "failure"}, Any: + &Unrecognized{XMLName: xml.Name{Space: nsSASL, + Local: "abort"}}} + cl.xmlOut <- clObj + } +} + +// Takes a string like `key1=value1,key2="value2"...` and returns a +// key/value map. +func parseSasl(in string) map[string]string { + re := regexp.MustCompile(`([^=]+)="?([^",]+)"?,?`) + strs := re.FindAllStringSubmatch(in, -1) + m := make(map[string]string) + for _, pair := range(strs) { + key := strings.ToLower(string(pair[1])) + value := string(pair[2]) + m[key] = value + } + return m +} + +func packSasl(m map[string]string) string { + var terms []string + for key, value := range(m) { + if key == "" || value == "" || value == `""` { + continue + } + terms = append(terms, key + "=" + value) + } + return strings.Join(terms, ",") +} + +func saslDigestResponse(username, realm, passwd, nonce, cnonceStr, + authenticate, digestUri, nonceCountStr string) string { + h := func(text string) []byte { + h := md5.New() + h.Write([]byte(text)) + return h.Sum() + } + hex := func(bytes []byte) string { + return fmt.Sprintf("%x", bytes) + } + kd := func(secret, data string) []byte { + return h(secret + ":" + data) + } + + a1 := string(h(username + ":" + realm + ":" + passwd)) + ":" + + nonce + ":" + cnonceStr + a2 := authenticate + ":" + digestUri + response := hex(kd(hex(h(a1)), nonce + ":" + + nonceCountStr + ":" + cnonceStr + ":auth:" + + hex(h(a2)))) + return response +}