github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kex2/transport_test.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kex2 5 6 import ( 7 "bytes" 8 "crypto/rand" 9 "io" 10 "net" 11 "runtime" 12 "strings" 13 "sync" 14 "testing" 15 "time" 16 17 "github.com/stretchr/testify/require" 18 "golang.org/x/net/context" 19 ) 20 21 type message struct { 22 seqno Seqno 23 msg []byte 24 } 25 26 type simplexSession struct { 27 ch chan message 28 } 29 30 var zeroDeviceID DeviceID 31 32 func (d DeviceID) isZero() bool { 33 return d.Eq(zeroDeviceID) 34 } 35 36 func newSimplexSession() *simplexSession { 37 return &simplexSession{ 38 ch: make(chan message, 100), 39 } 40 } 41 42 type session struct { 43 id SessionID 44 devices [2]DeviceID 45 simplexSessions [2](*simplexSession) 46 } 47 48 func newSession(i SessionID) *session { 49 sess := &session{id: i} 50 for j := 0; j < 2; j++ { 51 sess.simplexSessions[j] = newSimplexSession() 52 } 53 return sess 54 } 55 56 func (s *session) getDeviceNumber(d DeviceID) int { 57 if s.devices[0].Eq(d) { 58 return 0 59 } 60 if s.devices[0].isZero() { 61 s.devices[0] = d 62 return 0 63 } 64 s.devices[1] = d 65 return 1 66 } 67 68 type mockRouter struct { 69 behavior int 70 maxPoll time.Duration 71 72 sessionMutex sync.Mutex 73 sessions map[SessionID]*session 74 } 75 76 const ( 77 GoodRouter = 0 78 BadRouterCorruptedSession = 1 << iota 79 BadRouterCorruptedSender = 1 << iota 80 BadRouterCorruptedCiphertext = 1 << iota 81 BadRouterReorder = 1 << iota 82 BadRouterDrop = 1 << iota 83 ) 84 85 func corruptMessage(behavior int, msg []byte) { 86 if (behavior & BadRouterCorruptedSession) != 0 { 87 msg[23] ^= 0x80 88 } 89 if (behavior & BadRouterCorruptedSender) != 0 { 90 msg[10] ^= 0x40 91 } 92 if (behavior & BadRouterCorruptedCiphertext) != 0 { 93 msg[len(msg)-10] ^= 0x01 94 } 95 } 96 97 func newMockRouterWithBehavior(b int) *mockRouter { 98 return &mockRouter{ 99 behavior: b, 100 sessions: make(map[SessionID]*session), 101 } 102 } 103 104 func newMockRouterWithBehaviorAndMaxPoll(b int, mp time.Duration) *mockRouter { 105 return &mockRouter{ 106 behavior: b, 107 maxPoll: mp, 108 sessions: make(map[SessionID]*session), 109 } 110 } 111 112 func (ss *simplexSession) post(seqno Seqno, msg []byte) error { 113 ss.ch <- message{seqno, msg} 114 return nil 115 } 116 117 type lookupType int 118 119 const ( 120 bySender lookupType = 0 121 byReceiver lookupType = 1 122 ) 123 124 func (s *session) findOrMakeSimplexSession(sender DeviceID, lt lookupType) *simplexSession { 125 i := s.getDeviceNumber(sender) 126 if lt == byReceiver { 127 i = 1 - i 128 } 129 return s.simplexSessions[i] 130 } 131 132 func (mr *mockRouter) findOrMakeSimplexSession(i SessionID, sender DeviceID, lt lookupType) *simplexSession { 133 mr.sessionMutex.Lock() 134 defer mr.sessionMutex.Unlock() 135 136 sess, ok := mr.sessions[i] 137 if !ok { 138 sess = newSession(i) 139 mr.sessions[i] = sess 140 } 141 return sess.findOrMakeSimplexSession(sender, lt) 142 } 143 144 func (mr *mockRouter) Post(i SessionID, sender DeviceID, seqno Seqno, msg []byte) error { 145 ss := mr.findOrMakeSimplexSession(i, sender, bySender) 146 corruptMessage(mr.behavior, msg) 147 return ss.post(seqno, msg) 148 } 149 150 func (ss *simplexSession) get(seqno Seqno, poll time.Duration, behavior int) (ret [][]byte, err error) { 151 timeout := false 152 handleMessage := func(msg message) { 153 ret = append(ret, msg.msg) 154 } 155 if poll.Nanoseconds() > 0 { 156 select { 157 case msg := <-ss.ch: 158 handleMessage(msg) 159 case <-time.After(poll): 160 timeout = true 161 } 162 } 163 if !timeout { 164 loopMessages: 165 for { 166 select { 167 case msg := <-ss.ch: 168 handleMessage(msg) 169 default: 170 break loopMessages 171 } 172 } 173 } 174 175 if (behavior&BadRouterReorder) != 0 && len(ret) > 1 { 176 ret[0], ret[1] = ret[1], ret[0] 177 } 178 if (behavior&BadRouterDrop) != 0 && len(ret) > 1 { 179 ret = ret[1:] 180 } 181 182 return ret, err 183 } 184 185 func (mr *mockRouter) Get(i SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) ([][]byte, error) { 186 ss := mr.findOrMakeSimplexSession(i, receiver, byReceiver) 187 if mr.maxPoll > time.Duration(0) && poll > mr.maxPoll { 188 poll = mr.maxPoll 189 } 190 return ss.get(seqno, poll, mr.behavior) 191 } 192 193 func genSecret(t *testing.T) (ret Secret) { 194 _, err := rand.Read(ret[:]) 195 if err != nil { 196 t.Fatal(err) 197 } 198 return ret 199 } 200 201 func genDeviceID(t *testing.T) (ret DeviceID) { 202 _, err := rand.Read(ret[:]) 203 if err != nil { 204 t.Fatal(err) 205 } 206 return ret 207 } 208 209 type testLogCtx struct { 210 sync.Mutex 211 t *testing.T 212 } 213 214 func newTestLogCtx(t *testing.T) (ret *testLogCtx, closer func()) { 215 ret = &testLogCtx{t: t} 216 closer = func() { 217 ret.Lock() 218 defer ret.Unlock() 219 ret.t = nil 220 } 221 return ret, closer 222 } 223 224 func (t *testLogCtx) Debug(format string, args ...interface{}) { 225 t.Lock() 226 if t.t != nil { 227 t.t.Logf(format, args...) 228 } 229 t.Unlock() 230 } 231 232 func genNewConn(t *testLogCtx, mr MessageRouter, s Secret, d DeviceID, rt time.Duration) net.Conn { 233 ret, err := NewConn(context.TODO(), t, mr, s, d, rt) 234 if err != nil { 235 t.t.Fatal(err) 236 } 237 return ret 238 } 239 240 func genConnPair(t *testLogCtx, behavior int, readTimeout time.Duration) (c1 net.Conn, c2 net.Conn, d1 DeviceID, d2 DeviceID) { 241 r := newMockRouterWithBehavior(behavior) 242 s := genSecret(t.t) 243 d1 = genDeviceID(t.t) 244 d2 = genDeviceID(t.t) 245 c1 = genNewConn(t, r, s, d1, readTimeout) 246 c2 = genNewConn(t, r, s, d2, readTimeout) 247 return 248 } 249 250 func maybeDisableTest(t *testing.T) { 251 if runtime.GOOS == "windows" { 252 t.Skip() 253 } 254 } 255 256 func TestHello(t *testing.T) { 257 testLogCtx, cleanup := newTestLogCtx(t) 258 defer cleanup() 259 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 260 txt := []byte("hello friend") 261 if _, err := c1.Write(txt); err != nil { 262 t.Fatal(err) 263 } 264 buf := make([]byte, 100) 265 if n, err := c2.Read(buf); err != nil { 266 t.Fatal(err) 267 } else if n != len(txt) { 268 t.Fatal("bad read len") 269 } else if !bytes.Equal(buf[0:n], txt) { 270 t.Fatal("wrong message back") 271 } 272 txt2 := []byte("pong PONG pong PONG pong PONG") 273 if _, err := c2.Write(txt2); err != nil { 274 t.Fatal(err) 275 } else if n, err := c1.Read(buf); err != nil { 276 t.Fatal(err) 277 } else if n != len(txt2) { 278 t.Fatal("bad read len") 279 } else if !bytes.Equal(buf[0:n], txt2) { 280 t.Fatal("wrong ponged text") 281 } 282 } 283 284 func TestBadMetadata(t *testing.T) { 285 testLogCtx, cleanup := newTestLogCtx(t) 286 defer cleanup() 287 288 testBehavior := func(b int, wanted error) { 289 c1, c2, _, _ := genConnPair(testLogCtx, b, time.Duration(0)) 290 txt := []byte("hello friend") 291 if _, err := c1.Write(txt); err != nil { 292 t.Fatal(err) 293 } 294 buf := make([]byte, 100) 295 if _, err := c2.Read(buf); err == nil { 296 t.Fatalf("behavior %d: wanted an error, didn't get one", b) 297 } else if err != wanted { 298 t.Fatalf("behavior %d: wanted error '%v', got '%v'", b, err, wanted) 299 } 300 } 301 testBehavior(BadRouterCorruptedSession, ErrBadMetadata) 302 testBehavior(BadRouterCorruptedSender, ErrBadMetadata) 303 testBehavior(BadRouterCorruptedCiphertext, ErrDecryption) 304 } 305 306 func TestReadDeadline(t *testing.T) { 307 testLogCtx, cleanup := newTestLogCtx(t) 308 defer cleanup() 309 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 310 wait := time.Duration(10) * time.Millisecond 311 err := c2.SetReadDeadline(time.Now().Add(wait)) 312 require.NoError(t, err) 313 go func() { 314 time.Sleep(wait * 2) 315 _, _ = c1.Write([]byte("hello friend")) 316 }() 317 buf := make([]byte, 100) 318 _, err = c2.Read(buf) 319 if err != ErrTimedOut { 320 t.Fatalf("wanted a read timeout") 321 } 322 } 323 324 func TestReadTimeout(t *testing.T) { 325 testLogCtx, cleanup := newTestLogCtx(t) 326 defer cleanup() 327 wait := time.Duration(10) * time.Millisecond 328 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, wait) 329 go func() { 330 time.Sleep(wait * 2) 331 _, _ = c1.Write([]byte("hello friend")) 332 }() 333 buf := make([]byte, 100) 334 _, err := c2.Read(buf) 335 if err != ErrTimedOut { 336 t.Fatalf("wanted a read timeout") 337 } 338 } 339 340 func TestReadDelayedWrite(t *testing.T) { 341 maybeDisableTest(t) 342 testLogCtx, cleanup := newTestLogCtx(t) 343 defer cleanup() 344 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 345 wait := time.Duration(50) * time.Millisecond 346 err := c2.SetReadDeadline(time.Now().Add(wait)) 347 require.NoError(t, err) 348 text := "hello friend" 349 go func() { 350 time.Sleep(wait / 32) 351 _, _ = c1.Write([]byte(text)) 352 }() 353 buf := make([]byte, 100) 354 n, err := c2.Read(buf) 355 if err != nil { 356 t.Fatal(err) 357 } 358 if n != len(text) { 359 t.Fatalf("wrong read length") 360 } 361 } 362 363 func TestMultipleWritesOneRead(t *testing.T) { 364 maybeDisableTest(t) 365 testLogCtx, cleanup := newTestLogCtx(t) 366 defer cleanup() 367 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 368 msgs := []string{ 369 "Alas, poor Yorick! I knew him, Horatio: a fellow", 370 "of infinite jest, of most excellent fancy: he hath", 371 "borne me on his back a thousand times; and now, how", 372 "abhorred in my imagination it is! my gorge rims at", 373 "it.", 374 } 375 for i, m := range msgs { 376 if i > 0 { 377 m = "\n" + m 378 } 379 if _, err := c1.Write([]byte(m)); err != nil { 380 t.Fatal(err) 381 } 382 } 383 buf := make([]byte, 1000) 384 if n, err := c2.Read(buf); err != nil { 385 t.Fatal(err) 386 } else if strings.Join(msgs, "\n") != string(buf[0:n]) { 387 t.Fatal("string mismatch") 388 } 389 } 390 391 func TestOneWriteMultipleReads(t *testing.T) { 392 testLogCtx, cleanup := newTestLogCtx(t) 393 defer cleanup() 394 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 395 msg := `Crows maunder on the petrified fairway. 396 Absence! My heart grows tense 397 as though a harpoon were sparring for the kill.` 398 if _, err := c1.Write([]byte(msg)); err != nil { 399 return 400 } 401 small := make([]byte, 3) 402 var buf []byte 403 for { 404 if n, err := c2.Read(small); err != nil && err != ErrAgain { 405 t.Fatal(err) 406 } else if n == 0 { 407 if err != ErrAgain { 408 t.Fatalf("exepcted ErrAgain if we read 0 bytes, but got %v", err) 409 } 410 break 411 } else { 412 buf = append(buf, small[0:n]...) 413 } 414 } 415 if string(buf) != msg { 416 t.Fatal("message mismatch") 417 } 418 } 419 420 func TestReorder(t *testing.T) { 421 testLogCtx, cleanup := newTestLogCtx(t) 422 defer cleanup() 423 c1, c2, _, _ := genConnPair(testLogCtx, BadRouterReorder, time.Duration(0)) 424 msgs := []string{ 425 "Alas, poor Yorick! I knew him, Horatio: a fellow", 426 "of infinite jest, of most excellent fancy: he hath", 427 "borne me on his back a thousand times; and now, how", 428 "abhorred in my imagination it is! my gorge rims at", 429 "it.", 430 } 431 for i, m := range msgs { 432 if i > 0 { 433 m = "\n" + m 434 } 435 if _, err := c1.Write([]byte(m)); err != nil { 436 t.Fatal(err) 437 } 438 } 439 buf := make([]byte, 1000) 440 _, err := c2.Read(buf) 441 if _, ok := err.(ErrBadPacketSequence); !ok { 442 t.Fatalf("expected an ErrBadPacketSequence; got %v", err) 443 } 444 } 445 446 func TestDrop(t *testing.T) { 447 testLogCtx, cleanup := newTestLogCtx(t) 448 defer cleanup() 449 c1, c2, _, _ := genConnPair(testLogCtx, BadRouterDrop, time.Duration(0)) 450 msgs := []string{ 451 "Alas, poor Yorick! I knew him, Horatio: a fellow", 452 "of infinite jest, of most excellent fancy: he hath", 453 "borne me on his back a thousand times; and now, how", 454 "abhorred in my imagination it is! my gorge rims at", 455 "it.", 456 } 457 for i, m := range msgs { 458 if i > 0 { 459 m = "\n" + m 460 } 461 if _, err := c1.Write([]byte(m)); err != nil { 462 t.Fatal(err) 463 } 464 } 465 buf := make([]byte, 1000) 466 _, err := c2.Read(buf) 467 if _, ok := err.(ErrBadPacketSequence); !ok { 468 t.Fatalf("expected an ErrBadPacketSequence; got %v", err) 469 } 470 } 471 472 func TestClose(t *testing.T) { 473 testLogCtx, cleanup := newTestLogCtx(t) 474 defer cleanup() 475 c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(4)*time.Second) 476 msg := "Hello friend. I'm going to mic drop." 477 if _, err := c1.Write([]byte(msg)); err != nil { 478 t.Fatal(err) 479 } 480 if err := c1.Close(); err != nil { 481 t.Fatal(err) 482 } 483 buf := make([]byte, 1000) 484 if n, err := c2.Read(buf); err != nil { 485 t.Fatal(err) 486 } else if n != len(msg) { 487 t.Fatalf("short read: %d v %d: %v", n, len(msg), msg) 488 } else if string(buf[0:n]) != msg { 489 t.Fatal("wrong msg") 490 } 491 492 // Assert we get an EOF now and forever... 493 for i := 0; i < 3; i++ { 494 if n, err := c2.Read(buf); err != io.EOF { 495 t.Fatalf("expected EOF, but got err = %v", err) 496 } else if n != 0 { 497 t.Fatalf("Expected 0-length read, but got %d", n) 498 } 499 } 500 } 501 502 func TestErrAgain(t *testing.T) { 503 testLogCtx, cleanup := newTestLogCtx(t) 504 defer cleanup() 505 _, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0)) 506 buf := make([]byte, 100) 507 if n, err := c2.Read(buf); err != ErrAgain { 508 t.Fatalf("wanted ErrAgain, but got err = %v", err) 509 } else if n != 0 { 510 t.Fatalf("Wanted 0 bytes back; got %d", n) 511 } 512 } 513 514 func TestPollLoopSuccess(t *testing.T) { 515 maybeDisableTest(t) 516 517 testLogCtx, cleanup := newTestLogCtx(t) 518 defer cleanup() 519 520 wait := time.Duration(100) * time.Millisecond 521 r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/128) 522 s := genSecret(t) 523 d1 := genDeviceID(t) 524 d2 := genDeviceID(t) 525 c1 := genNewConn(testLogCtx, r, s, d1, wait) 526 c2 := genNewConn(testLogCtx, r, s, d2, wait) 527 528 text := "poll for this, will you?" 529 530 go func() { 531 time.Sleep(wait / 32) 532 _, _ = c1.Write([]byte(text)) 533 }() 534 buf := make([]byte, 100) 535 n, err := c2.Read(buf) 536 if err != nil { 537 t.Fatal(err) 538 } 539 if n != len(text) { 540 t.Fatalf("wrong read length") 541 } 542 } 543 544 func TestPollLoopTimeout(t *testing.T) { 545 maybeDisableTest(t) 546 547 testLogCtx, cleanup := newTestLogCtx(t) 548 defer cleanup() 549 550 wait := time.Duration(8) * time.Millisecond 551 r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/32) 552 s := genSecret(t) 553 d1 := genDeviceID(t) 554 d2 := genDeviceID(t) 555 c1 := genNewConn(testLogCtx, r, s, d1, wait) 556 c2 := genNewConn(testLogCtx, r, s, d2, wait) 557 558 text := "poll for this, will you?" 559 560 go func() { 561 time.Sleep(wait * 2) 562 _, _ = c1.Write([]byte(text)) 563 }() 564 buf := make([]byte, 100) 565 if _, err := c2.Read(buf); err != ErrTimedOut { 566 t.Fatalf("Wanted ErrTimedOut; got %v", err) 567 } 568 }