github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/transport.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kex2 5 6 import ( 7 "crypto/hmac" 8 "crypto/rand" 9 "crypto/sha256" 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "sync" 15 "time" 16 17 "github.com/keybase/go-codec/codec" 18 "golang.org/x/crypto/nacl/secretbox" 19 "golang.org/x/net/context" 20 ) 21 22 // DeviceID is a 16-byte identifier that each side of key exchange has. It's 23 // used primarily to tell sender from receiver. 24 type DeviceID [16]byte 25 26 // SessionID is a 32-byte session identifier that's derived from the shared 27 // session secret. It's used to route messages on the server side. 28 type SessionID [32]byte 29 30 // SecretLen is the number of bytes in the secret. 31 const SecretLen = 32 32 33 // Secret is the 32-byte shared secret identifier 34 type Secret [SecretLen]byte 35 36 // Seqno increments on every message sent from a Kex sender. 37 type Seqno uint32 38 39 // Eq returns true if the two device IDs are equal 40 func (d DeviceID) Eq(d2 DeviceID) bool { 41 return hmac.Equal(d[:], d2[:]) 42 } 43 44 // Eq returns true if the two session IDs are equal 45 func (s SessionID) Eq(s2 SessionID) bool { 46 return hmac.Equal(s[:], s2[:]) 47 } 48 49 // MessageRouter is a stateful message router that will be implemented by 50 // JSON/REST calls to the Keybase API server. 51 type MessageRouter interface { 52 53 // Post a message. Message will always be non-nil and non-empty. 54 // Even for an EOF, the empty buffer is encrypted via SecretBox, 55 // so the buffer posted to the server will have data. 56 Post(I SessionID, sender DeviceID, seqno Seqno, msg []byte) error 57 58 // Get messages on the channel. Only poll for `poll` milliseconds. If the timeout 59 // elapses without any data ready, then just return an empty result, with nil error. 60 // Several messages can be returned at once, which should be processed in serial. 61 // They are guaranteed to be in order; otherwise, there was an issue. 62 // Get() should only return a non-nil error if there was an HTTPS or TCP-level error. 63 // Application-level errors like EOF or no data ready are handled by modulating 64 // the `msgs` result. 65 Get(I SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) (msg [][]byte, err error) 66 } 67 68 // Conn is a struct that obeys the net.Conn interface. It establishes a session abstraction 69 // over a message channel bounced off the Keybase API server, applying the appropriate 70 // e2e encryption/MAC'ing. 71 type Conn struct { 72 router MessageRouter 73 secret Secret 74 sessionID SessionID 75 deviceID DeviceID 76 77 // Protects the read path. There should only be one reader outstanding at once. 78 readMutex sync.Mutex 79 readSeqno Seqno 80 readDeadline time.Time 81 readTimeout time.Duration 82 bufferedMsgs [][]byte 83 84 // Protects the write path. There should only be one writer outstanding at once. 85 writeMutex sync.Mutex 86 writeSeqno Seqno 87 88 // Protects the pollLoopRunning mutex. We expose this mainly for testing purposes 89 pollLoopRunningMutex sync.Mutex 90 pollLoopRunning bool 91 92 // Protects the setting of error states. Only one thread should be setting or 93 // accessing these errors at a time. 94 errMutex sync.Mutex 95 readErr error 96 writeErr error 97 closed bool 98 99 ctx context.Context 100 lctx LogContext 101 } 102 103 const sessionIDText = "Kex v2 Session ID" 104 105 // NewConn establishes a Kex session based on the given secret. Will work for 106 // both ends of the connection, regardless of which order the two started 107 // their connection. Will communicate with the other end via the given message router. 108 // You can specify an optional timeout to cancel any reads longer than that timeout. 109 func NewConn(ctx context.Context, lctx LogContext, r MessageRouter, s Secret, d DeviceID, readTimeout time.Duration) (con net.Conn, err error) { 110 mac := hmac.New(sha256.New, s[:]) 111 _, err = mac.Write([]byte(sessionIDText)) 112 if err != nil { 113 return nil, err 114 } 115 tmp := mac.Sum(nil) 116 var sessionID SessionID 117 copy(sessionID[:], tmp) 118 ret := &Conn{ 119 router: r, 120 secret: s, 121 sessionID: sessionID, 122 deviceID: d, 123 readSeqno: 0, 124 readTimeout: readTimeout, 125 writeSeqno: 0, 126 ctx: ctx, 127 lctx: lctx, 128 } 129 return ret, nil 130 } 131 132 // TimedoutError is for operations that timed out; for instance, if no read 133 // data was available before the deadline. 134 type timedoutError struct{} 135 136 // Error returns the string representation of this error 137 func (t timedoutError) Error() string { return "operation timed out" } 138 139 // Temporary returns if the error is retryable 140 func (t timedoutError) Temporary() bool { return true } 141 142 // Timeout returns if this error is a timeout 143 func (t timedoutError) Timeout() bool { return true } 144 145 // ErrTimedOut is the signleton error we use if the operation timedout. 146 var ErrTimedOut net.Error = timedoutError{} 147 148 // ErrUnimplemented indicates the given method isn't implemented 149 var ErrUnimplemented = errors.New("unimplemented") 150 151 // ErrBadMetadata indicates that the metadata outside the encrypted message 152 // didn't match what was inside. 153 var ErrBadMetadata = errors.New("bad metadata") 154 155 // ErrBadDecryption indicates that a ciphertext failed to decrypt or MAC properly 156 var ErrDecryption = errors.New("decryption failed") 157 158 // ErrNotEnoughRandomness indicates that encryption failed due to insufficient 159 // randomness 160 var ErrNotEnoughRandomness = errors.New("not enough random data") 161 162 // ErrWrongSession indicates that the given session didn't match the 163 // clients expectations 164 var ErrWrongSession = errors.New("got message for wrong Session ID") 165 166 // ErrSelfReceive indicates that the client received a message sent by 167 // itself, which should never happen 168 var ErrSelfRecieve = errors.New("got message back that we sent") 169 170 // ErrAgain indicates that no data was available to read, but the 171 // reader was in non-blocking mode, so to try again later. 172 var ErrAgain = errors.New("no data were ready to read") 173 174 // ErrBadSecret indicates that the secret received was invalid. 175 var ErrBadSecret = errors.New("bad secret") 176 177 // ErrHelloTimeout indicates that the Hello() part of the 178 // protocol timed out. Most likely due to an incorrect 179 // secret phrase from the user. 180 var ErrHelloTimeout = errors.New("hello timeout") 181 182 // ErrBadPacketSequence indicates that packets arrived out of order from the 183 // server (which they shouldn't). 184 type ErrBadPacketSequence struct { 185 SessionID SessionID 186 SenderID DeviceID 187 ReceivedSeqno Seqno 188 PrevSeqno Seqno 189 } 190 191 func (e ErrBadPacketSequence) Error() string { 192 return fmt.Sprintf("Unexpected out-of-order packet arrival {SessionID: %v, SenderID: %v, ReceivedSeqno: %d, PrevSeqno: %d})", 193 e.SessionID, e.SenderID, e.ReceivedSeqno, e.PrevSeqno) 194 } 195 196 func (c *Conn) setReadError(e error) error { 197 c.errMutex.Lock() 198 c.readErr = e 199 c.errMutex.Unlock() 200 return e 201 } 202 203 func (c *Conn) setWriteError(e error) error { 204 c.errMutex.Lock() 205 c.writeErr = e 206 c.errMutex.Unlock() 207 return e 208 } 209 210 func (c *Conn) getErrorForWrite() error { 211 var err error 212 c.errMutex.Lock() 213 if c.readErr != nil && c.readErr != io.EOF { 214 err = c.readErr 215 } else if c.writeErr != nil { 216 err = c.writeErr 217 } 218 c.errMutex.Unlock() 219 return err 220 } 221 222 func (c *Conn) setClosed() { 223 c.errMutex.Lock() 224 c.closed = true 225 c.errMutex.Unlock() 226 } 227 228 func (c *Conn) getClosed() bool { 229 c.errMutex.Lock() 230 ret := c.closed 231 c.errMutex.Unlock() 232 return ret 233 } 234 235 func (c *Conn) getErrorForRead() error { 236 var err error 237 c.errMutex.Lock() 238 if c.readErr != nil { 239 err = c.readErr 240 } else if c.writeErr != nil && c.writeErr != io.EOF { 241 err = c.writeErr 242 } 243 c.errMutex.Unlock() 244 return err 245 } 246 247 func (c *Conn) setPollLoopRunning(b bool) { 248 c.pollLoopRunningMutex.Lock() 249 c.pollLoopRunning = b 250 c.pollLoopRunningMutex.Unlock() 251 } 252 253 type outerMsg struct { 254 _struct bool `codec:",toarray"` //nolint 255 SenderID DeviceID `codec:"senderID"` 256 SessionID SessionID `codec:"sessionID"` 257 Seqno Seqno `codec:"seqno"` 258 Nonce [24]byte `codec:"nonce"` 259 Payload []byte `codec:"payload"` 260 } 261 262 type innerMsg struct { 263 _struct bool `codec:",toarray"` //nolint 264 SenderID DeviceID `codec:"senderID"` 265 SessionID SessionID `codec:"sessionID"` 266 Seqno Seqno `codec:"seqno"` 267 Payload []byte `codec:"payload"` 268 } 269 270 func (c *Conn) decryptIncomingMessage(msg []byte) (int, error) { 271 var err error 272 mh := codec.MsgpackHandle{WriteExt: true} 273 dec := codec.NewDecoderBytes(msg, &mh) 274 var om outerMsg 275 err = dec.Decode(&om) 276 if err != nil { 277 c.lctx.Debug("Conn#decryptIncomingMessage: decoding failure: %s", err.Error()) 278 return 0, err 279 } 280 var plaintext []byte 281 var ok bool 282 plaintext, ok = secretbox.Open(plaintext, om.Payload, &om.Nonce, (*[32]byte)(&c.secret)) 283 if !ok { 284 return 0, ErrDecryption 285 } 286 dec = codec.NewDecoderBytes(plaintext, &mh) 287 var im innerMsg 288 err = dec.Decode(&im) 289 if err != nil { 290 return 0, err 291 } 292 if !om.SenderID.Eq(im.SenderID) || !om.SessionID.Eq(im.SessionID) || om.Seqno != im.Seqno { 293 return 0, ErrBadMetadata 294 } 295 if !im.SessionID.Eq(c.sessionID) { 296 return 0, ErrWrongSession 297 } 298 if im.SenderID.Eq(c.deviceID) { 299 return 0, ErrSelfRecieve 300 } 301 302 if im.Seqno != c.readSeqno+1 { 303 return 0, ErrBadPacketSequence{im.SessionID, im.SenderID, im.Seqno, c.readSeqno} 304 } 305 c.readSeqno = im.Seqno 306 307 c.bufferedMsgs = append(c.bufferedMsgs, im.Payload) 308 return len(im.Payload), nil 309 } 310 311 func (c *Conn) decryptIncomingMessages(msgs [][]byte) (int, error) { 312 var ret int 313 for _, msg := range msgs { 314 n, e := c.decryptIncomingMessage(msg) 315 if e != nil { 316 return ret, e 317 } 318 ret += n 319 } 320 return ret, nil 321 } 322 323 func (c *Conn) readBufferedMsgsIntoBytes(out []byte) (int, error) { 324 p := 0 325 326 // If no buffered messages, then return that we didn't pull any 327 // new data from the server. 328 if len(c.bufferedMsgs) == 0 { 329 return 0, nil 330 } 331 332 // Any empty buffer signals an EOF condition 333 if len(c.bufferedMsgs[0]) == 0 { 334 c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition") 335 return 0, io.EOF 336 } 337 338 for p < len(out) { 339 rem := len(out) - p 340 if len(c.bufferedMsgs) > 0 { 341 front := c.bufferedMsgs[0] 342 n := len(front) 343 344 // An empty buffer signifies that the other side wanted 345 // and EOF condition. However, we shouldn't return an EOF 346 // if we've read anything, this time through. 347 if n == 0 { 348 var err error 349 if p == 0 { 350 c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition (after consume loop)") 351 err = io.EOF 352 } 353 return p, err 354 } 355 356 if rem < n { 357 n = rem 358 copy(out[p:(p+n)], front[0:n]) 359 front = front[n:] 360 if len(front) == 0 { 361 // Be careful not to recycle an empty buffer into the 362 // list of buffered messages, since that has special 363 // significance (see above). 364 c.bufferedMsgs = c.bufferedMsgs[1:] 365 } else { 366 c.bufferedMsgs[0] = front 367 } 368 } else { 369 copy(out[p:(p+n)], front) 370 c.bufferedMsgs = c.bufferedMsgs[1:] 371 } 372 373 p += n 374 } else { 375 break 376 } 377 } 378 return p, nil 379 } 380 381 func (c *Conn) pollLoop(poll time.Duration) (msgs [][]byte, err error) { 382 383 var totalWaitTime time.Duration 384 385 c.setPollLoopRunning(true) 386 defer c.setPollLoopRunning(false) 387 388 start := time.Now() 389 for { 390 newPoll := poll - totalWaitTime 391 msgs, err = c.router.Get(c.sessionID, c.deviceID, c.readSeqno+1, newPoll) 392 totalWaitTime = time.Since(start) 393 if err != nil || len(msgs) > 0 || totalWaitTime >= poll || c.getClosed() { 394 return 395 } 396 397 select { 398 case <-c.ctx.Done(): 399 return nil, ErrCanceled 400 default: 401 } 402 } 403 } 404 405 // Read data from the connection, returning plaintext data if all 406 // cryptographic checks passed. Obeys the `net.Conn` interface. 407 // Returns the number of bytes read into the output buffer. 408 func (c *Conn) Read(out []byte) (n int, err error) { 409 410 c.readMutex.Lock() 411 defer c.readMutex.Unlock() 412 413 // The first error kills the whole stream 414 if err = c.getErrorForRead(); err != nil { 415 return 0, err 416 } 417 // First see if there's anything buffered, and read that 418 // out now. 419 if n, err = c.readBufferedMsgsIntoBytes(out); err != nil { 420 return 0, c.setReadError(err) 421 } 422 if n > 0 { 423 return n, nil 424 } 425 426 var poll time.Duration 427 if !c.readDeadline.IsZero() { 428 poll = time.Until(c.readDeadline) 429 if poll.Nanoseconds() < 0 { 430 return 0, c.setReadError(ErrTimedOut) 431 } 432 } else { 433 poll = c.readTimeout 434 } 435 436 var msgs [][]byte 437 msgs, err = c.pollLoop(poll) 438 439 if err != nil { 440 return 0, c.setReadError(err) 441 } 442 if _, err = c.decryptIncomingMessages(msgs); err != nil { 443 return 0, c.setReadError(err) 444 } 445 if n, err = c.readBufferedMsgsIntoBytes(out); err != nil { 446 return 0, c.setReadError(err) 447 } 448 449 if n == 0 { 450 switch { 451 case c.getClosed(): 452 c.lctx.Debug("conn#Read: EOF since connection was closed") 453 err = io.EOF 454 case poll > 0: 455 err = ErrTimedOut 456 default: 457 err = ErrAgain 458 } 459 } 460 461 return n, err 462 } 463 464 func (c *Conn) encryptOutgoingMessage(seqno Seqno, buf []byte) (ret []byte, err error) { 465 var nonce [24]byte 466 var n int 467 468 if n, err = rand.Read(nonce[:]); err != nil { 469 return nil, err 470 } else if n != 24 { 471 return nil, ErrNotEnoughRandomness 472 } 473 im := innerMsg{ 474 SenderID: c.deviceID, 475 SessionID: c.sessionID, 476 Seqno: seqno, 477 Payload: buf, 478 } 479 mh := codec.MsgpackHandle{WriteExt: true} 480 var imPacked []byte 481 enc := codec.NewEncoderBytes(&imPacked, &mh) 482 if err = enc.Encode(im); err != nil { 483 return nil, err 484 } 485 ciphertext := secretbox.Seal(nil, imPacked, &nonce, (*[32]byte)(&c.secret)) 486 487 om := outerMsg{ 488 SenderID: c.deviceID, 489 SessionID: c.sessionID, 490 Seqno: seqno, 491 Nonce: nonce, 492 Payload: ciphertext, 493 } 494 enc = codec.NewEncoderBytes(&ret, &mh) 495 if err = enc.Encode(om); err != nil { 496 return nil, err 497 } 498 return ret, nil 499 } 500 501 func (c *Conn) nextWriteSeqno() Seqno { 502 c.writeSeqno++ 503 return c.writeSeqno 504 } 505 506 // Write data to the connection, encrypting and MAC'ing along the way. 507 // Obeys the `net.Conn` interface 508 func (c *Conn) Write(buf []byte) (n int, err error) { 509 510 c.writeMutex.Lock() 511 defer c.writeMutex.Unlock() 512 513 // Our protocol specifies that writing an empty buffer means "close" 514 // the connection. We don't want callers of `Write` to do this by 515 // accident, we want them to call `Close()` explicitly. So short-circuit 516 // the write operation here for empty buffers. 517 if len(buf) == 0 { 518 return 0, nil 519 } 520 521 return c.writeWithLock(buf) 522 } 523 524 func (c *Conn) writeWithLock(buf []byte) (n int, err error) { 525 526 var ctext []byte 527 528 // The first error kills the whole stream 529 if err = c.getErrorForWrite(); err != nil { 530 return 0, err 531 } 532 seqno := c.nextWriteSeqno() 533 534 ctext, err = c.encryptOutgoingMessage(seqno, buf) 535 if err != nil { 536 return 0, c.setWriteError(err) 537 } 538 539 if err = c.router.Post(c.sessionID, c.deviceID, seqno, ctext); err != nil { 540 return 0, c.setWriteError(err) 541 } 542 543 return len(ctext), nil 544 } 545 546 // Close the connection to the server, sending an empty buffer via POST 547 // through the `MessageRouter`. Fulfills the `net.Conn` interface 548 func (c *Conn) Close() error { 549 550 c.writeMutex.Lock() 551 defer c.writeMutex.Unlock() 552 553 c.lctx.Debug("Conn#Close: all subsequent writes are EOFs") 554 555 // set closed so that the read loop will bail out above 556 c.setClosed() 557 558 // Write an empty buffer to signal EOF 559 if _, err := c.writeWithLock([]byte{}); err != nil { 560 return err 561 } 562 563 // All subsequent writes should fail. 564 _ = c.setWriteError(io.EOF) 565 566 return nil 567 } 568 569 // LocalAddr returns the local network address, fulfilling the `net.Conn interface` 570 func (c *Conn) LocalAddr() (addr net.Addr) { 571 return 572 } 573 574 // RemoteAddr returns the remote network address, fulfilling the `net.Conn interface` 575 func (c *Conn) RemoteAddr() (addr net.Addr) { 576 return 577 } 578 579 // SetDeadline sets the read and write deadlines associated 580 // with the connection. It is equivalent to calling both 581 // SetReadDeadline and SetWriteDeadline. 582 // 583 // A deadline is an absolute time after which I/O operations 584 // fail with a timeout (see type Error) instead of 585 // blocking. The deadline applies to all future I/O, not just 586 // the immediately following call to Read or Write. 587 // 588 // An idle timeout can be implemented by repeatedly extending 589 // the deadline after successful Read or Write calls. 590 // 591 // A zero value for t means I/O operations will not time out. 592 func (c *Conn) SetDeadline(t time.Time) error { 593 return c.SetReadDeadline(t) 594 } 595 596 // SetReadDeadline sets the deadline for future Read calls. 597 // A zero value for t means Read will not time out. 598 func (c *Conn) SetReadDeadline(t time.Time) error { 599 c.readMutex.Lock() 600 c.readDeadline = t 601 c.readMutex.Unlock() 602 return nil 603 } 604 605 // SetWriteDeadline sets the deadline for future Write calls. 606 // Even if write times out, it may return n > 0, indicating that 607 // some of the data was successfully written. 608 // A zero value for t means Write will not time out. 609 // We're not implementing this feature for now, so make it an error 610 // if we try to do so. 611 func (c *Conn) SetWriteDeadline(t time.Time) error { 612 return ErrUnimplemented 613 }