gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/connection_handshake.go (about)

     1  package rethinkdb
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/hmac"
     6  	"crypto/rand"
     7  	"crypto/sha256"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"encoding/json"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"golang.org/x/crypto/pbkdf2"
    18  
    19  	p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2"
    20  )
    21  
    22  type HandshakeVersion int
    23  
    24  const (
    25  	HandshakeV1_0 HandshakeVersion = iota
    26  	HandshakeV0_4
    27  )
    28  
    29  type connectionHandshake interface {
    30  	Send() error
    31  }
    32  
    33  func (c *Connection) handshake(version HandshakeVersion) (connectionHandshake, error) {
    34  	switch version {
    35  	case HandshakeV0_4:
    36  		return &connectionHandshakeV0_4{conn: c}, nil
    37  	case HandshakeV1_0:
    38  		return &connectionHandshakeV1_0{conn: c}, nil
    39  	default:
    40  		return nil, fmt.Errorf("Unrecognised handshake version")
    41  	}
    42  }
    43  
    44  type connectionHandshakeV0_4 struct {
    45  	conn *Connection
    46  }
    47  
    48  func (c *connectionHandshakeV0_4) Send() error {
    49  	// Send handshake request
    50  	if err := c.writeHandshakeReq(); err != nil {
    51  		c.conn.Close()
    52  		return RQLConnectionError{rqlError(err.Error())}
    53  	}
    54  	// Read handshake response
    55  	if err := c.readHandshakeSuccess(); err != nil {
    56  		c.conn.Close()
    57  		return RQLConnectionError{rqlError(err.Error())}
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (c *connectionHandshakeV0_4) writeHandshakeReq() error {
    64  	pos := 0
    65  	dataLen := 4 + 4 + len(c.conn.opts.AuthKey) + 4
    66  	data := make([]byte, dataLen)
    67  
    68  	// Send the protocol version to the server as a 4-byte little-endian-encoded integer
    69  	binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V0_4))
    70  	pos += 4
    71  
    72  	// Send the length of the auth key to the server as a 4-byte little-endian-encoded integer
    73  	binary.LittleEndian.PutUint32(data[pos:], uint32(len(c.conn.opts.AuthKey)))
    74  	pos += 4
    75  
    76  	// Send the auth key as an ASCII string
    77  	if len(c.conn.opts.AuthKey) > 0 {
    78  		pos += copy(data[pos:], c.conn.opts.AuthKey)
    79  	}
    80  
    81  	// Send the protocol type as a 4-byte little-endian-encoded integer
    82  	binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_JSON))
    83  	pos += 4
    84  
    85  	return c.conn.writeData(data)
    86  }
    87  
    88  func (c *connectionHandshakeV0_4) readHandshakeSuccess() error {
    89  	reader := bufio.NewReader(c.conn.Conn)
    90  	line, err := reader.ReadBytes('\x00')
    91  	if err != nil {
    92  		if err == io.EOF {
    93  			return fmt.Errorf("Unexpected EOF: %s", string(line))
    94  		}
    95  		return err
    96  	}
    97  	// convert to string and remove trailing NUL byte
    98  	response := string(line[:len(line)-1])
    99  	if response != "SUCCESS" {
   100  		response = strings.TrimSpace(response)
   101  		// we failed authorization or something else terrible happened
   102  		return RQLDriverError{rqlError(fmt.Sprintf("Server dropped connection with message: \"%s\"", response))}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  const (
   109  	handshakeV1_0_protocolVersionNumber = 0
   110  	handshakeV1_0_authenticationMethod  = "SCRAM-SHA-256"
   111  )
   112  
   113  type connectionHandshakeV1_0 struct {
   114  	conn   *Connection
   115  	reader *bufio.Reader
   116  
   117  	authMsg string
   118  }
   119  
   120  func (c *connectionHandshakeV1_0) Send() error {
   121  	c.reader = bufio.NewReader(c.conn.Conn)
   122  
   123  	// Generate client nonce
   124  	clientNonce, err := c.generateNonce()
   125  	if err != nil {
   126  		c.conn.Close()
   127  		return RQLDriverError{rqlError(fmt.Sprintf("Failed to generate client nonce: %s", err))}
   128  	}
   129  	// Send client first message
   130  	if err := c.writeFirstMessage(clientNonce); err != nil {
   131  		c.conn.Close()
   132  		return err
   133  	}
   134  	// Read status
   135  	if err := c.checkServerVersions(); err != nil {
   136  		c.conn.Close()
   137  		return err
   138  	}
   139  
   140  	// Read server first message
   141  	i, salt, serverNonce, err := c.readFirstMessage()
   142  	if err != nil {
   143  		c.conn.Close()
   144  		return err
   145  	}
   146  
   147  	// Check server nonce
   148  	if !strings.HasPrefix(serverNonce, clientNonce) {
   149  		return RQLAuthError{RQLDriverError{rqlError("Invalid nonce from server")}}
   150  	}
   151  
   152  	// Generate proof
   153  	saltedPass := c.saltPassword(i, salt)
   154  	clientProof := c.calculateProof(saltedPass, clientNonce, serverNonce)
   155  	serverSignature := c.serverSignature(saltedPass)
   156  
   157  	// Send client final message
   158  	if err := c.writeFinalMessage(serverNonce, clientProof); err != nil {
   159  		c.conn.Close()
   160  		return err
   161  	}
   162  	// Read server final message
   163  	if err := c.readFinalMessage(serverSignature); err != nil {
   164  		c.conn.Close()
   165  		return err
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  func (c *connectionHandshakeV1_0) writeFirstMessage(clientNonce string) error {
   172  	// Default username to admin if not set
   173  	username := "admin"
   174  	if c.conn.opts.Username != "" {
   175  		username = c.conn.opts.Username
   176  	}
   177  
   178  	c.authMsg = fmt.Sprintf("n=%s,r=%s", username, clientNonce)
   179  	msg := fmt.Sprintf(
   180  		`{"protocol_version": %d,"authentication": "n,,%s","authentication_method": "%s"}`,
   181  		handshakeV1_0_protocolVersionNumber, c.authMsg, handshakeV1_0_authenticationMethod,
   182  	)
   183  
   184  	pos := 0
   185  	dataLen := 4 + len(msg) + 1
   186  	data := make([]byte, dataLen)
   187  
   188  	// Send the protocol version to the server as a 4-byte little-endian-encoded integer
   189  	binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V1_0))
   190  	pos += 4
   191  
   192  	// Send the auth message as an ASCII string
   193  	pos += copy(data[pos:], msg)
   194  
   195  	// Add null terminating byte
   196  	data[pos] = '\x00'
   197  
   198  	return c.writeData(data)
   199  }
   200  
   201  func (c *connectionHandshakeV1_0) checkServerVersions() error {
   202  	b, err := c.readResponse()
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	// Read status
   208  	type versionsResponse struct {
   209  		Success            bool   `json:"success"`
   210  		MinProtocolVersion int    `json:"min_protocol_version"`
   211  		MaxProtocolVersion int    `json:"max_protocol_version"`
   212  		ServerVersion      string `json:"server_version"`
   213  		ErrorCode          int    `json:"error_code"`
   214  		Error              string `json:"error"`
   215  	}
   216  	var rsp *versionsResponse
   217  	statusStr := string(b)
   218  
   219  	if err := json.Unmarshal(b, &rsp); err != nil {
   220  		if strings.HasPrefix(statusStr, "ERROR: ") {
   221  			statusStr = strings.TrimPrefix(statusStr, "ERROR: ")
   222  			return RQLConnectionError{rqlError(statusStr)}
   223  		}
   224  
   225  		return RQLDriverError{rqlError(fmt.Sprintf("Error reading versions: %s", err))}
   226  	}
   227  
   228  	if !rsp.Success {
   229  		return c.handshakeError(rsp.ErrorCode, rsp.Error)
   230  	}
   231  	if rsp.MinProtocolVersion > handshakeV1_0_protocolVersionNumber ||
   232  		rsp.MaxProtocolVersion < handshakeV1_0_protocolVersionNumber {
   233  		return RQLDriverError{rqlError(
   234  			fmt.Sprintf(
   235  				"Unsupported protocol version %d, expected between %d and %d.",
   236  				handshakeV1_0_protocolVersionNumber,
   237  				rsp.MinProtocolVersion,
   238  				rsp.MaxProtocolVersion,
   239  			),
   240  		)}
   241  	}
   242  
   243  	return nil
   244  }
   245  
   246  func (c *connectionHandshakeV1_0) readFirstMessage() (i int64, salt []byte, serverNonce string, err error) {
   247  	b, err2 := c.readResponse()
   248  	if err2 != nil {
   249  		err = err2
   250  		return
   251  	}
   252  
   253  	// Read server message
   254  	type firstMessageResponse struct {
   255  		Success        bool   `json:"success"`
   256  		Authentication string `json:"authentication"`
   257  		ErrorCode      int    `json:"error_code"`
   258  		Error          string `json:"error"`
   259  	}
   260  	var rsp *firstMessageResponse
   261  
   262  	if err2 := json.Unmarshal(b, &rsp); err2 != nil {
   263  		err = RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err2))}
   264  		return
   265  	}
   266  	if !rsp.Success {
   267  		err = c.handshakeError(rsp.ErrorCode, rsp.Error)
   268  		return
   269  	}
   270  
   271  	c.authMsg += ","
   272  	c.authMsg += rsp.Authentication
   273  
   274  	// Parse authentication field
   275  	auth := map[string]string{}
   276  	parts := strings.Split(rsp.Authentication, ",")
   277  	for _, part := range parts {
   278  		i := strings.Index(part, "=")
   279  		if i != -1 {
   280  			auth[part[:i]] = part[i+1:]
   281  		}
   282  	}
   283  
   284  	// Extract return values
   285  	if v, ok := auth["i"]; ok {
   286  		i, err = strconv.ParseInt(v, 10, 64)
   287  		if err != nil {
   288  			return
   289  		}
   290  	}
   291  	if v, ok := auth["s"]; ok {
   292  		salt, err = base64.StdEncoding.DecodeString(v)
   293  		if err != nil {
   294  			return
   295  		}
   296  	}
   297  	if v, ok := auth["r"]; ok {
   298  		serverNonce = v
   299  	}
   300  
   301  	return
   302  }
   303  
   304  func (c *connectionHandshakeV1_0) writeFinalMessage(serverNonce, clientProof string) error {
   305  	authMsg := "c=biws,r="
   306  	authMsg += serverNonce
   307  	authMsg += ",p="
   308  	authMsg += clientProof
   309  
   310  	msg := fmt.Sprintf(`{"authentication": "%s"}`, authMsg)
   311  
   312  	pos := 0
   313  	dataLen := len(msg) + 1
   314  	data := make([]byte, dataLen)
   315  
   316  	// Send the auth message as an ASCII string
   317  	pos += copy(data[pos:], msg)
   318  
   319  	// Add null terminating byte
   320  	data[pos] = '\x00'
   321  
   322  	return c.writeData(data)
   323  }
   324  
   325  func (c *connectionHandshakeV1_0) readFinalMessage(serverSignature string) error {
   326  	b, err := c.readResponse()
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	// Read server message
   332  	type finalMessageResponse struct {
   333  		Success        bool   `json:"success"`
   334  		Authentication string `json:"authentication"`
   335  		ErrorCode      int    `json:"error_code"`
   336  		Error          string `json:"error"`
   337  	}
   338  	var rsp *finalMessageResponse
   339  
   340  	if err := json.Unmarshal(b, &rsp); err != nil {
   341  		return RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err))}
   342  	}
   343  	if !rsp.Success {
   344  		return c.handshakeError(rsp.ErrorCode, rsp.Error)
   345  	}
   346  
   347  	// Parse authentication field
   348  	auth := map[string]string{}
   349  	parts := strings.Split(rsp.Authentication, ",")
   350  	for _, part := range parts {
   351  		i := strings.Index(part, "=")
   352  		if i != -1 {
   353  			auth[part[:i]] = part[i+1:]
   354  		}
   355  	}
   356  
   357  	// Validate server response
   358  	if serverSignature != auth["v"] {
   359  		return RQLAuthError{RQLDriverError{rqlError("Invalid server signature")}}
   360  	}
   361  
   362  	return nil
   363  }
   364  
   365  func (c *connectionHandshakeV1_0) writeData(data []byte) error {
   366  
   367  	if err := c.conn.writeData(data); err != nil {
   368  		return RQLConnectionError{rqlError(err.Error())}
   369  	}
   370  
   371  	return nil
   372  }
   373  
   374  func (c *connectionHandshakeV1_0) readResponse() ([]byte, error) {
   375  	line, err := c.reader.ReadBytes('\x00')
   376  	if err != nil {
   377  		if err == io.EOF {
   378  			return nil, RQLConnectionError{rqlError(fmt.Sprintf("Unexpected EOF: %s", string(line)))}
   379  		}
   380  		return nil, RQLConnectionError{rqlError(err.Error())}
   381  	}
   382  
   383  	// Strip null byte and return
   384  	return line[:len(line)-1], nil
   385  }
   386  
   387  func (c *connectionHandshakeV1_0) generateNonce() (string, error) {
   388  	const nonceSize = 24
   389  
   390  	b := make([]byte, nonceSize)
   391  	_, err := rand.Read(b)
   392  	if err != nil {
   393  		return "", err
   394  	}
   395  
   396  	return base64.StdEncoding.EncodeToString(b), nil
   397  }
   398  
   399  func (c *connectionHandshakeV1_0) saltPassword(iter int64, salt []byte) []byte {
   400  	pass := []byte(c.conn.opts.Password)
   401  
   402  	return pbkdf2.Key(pass, salt, int(iter), sha256.Size, sha256.New)
   403  }
   404  
   405  func (c *connectionHandshakeV1_0) calculateProof(saltedPass []byte, clientNonce, serverNonce string) string {
   406  	// Generate proof
   407  	c.authMsg += ",c=biws,r=" + serverNonce
   408  
   409  	mac := hmac.New(c.hashFunc(), saltedPass)
   410  	mac.Write([]byte("Client Key"))
   411  	clientKey := mac.Sum(nil)
   412  
   413  	hash := c.hashFunc()()
   414  	hash.Write(clientKey)
   415  	storedKey := hash.Sum(nil)
   416  
   417  	mac = hmac.New(c.hashFunc(), storedKey)
   418  	mac.Write([]byte(c.authMsg))
   419  	clientSignature := mac.Sum(nil)
   420  	clientProof := make([]byte, len(clientKey))
   421  	for i, _ := range clientKey {
   422  		clientProof[i] = clientKey[i] ^ clientSignature[i]
   423  	}
   424  
   425  	return base64.StdEncoding.EncodeToString(clientProof)
   426  }
   427  
   428  func (c *connectionHandshakeV1_0) serverSignature(saltedPass []byte) string {
   429  	mac := hmac.New(c.hashFunc(), saltedPass)
   430  	mac.Write([]byte("Server Key"))
   431  	serverKey := mac.Sum(nil)
   432  
   433  	mac = hmac.New(c.hashFunc(), serverKey)
   434  	mac.Write([]byte(c.authMsg))
   435  	serverSignature := mac.Sum(nil)
   436  
   437  	return base64.StdEncoding.EncodeToString(serverSignature)
   438  }
   439  
   440  func (c *connectionHandshakeV1_0) handshakeError(code int, message string) error {
   441  	if code >= 10 || code <= 20 {
   442  		return RQLAuthError{RQLDriverError{rqlError(message)}}
   443  	}
   444  
   445  	return RQLDriverError{rqlError(message)}
   446  }
   447  
   448  func (c *connectionHandshakeV1_0) hashFunc() func() hash.Hash {
   449  	return sha256.New
   450  }