gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/connection_handshake.go (about) 1 package rethinkdb 2 3 import ( 4 "bufio" 5 "crypto/hmac" 6 "crypto/rand" 7 "crypto/sha256" 8 "encoding/base64" 9 "encoding/binary" 10 "encoding/json" 11 "fmt" 12 "hash" 13 "io" 14 "strconv" 15 "strings" 16 17 "golang.org/x/crypto/pbkdf2" 18 19 p "gopkg.in/rethinkdb/rethinkdb-go.v6/ql2" 20 ) 21 22 type HandshakeVersion int 23 24 const ( 25 HandshakeV1_0 HandshakeVersion = iota 26 HandshakeV0_4 27 ) 28 29 type connectionHandshake interface { 30 Send() error 31 } 32 33 func (c *Connection) handshake(version HandshakeVersion) (connectionHandshake, error) { 34 switch version { 35 case HandshakeV0_4: 36 return &connectionHandshakeV0_4{conn: c}, nil 37 case HandshakeV1_0: 38 return &connectionHandshakeV1_0{conn: c}, nil 39 default: 40 return nil, fmt.Errorf("Unrecognised handshake version") 41 } 42 } 43 44 type connectionHandshakeV0_4 struct { 45 conn *Connection 46 } 47 48 func (c *connectionHandshakeV0_4) Send() error { 49 // Send handshake request 50 if err := c.writeHandshakeReq(); err != nil { 51 c.conn.Close() 52 return RQLConnectionError{rqlError(err.Error())} 53 } 54 // Read handshake response 55 if err := c.readHandshakeSuccess(); err != nil { 56 c.conn.Close() 57 return RQLConnectionError{rqlError(err.Error())} 58 } 59 60 return nil 61 } 62 63 func (c *connectionHandshakeV0_4) writeHandshakeReq() error { 64 pos := 0 65 dataLen := 4 + 4 + len(c.conn.opts.AuthKey) + 4 66 data := make([]byte, dataLen) 67 68 // Send the protocol version to the server as a 4-byte little-endian-encoded integer 69 binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V0_4)) 70 pos += 4 71 72 // Send the length of the auth key to the server as a 4-byte little-endian-encoded integer 73 binary.LittleEndian.PutUint32(data[pos:], uint32(len(c.conn.opts.AuthKey))) 74 pos += 4 75 76 // Send the auth key as an ASCII string 77 if len(c.conn.opts.AuthKey) > 0 { 78 pos += copy(data[pos:], c.conn.opts.AuthKey) 79 } 80 81 // Send the protocol type as a 4-byte little-endian-encoded integer 82 binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_JSON)) 83 pos += 4 84 85 return c.conn.writeData(data) 86 } 87 88 func (c *connectionHandshakeV0_4) readHandshakeSuccess() error { 89 reader := bufio.NewReader(c.conn.Conn) 90 line, err := reader.ReadBytes('\x00') 91 if err != nil { 92 if err == io.EOF { 93 return fmt.Errorf("Unexpected EOF: %s", string(line)) 94 } 95 return err 96 } 97 // convert to string and remove trailing NUL byte 98 response := string(line[:len(line)-1]) 99 if response != "SUCCESS" { 100 response = strings.TrimSpace(response) 101 // we failed authorization or something else terrible happened 102 return RQLDriverError{rqlError(fmt.Sprintf("Server dropped connection with message: \"%s\"", response))} 103 } 104 105 return nil 106 } 107 108 const ( 109 handshakeV1_0_protocolVersionNumber = 0 110 handshakeV1_0_authenticationMethod = "SCRAM-SHA-256" 111 ) 112 113 type connectionHandshakeV1_0 struct { 114 conn *Connection 115 reader *bufio.Reader 116 117 authMsg string 118 } 119 120 func (c *connectionHandshakeV1_0) Send() error { 121 c.reader = bufio.NewReader(c.conn.Conn) 122 123 // Generate client nonce 124 clientNonce, err := c.generateNonce() 125 if err != nil { 126 c.conn.Close() 127 return RQLDriverError{rqlError(fmt.Sprintf("Failed to generate client nonce: %s", err))} 128 } 129 // Send client first message 130 if err := c.writeFirstMessage(clientNonce); err != nil { 131 c.conn.Close() 132 return err 133 } 134 // Read status 135 if err := c.checkServerVersions(); err != nil { 136 c.conn.Close() 137 return err 138 } 139 140 // Read server first message 141 i, salt, serverNonce, err := c.readFirstMessage() 142 if err != nil { 143 c.conn.Close() 144 return err 145 } 146 147 // Check server nonce 148 if !strings.HasPrefix(serverNonce, clientNonce) { 149 return RQLAuthError{RQLDriverError{rqlError("Invalid nonce from server")}} 150 } 151 152 // Generate proof 153 saltedPass := c.saltPassword(i, salt) 154 clientProof := c.calculateProof(saltedPass, clientNonce, serverNonce) 155 serverSignature := c.serverSignature(saltedPass) 156 157 // Send client final message 158 if err := c.writeFinalMessage(serverNonce, clientProof); err != nil { 159 c.conn.Close() 160 return err 161 } 162 // Read server final message 163 if err := c.readFinalMessage(serverSignature); err != nil { 164 c.conn.Close() 165 return err 166 } 167 168 return nil 169 } 170 171 func (c *connectionHandshakeV1_0) writeFirstMessage(clientNonce string) error { 172 // Default username to admin if not set 173 username := "admin" 174 if c.conn.opts.Username != "" { 175 username = c.conn.opts.Username 176 } 177 178 c.authMsg = fmt.Sprintf("n=%s,r=%s", username, clientNonce) 179 msg := fmt.Sprintf( 180 `{"protocol_version": %d,"authentication": "n,,%s","authentication_method": "%s"}`, 181 handshakeV1_0_protocolVersionNumber, c.authMsg, handshakeV1_0_authenticationMethod, 182 ) 183 184 pos := 0 185 dataLen := 4 + len(msg) + 1 186 data := make([]byte, dataLen) 187 188 // Send the protocol version to the server as a 4-byte little-endian-encoded integer 189 binary.LittleEndian.PutUint32(data[pos:], uint32(p.VersionDummy_V1_0)) 190 pos += 4 191 192 // Send the auth message as an ASCII string 193 pos += copy(data[pos:], msg) 194 195 // Add null terminating byte 196 data[pos] = '\x00' 197 198 return c.writeData(data) 199 } 200 201 func (c *connectionHandshakeV1_0) checkServerVersions() error { 202 b, err := c.readResponse() 203 if err != nil { 204 return err 205 } 206 207 // Read status 208 type versionsResponse struct { 209 Success bool `json:"success"` 210 MinProtocolVersion int `json:"min_protocol_version"` 211 MaxProtocolVersion int `json:"max_protocol_version"` 212 ServerVersion string `json:"server_version"` 213 ErrorCode int `json:"error_code"` 214 Error string `json:"error"` 215 } 216 var rsp *versionsResponse 217 statusStr := string(b) 218 219 if err := json.Unmarshal(b, &rsp); err != nil { 220 if strings.HasPrefix(statusStr, "ERROR: ") { 221 statusStr = strings.TrimPrefix(statusStr, "ERROR: ") 222 return RQLConnectionError{rqlError(statusStr)} 223 } 224 225 return RQLDriverError{rqlError(fmt.Sprintf("Error reading versions: %s", err))} 226 } 227 228 if !rsp.Success { 229 return c.handshakeError(rsp.ErrorCode, rsp.Error) 230 } 231 if rsp.MinProtocolVersion > handshakeV1_0_protocolVersionNumber || 232 rsp.MaxProtocolVersion < handshakeV1_0_protocolVersionNumber { 233 return RQLDriverError{rqlError( 234 fmt.Sprintf( 235 "Unsupported protocol version %d, expected between %d and %d.", 236 handshakeV1_0_protocolVersionNumber, 237 rsp.MinProtocolVersion, 238 rsp.MaxProtocolVersion, 239 ), 240 )} 241 } 242 243 return nil 244 } 245 246 func (c *connectionHandshakeV1_0) readFirstMessage() (i int64, salt []byte, serverNonce string, err error) { 247 b, err2 := c.readResponse() 248 if err2 != nil { 249 err = err2 250 return 251 } 252 253 // Read server message 254 type firstMessageResponse struct { 255 Success bool `json:"success"` 256 Authentication string `json:"authentication"` 257 ErrorCode int `json:"error_code"` 258 Error string `json:"error"` 259 } 260 var rsp *firstMessageResponse 261 262 if err2 := json.Unmarshal(b, &rsp); err2 != nil { 263 err = RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err2))} 264 return 265 } 266 if !rsp.Success { 267 err = c.handshakeError(rsp.ErrorCode, rsp.Error) 268 return 269 } 270 271 c.authMsg += "," 272 c.authMsg += rsp.Authentication 273 274 // Parse authentication field 275 auth := map[string]string{} 276 parts := strings.Split(rsp.Authentication, ",") 277 for _, part := range parts { 278 i := strings.Index(part, "=") 279 if i != -1 { 280 auth[part[:i]] = part[i+1:] 281 } 282 } 283 284 // Extract return values 285 if v, ok := auth["i"]; ok { 286 i, err = strconv.ParseInt(v, 10, 64) 287 if err != nil { 288 return 289 } 290 } 291 if v, ok := auth["s"]; ok { 292 salt, err = base64.StdEncoding.DecodeString(v) 293 if err != nil { 294 return 295 } 296 } 297 if v, ok := auth["r"]; ok { 298 serverNonce = v 299 } 300 301 return 302 } 303 304 func (c *connectionHandshakeV1_0) writeFinalMessage(serverNonce, clientProof string) error { 305 authMsg := "c=biws,r=" 306 authMsg += serverNonce 307 authMsg += ",p=" 308 authMsg += clientProof 309 310 msg := fmt.Sprintf(`{"authentication": "%s"}`, authMsg) 311 312 pos := 0 313 dataLen := len(msg) + 1 314 data := make([]byte, dataLen) 315 316 // Send the auth message as an ASCII string 317 pos += copy(data[pos:], msg) 318 319 // Add null terminating byte 320 data[pos] = '\x00' 321 322 return c.writeData(data) 323 } 324 325 func (c *connectionHandshakeV1_0) readFinalMessage(serverSignature string) error { 326 b, err := c.readResponse() 327 if err != nil { 328 return err 329 } 330 331 // Read server message 332 type finalMessageResponse struct { 333 Success bool `json:"success"` 334 Authentication string `json:"authentication"` 335 ErrorCode int `json:"error_code"` 336 Error string `json:"error"` 337 } 338 var rsp *finalMessageResponse 339 340 if err := json.Unmarshal(b, &rsp); err != nil { 341 return RQLDriverError{rqlError(fmt.Sprintf("Error parsing auth response: %s", err))} 342 } 343 if !rsp.Success { 344 return c.handshakeError(rsp.ErrorCode, rsp.Error) 345 } 346 347 // Parse authentication field 348 auth := map[string]string{} 349 parts := strings.Split(rsp.Authentication, ",") 350 for _, part := range parts { 351 i := strings.Index(part, "=") 352 if i != -1 { 353 auth[part[:i]] = part[i+1:] 354 } 355 } 356 357 // Validate server response 358 if serverSignature != auth["v"] { 359 return RQLAuthError{RQLDriverError{rqlError("Invalid server signature")}} 360 } 361 362 return nil 363 } 364 365 func (c *connectionHandshakeV1_0) writeData(data []byte) error { 366 367 if err := c.conn.writeData(data); err != nil { 368 return RQLConnectionError{rqlError(err.Error())} 369 } 370 371 return nil 372 } 373 374 func (c *connectionHandshakeV1_0) readResponse() ([]byte, error) { 375 line, err := c.reader.ReadBytes('\x00') 376 if err != nil { 377 if err == io.EOF { 378 return nil, RQLConnectionError{rqlError(fmt.Sprintf("Unexpected EOF: %s", string(line)))} 379 } 380 return nil, RQLConnectionError{rqlError(err.Error())} 381 } 382 383 // Strip null byte and return 384 return line[:len(line)-1], nil 385 } 386 387 func (c *connectionHandshakeV1_0) generateNonce() (string, error) { 388 const nonceSize = 24 389 390 b := make([]byte, nonceSize) 391 _, err := rand.Read(b) 392 if err != nil { 393 return "", err 394 } 395 396 return base64.StdEncoding.EncodeToString(b), nil 397 } 398 399 func (c *connectionHandshakeV1_0) saltPassword(iter int64, salt []byte) []byte { 400 pass := []byte(c.conn.opts.Password) 401 402 return pbkdf2.Key(pass, salt, int(iter), sha256.Size, sha256.New) 403 } 404 405 func (c *connectionHandshakeV1_0) calculateProof(saltedPass []byte, clientNonce, serverNonce string) string { 406 // Generate proof 407 c.authMsg += ",c=biws,r=" + serverNonce 408 409 mac := hmac.New(c.hashFunc(), saltedPass) 410 mac.Write([]byte("Client Key")) 411 clientKey := mac.Sum(nil) 412 413 hash := c.hashFunc()() 414 hash.Write(clientKey) 415 storedKey := hash.Sum(nil) 416 417 mac = hmac.New(c.hashFunc(), storedKey) 418 mac.Write([]byte(c.authMsg)) 419 clientSignature := mac.Sum(nil) 420 clientProof := make([]byte, len(clientKey)) 421 for i, _ := range clientKey { 422 clientProof[i] = clientKey[i] ^ clientSignature[i] 423 } 424 425 return base64.StdEncoding.EncodeToString(clientProof) 426 } 427 428 func (c *connectionHandshakeV1_0) serverSignature(saltedPass []byte) string { 429 mac := hmac.New(c.hashFunc(), saltedPass) 430 mac.Write([]byte("Server Key")) 431 serverKey := mac.Sum(nil) 432 433 mac = hmac.New(c.hashFunc(), serverKey) 434 mac.Write([]byte(c.authMsg)) 435 serverSignature := mac.Sum(nil) 436 437 return base64.StdEncoding.EncodeToString(serverSignature) 438 } 439 440 func (c *connectionHandshakeV1_0) handshakeError(code int, message string) error { 441 if code >= 10 || code <= 20 { 442 return RQLAuthError{RQLDriverError{rqlError(message)}} 443 } 444 445 return RQLDriverError{rqlError(message)} 446 } 447 448 func (c *connectionHandshakeV1_0) hashFunc() func() hash.Hash { 449 return sha256.New 450 }