Added SASL digest authentication.
authorChris Jones <chris@cjones.org>
Tue, 27 Dec 2011 15:36:07 -0700
changeset 11 48be1ae93fd4
parent 10 f38b0ee7b1c1
child 12 122ab6208c3c
Added SASL digest authentication.
stream.go
stream_test.go
structs.go
xmpp.go
--- a/stream.go	Mon Dec 26 18:07:14 2011 -0700
+++ b/stream.go	Tue Dec 27 15:36:07 2011 -0700
@@ -10,11 +10,18 @@
 package xmpp
 
 import (
+	"big"
+	"crypto/md5"
+	"crypto/rand"
 	"crypto/tls"
+	"encoding/base64"
+	"fmt"
 	"io"
 	"log"
 	"net"
 	"os"
+	"regexp"
+	"strings"
 	"time"
 	"xml"
 )
@@ -104,6 +111,9 @@
 			obj = &Features{}
 		case nsTLS + " proceed", nsTLS + " failure":
 			obj = &starttls{}
+		case nsSASL + " challenge", nsSASL + " failure",
+			nsSASL + " success":
+			obj = &auth{}
 		default:
 			obj = &Unrecognized{}
 			log.Printf("Ignoring unrecognized: %s %s\n",
@@ -156,7 +166,7 @@
 	}
 }
 
-func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) {
+func (cl *Client) readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) {
 	defer tryClose(srvIn, cliOut)
 
 	for x := range srvIn {
@@ -164,9 +174,11 @@
 		case *Stream:
 			handleStream(obj)
 		case *Features:
-			handleFeatures(obj, srvOut)
+			cl.handleFeatures(obj)
 		case *starttls:
 			cl.handleTls(obj)
+		case *auth:
+			cl.handleSasl(obj)
 		default:
 			cliOut <- x
 		}
@@ -184,11 +196,17 @@
 func handleStream(ss *Stream) {
 }
 
-func handleFeatures(fe *Features, srvOut chan<- interface{}) {
+func (cl *Client) handleFeatures(fe *Features) {
 	if fe.Starttls != nil {
 		start := &starttls{XMLName: xml.Name{Space: nsTLS,
 			Local: "starttls"}}
-		srvOut <- start
+		cl.xmlOut <- start
+		return
+	}
+
+	if len(fe.Mechanisms.Mechanism) > 0 {
+		cl.chooseSasl(fe)
+		return
 	}
 }
 
@@ -239,3 +257,182 @@
 	// Signal that we're going back to the read loop.
 	cl.socketSync.Done()
 }
+
+func (cl *Client) chooseSasl(fe *Features) {
+	var digestMd5 bool
+	for _, m := range(fe.Mechanisms.Mechanism) {
+		switch strings.ToLower(m) {
+		case "digest-md5":
+			digestMd5 = true
+		}
+	}
+
+	if digestMd5 {
+		auth := &auth{XMLName: xml.Name{Space: nsSASL, Local:
+				"auth"}, Mechanism: "DIGEST-MD5"}
+		cl.xmlOut <- auth
+	}
+}
+
+func (cl *Client) handleSasl(srv *auth) {
+	switch strings.ToLower(srv.XMLName.Local) {
+	case "challenge":
+		b64 := base64.StdEncoding
+		str, err := b64.DecodeString(srv.Chardata)
+		if err != nil {
+			log.Printf("SASL challenge decode: %s",
+				err.String())
+			return;
+		}
+		srvMap := parseSasl(string(str))
+
+		if cl.saslExpected == "" {
+			cl.saslDigest1(srvMap)
+		} else {
+			cl.saslDigest2(srvMap)
+		}
+	case "failure":
+		log.Println("SASL authentication failed")
+	case "success":
+		log.Println("SASL authentication succeeded")
+		ss := &Stream{To: cl.Jid.Domain, Version: Version}
+		cl.xmlOut <- ss
+	}
+}
+
+func (cl *Client) saslDigest1(srvMap map[string] string) {
+	// Make sure it supports qop=auth
+	var hasAuth bool
+	for _, qop := range(strings.Fields(srvMap["qop"])) {
+		if qop == "auth" {
+			hasAuth = true
+		}
+	}
+	if !hasAuth {
+		log.Println("Server doesn't support SASL auth")
+		return;
+	}
+
+	// Pick a realm.
+	var realm string
+	if srvMap["realm"] != "" {
+		realm = strings.Fields(srvMap["realm"])[0]
+	}
+
+	passwd := cl.password
+	nonce := srvMap["nonce"]
+	digestUri := "xmpp/" + cl.Jid.Domain
+	nonceCount := int32(1)
+	nonceCountStr := fmt.Sprintf("%08x", nonceCount)
+
+	// Begin building the response. Username is
+	// user@domain or just domain.
+	var username string
+	if cl.Jid.Node == nil {
+		username = cl.Jid.Domain
+	} else {
+		username = *cl.Jid.Node
+	}
+
+	// Generate our own nonce from random data.
+	randSize := big.NewInt(0)
+	randSize.Lsh(big.NewInt(1), 64)
+	cnonce, err := rand.Int(rand.Reader, randSize)
+	if err != nil {
+		log.Println("SASL rand: %s", err.String())
+		return
+	}
+	cnonceStr := fmt.Sprintf("%016x", cnonce)
+
+	/* Now encode the actual password response, as well as the
+	 * expected next challenge from the server. */
+	response := saslDigestResponse(username, realm, passwd, nonce,
+		cnonceStr, "AUTHENTICATE", digestUri, nonceCountStr)
+	next := saslDigestResponse(username, realm, passwd, nonce,
+		cnonceStr, "", digestUri, nonceCountStr)
+	cl.saslExpected = next
+
+	// Build the map which will be encoded.
+	clMap := make(map[string]string)
+	clMap["realm"] = `"` + realm + `"`
+	clMap["username"] = `"` + username + `"`
+	clMap["nonce"] = `"` + nonce + `"`
+	clMap["cnonce"] = `"` + cnonceStr + `"`
+	clMap["nc"] =  nonceCountStr
+	clMap["qop"] = "auth"
+	clMap["digest-uri"] = `"` + digestUri + `"`
+	clMap["response"] = response
+	if srvMap["charset"] == "utf-8" {
+		clMap["charset"] = "utf-8"
+	}
+
+	// Encode the map and send it.
+	clStr := packSasl(clMap)
+	b64 := base64.StdEncoding
+	clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
+			"response"}, Chardata:
+		b64.EncodeToString([]byte(clStr))}
+	cl.xmlOut <- clObj
+}
+
+func (cl *Client) saslDigest2(srvMap map[string] string) {
+	if cl.saslExpected == srvMap["rspauth"] {
+		clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
+				"response"}}
+		cl.xmlOut <- clObj
+	} else {
+		clObj := &auth{XMLName: xml.Name{Space: nsSASL, Local:
+				"failure"}, Any:
+			&Unrecognized{XMLName: xml.Name{Space: nsSASL,
+				Local: "abort"}}}
+		cl.xmlOut <- clObj
+	}
+}
+
+// Takes a string like `key1=value1,key2="value2"...` and returns a
+// key/value map.
+func parseSasl(in string) map[string]string {
+	re := regexp.MustCompile(`([^=]+)="?([^",]+)"?,?`)
+	strs := re.FindAllStringSubmatch(in, -1)
+	m := make(map[string]string)
+	for _, pair := range(strs) {
+		key := strings.ToLower(string(pair[1]))
+		value := string(pair[2])
+		m[key] = value
+	}
+	return m
+}
+
+func packSasl(m map[string]string) string {
+	var terms []string
+	for key, value := range(m) {
+		if key == "" || value == "" || value == `""` {
+			continue
+		}
+		terms = append(terms, key + "=" + value)
+	}
+	return strings.Join(terms, ",")
+}
+
+func saslDigestResponse(username, realm, passwd, nonce, cnonceStr,
+	authenticate, digestUri, nonceCountStr string) string {
+	h := func(text string) []byte {
+		h := md5.New()
+		h.Write([]byte(text))
+		return h.Sum()
+	}
+	hex := func(bytes []byte) string {
+		return fmt.Sprintf("%x", bytes)
+	}
+	kd := func(secret, data string) []byte {
+		return h(secret + ":" + data)
+	}
+
+	a1 := string(h(username + ":" + realm + ":" + passwd)) + ":" +
+		nonce + ":" + cnonceStr
+	a2 := authenticate + ":" + digestUri
+	response := hex(kd(hex(h(a1)), nonce + ":" +
+		nonceCountStr + ":" + cnonceStr + ":auth:" +
+		hex(h(a2))))
+	return response
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stream_test.go	Tue Dec 27 15:36:07 2011 -0700
@@ -0,0 +1,19 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xmpp
+
+import (
+	"testing"
+)
+
+func TestSaslDigest(t *testing.T) {
+	// These values are from RFC2831, section 4.
+	obs := saslDigestResponse("chris", "elwood.innosoft.com",
+		"secret", "OA6MG9tEQGm2hh", "OA6MHXh6VqTrRk",
+		"AUTHENTICATE", "imap/elwood.innosoft.com",
+		"00000001")
+	exp := "d388dad90d4bbd760a152321f2143af7"
+	assertEquals(t, exp, obs)
+}
--- a/structs.go	Mon Dec 26 18:07:14 2011 -0700
+++ b/structs.go	Tue Dec 27 15:36:07 2011 -0700
@@ -72,6 +72,13 @@
 	Mechanism []string
 }
 
+type auth struct {
+	XMLName xml.Name
+	Chardata string `xml:"chardata"`
+	Mechanism string `xml:"attr"`
+	Any *Unrecognized
+}
+
 type Unrecognized struct {
 	XMLName xml.Name
 }
--- a/xmpp.go	Mon Dec 26 18:07:14 2011 -0700
+++ b/xmpp.go	Tue Dec 27 15:36:07 2011 -0700
@@ -23,6 +23,7 @@
 	nsStreams = "urn:ietf:params:xml:ns:xmpp-streams"
 	nsStream = "http://etherx.jabber.org/streams"
 	nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
+	nsSASL = "urn:ietf:params:xml:ns:xmpp-sasl"
 
 	// DNS SRV names
 	serverSrv = "xmpp-server"
@@ -34,8 +35,10 @@
 // The client in a client-server XMPP connection.
 type Client struct {
 	Jid JID
+	password string
 	socket net.Conn
 	socketSync sync.WaitGroup
+	saslExpected string
 	In <-chan interface{}
 	Out chan<- interface{}
 	xmlOut chan<- interface{}
@@ -74,6 +77,7 @@
 	}
 
 	cl := new(Client)
+	cl.password = password
 	cl.Jid = *jid
 	cl.socket = tcp
 
@@ -136,7 +140,7 @@
 
 func (cl *Client) startStreamReader(xmlIn <-chan interface{}, srvOut chan<- interface{}) <-chan interface{} {
 	ch := make(chan interface{})
-	go cl.readStream(xmlIn, srvOut, ch)
+	go cl.readStream(xmlIn, ch)
 	return ch
 }