github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/mux_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "io" 9 "io/ioutil" 10 "sync" 11 "testing" 12 "time" 13 ) 14 15 // PSIPHON 16 // ======= 17 // See comment in channel.go 18 var testChannelWindowSize = getChannelWindowSize("") 19 20 func muxPair() (*mux, *mux) { 21 a, b := memPipe() 22 23 s := newMux(a) 24 c := newMux(b) 25 26 return s, c 27 } 28 29 // Returns both ends of a channel, and the mux for the 2nd 30 // channel. 31 func channelPair(t *testing.T) (*channel, *channel, *mux) { 32 c, s := muxPair() 33 34 res := make(chan *channel, 1) 35 go func() { 36 newCh, ok := <-s.incomingChannels 37 if !ok { 38 t.Fatalf("No incoming channel") 39 } 40 if newCh.ChannelType() != "chan" { 41 t.Fatalf("got type %q want chan", newCh.ChannelType()) 42 } 43 ch, _, err := newCh.Accept() 44 if err != nil { 45 t.Fatalf("Accept %v", err) 46 } 47 res <- ch.(*channel) 48 }() 49 50 ch, err := c.openChannel("chan", nil) 51 if err != nil { 52 t.Fatalf("OpenChannel: %v", err) 53 } 54 55 return <-res, ch, c 56 } 57 58 // Test that stderr and stdout can be addressed from different 59 // goroutines. This is intended for use with the race detector. 60 func TestMuxChannelExtendedThreadSafety(t *testing.T) { 61 writer, reader, mux := channelPair(t) 62 defer writer.Close() 63 defer reader.Close() 64 defer mux.Close() 65 66 var wr, rd sync.WaitGroup 67 magic := "hello world" 68 69 wr.Add(2) 70 go func() { 71 io.WriteString(writer, magic) 72 wr.Done() 73 }() 74 go func() { 75 io.WriteString(writer.Stderr(), magic) 76 wr.Done() 77 }() 78 79 rd.Add(2) 80 go func() { 81 c, err := ioutil.ReadAll(reader) 82 if string(c) != magic { 83 t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) 84 } 85 rd.Done() 86 }() 87 go func() { 88 c, err := ioutil.ReadAll(reader.Stderr()) 89 if string(c) != magic { 90 t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) 91 } 92 rd.Done() 93 }() 94 95 wr.Wait() 96 writer.CloseWrite() 97 rd.Wait() 98 } 99 100 func TestMuxReadWrite(t *testing.T) { 101 s, c, mux := channelPair(t) 102 defer s.Close() 103 defer c.Close() 104 defer mux.Close() 105 106 magic := "hello world" 107 magicExt := "hello stderr" 108 go func() { 109 _, err := s.Write([]byte(magic)) 110 if err != nil { 111 t.Fatalf("Write: %v", err) 112 } 113 _, err = s.Extended(1).Write([]byte(magicExt)) 114 if err != nil { 115 t.Fatalf("Write: %v", err) 116 } 117 }() 118 119 var buf [1024]byte 120 n, err := c.Read(buf[:]) 121 if err != nil { 122 t.Fatalf("server Read: %v", err) 123 } 124 got := string(buf[:n]) 125 if got != magic { 126 t.Fatalf("server: got %q want %q", got, magic) 127 } 128 129 n, err = c.Extended(1).Read(buf[:]) 130 if err != nil { 131 t.Fatalf("server Read: %v", err) 132 } 133 134 got = string(buf[:n]) 135 if got != magicExt { 136 t.Fatalf("server: got %q want %q", got, magic) 137 } 138 } 139 140 func TestMuxChannelOverflow(t *testing.T) { 141 reader, writer, mux := channelPair(t) 142 defer reader.Close() 143 defer writer.Close() 144 defer mux.Close() 145 146 wDone := make(chan int, 1) 147 go func() { 148 if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil { 149 t.Errorf("could not fill window: %v", err) 150 } 151 writer.Write(make([]byte, 1)) 152 wDone <- 1 153 }() 154 writer.remoteWin.waitWriterBlocked() 155 156 // Send 1 byte. 157 packet := make([]byte, 1+4+4+1) 158 packet[0] = msgChannelData 159 marshalUint32(packet[1:], writer.remoteId) 160 marshalUint32(packet[5:], uint32(1)) 161 packet[9] = 42 162 163 if err := writer.mux.conn.writePacket(packet); err != nil { 164 t.Errorf("could not send packet") 165 } 166 if _, err := reader.SendRequest("hello", true, nil); err == nil { 167 t.Errorf("SendRequest succeeded.") 168 } 169 <-wDone 170 } 171 172 func TestMuxChannelCloseWriteUnblock(t *testing.T) { 173 reader, writer, mux := channelPair(t) 174 defer reader.Close() 175 defer writer.Close() 176 defer mux.Close() 177 178 wDone := make(chan int, 1) 179 go func() { 180 if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil { 181 t.Errorf("could not fill window: %v", err) 182 } 183 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 184 t.Errorf("got %v, want EOF for unblock write", err) 185 } 186 wDone <- 1 187 }() 188 189 writer.remoteWin.waitWriterBlocked() 190 reader.Close() 191 <-wDone 192 } 193 194 func TestMuxConnectionCloseWriteUnblock(t *testing.T) { 195 reader, writer, mux := channelPair(t) 196 defer reader.Close() 197 defer writer.Close() 198 defer mux.Close() 199 200 wDone := make(chan int, 1) 201 go func() { 202 if _, err := writer.Write(make([]byte, testChannelWindowSize)); err != nil { 203 t.Errorf("could not fill window: %v", err) 204 } 205 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 206 t.Errorf("got %v, want EOF for unblock write", err) 207 } 208 wDone <- 1 209 }() 210 211 writer.remoteWin.waitWriterBlocked() 212 mux.Close() 213 <-wDone 214 } 215 216 func TestMuxReject(t *testing.T) { 217 client, server := muxPair() 218 defer server.Close() 219 defer client.Close() 220 221 go func() { 222 ch, ok := <-server.incomingChannels 223 if !ok { 224 t.Fatalf("Accept") 225 } 226 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 227 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 228 } 229 ch.Reject(RejectionReason(42), "message") 230 }() 231 232 ch, err := client.openChannel("ch", []byte("extra")) 233 if ch != nil { 234 t.Fatal("openChannel not rejected") 235 } 236 237 ocf, ok := err.(*OpenChannelError) 238 if !ok { 239 t.Errorf("got %#v want *OpenChannelError", err) 240 } else if ocf.Reason != 42 || ocf.Message != "message" { 241 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") 242 } 243 244 want := "ssh: rejected: unknown reason 42 (message)" 245 if err.Error() != want { 246 t.Errorf("got %q, want %q", err.Error(), want) 247 } 248 } 249 250 func TestMuxChannelRequest(t *testing.T) { 251 client, server, mux := channelPair(t) 252 defer server.Close() 253 defer client.Close() 254 defer mux.Close() 255 256 var received int 257 var wg sync.WaitGroup 258 wg.Add(1) 259 go func() { 260 for r := range server.incomingRequests { 261 received++ 262 r.Reply(r.Type == "yes", nil) 263 } 264 wg.Done() 265 }() 266 _, err := client.SendRequest("yes", false, nil) 267 if err != nil { 268 t.Fatalf("SendRequest: %v", err) 269 } 270 ok, err := client.SendRequest("yes", true, nil) 271 if err != nil { 272 t.Fatalf("SendRequest: %v", err) 273 } 274 275 if !ok { 276 t.Errorf("SendRequest(yes): %v", ok) 277 278 } 279 280 ok, err = client.SendRequest("no", true, nil) 281 if err != nil { 282 t.Fatalf("SendRequest: %v", err) 283 } 284 if ok { 285 t.Errorf("SendRequest(no): %v", ok) 286 287 } 288 289 client.Close() 290 wg.Wait() 291 292 if received != 3 { 293 t.Errorf("got %d requests, want %d", received, 3) 294 } 295 } 296 297 func TestMuxUnknownChannelRequests(t *testing.T) { 298 clientPipe, serverPipe := memPipe() 299 client := newMux(clientPipe) 300 defer serverPipe.Close() 301 defer client.Close() 302 303 kDone := make(chan struct{}) 304 go func() { 305 // Ignore unknown channel messages that don't want a reply. 306 err := serverPipe.writePacket(Marshal(channelRequestMsg{ 307 PeersID: 1, 308 Request: "keepalive@openssh.com", 309 WantReply: false, 310 RequestSpecificData: []byte{}, 311 })) 312 if err != nil { 313 t.Fatalf("send: %v", err) 314 } 315 316 // Send a keepalive, which should get a channel failure message 317 // in response. 318 err = serverPipe.writePacket(Marshal(channelRequestMsg{ 319 PeersID: 2, 320 Request: "keepalive@openssh.com", 321 WantReply: true, 322 RequestSpecificData: []byte{}, 323 })) 324 if err != nil { 325 t.Fatalf("send: %v", err) 326 } 327 328 packet, err := serverPipe.readPacket() 329 if err != nil { 330 t.Fatalf("read packet: %v", err) 331 } 332 decoded, err := decode(packet) 333 if err != nil { 334 t.Fatalf("decode failed: %v", err) 335 } 336 337 switch msg := decoded.(type) { 338 case *channelRequestFailureMsg: 339 if msg.PeersID != 2 { 340 t.Fatalf("received response to wrong message: %v", msg) 341 } 342 default: 343 t.Fatalf("unexpected channel message: %v", msg) 344 } 345 346 kDone <- struct{}{} 347 348 // Receive and respond to the keepalive to confirm the mux is 349 // still processing requests. 350 packet, err = serverPipe.readPacket() 351 if err != nil { 352 t.Fatalf("read packet: %v", err) 353 } 354 if packet[0] != msgGlobalRequest { 355 t.Fatalf("expected global request") 356 } 357 358 err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ 359 Data: []byte{}, 360 })) 361 if err != nil { 362 t.Fatalf("failed to send failure msg: %v", err) 363 } 364 365 close(kDone) 366 }() 367 368 // Wait for the server to send the keepalive message and receive back a 369 // response. 370 select { 371 case <-kDone: 372 case <-time.After(10 * time.Second): 373 t.Fatalf("server never received ack") 374 } 375 376 // Confirm client hasn't closed. 377 if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { 378 t.Fatalf("failed to send keepalive: %v", err) 379 } 380 381 select { 382 case <-kDone: 383 case <-time.After(10 * time.Second): 384 t.Fatalf("server never shut down") 385 } 386 } 387 388 func TestMuxClosedChannel(t *testing.T) { 389 clientPipe, serverPipe := memPipe() 390 client := newMux(clientPipe) 391 defer serverPipe.Close() 392 defer client.Close() 393 394 kDone := make(chan struct{}) 395 go func() { 396 // Open the channel. 397 packet, err := serverPipe.readPacket() 398 if err != nil { 399 t.Fatalf("read packet: %v", err) 400 } 401 if packet[0] != msgChannelOpen { 402 t.Fatalf("expected chan open") 403 } 404 405 var openMsg channelOpenMsg 406 if err := Unmarshal(packet, &openMsg); err != nil { 407 t.Fatalf("unmarshal: %v", err) 408 } 409 410 // Send back the opened channel confirmation. 411 err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{ 412 PeersID: openMsg.PeersID, 413 MyID: 0, 414 MyWindow: 0, 415 MaxPacketSize: channelMaxPacket, 416 })) 417 if err != nil { 418 t.Fatalf("send: %v", err) 419 } 420 421 // Close the channel. 422 err = serverPipe.writePacket(Marshal(channelCloseMsg{ 423 PeersID: openMsg.PeersID, 424 })) 425 if err != nil { 426 t.Fatalf("send: %v", err) 427 } 428 429 // Send a keepalive message on the channel we just closed. 430 err = serverPipe.writePacket(Marshal(channelRequestMsg{ 431 PeersID: openMsg.PeersID, 432 Request: "keepalive@openssh.com", 433 WantReply: true, 434 RequestSpecificData: []byte{}, 435 })) 436 if err != nil { 437 t.Fatalf("send: %v", err) 438 } 439 440 // Receive the channel closed response. 441 packet, err = serverPipe.readPacket() 442 if err != nil { 443 t.Fatalf("read packet: %v", err) 444 } 445 if packet[0] != msgChannelClose { 446 t.Fatalf("expected channel close") 447 } 448 449 // Receive the keepalive response failure. 450 packet, err = serverPipe.readPacket() 451 if err != nil { 452 t.Fatalf("read packet: %v", err) 453 } 454 if packet[0] != msgChannelFailure { 455 t.Fatalf("expected channel close") 456 } 457 kDone <- struct{}{} 458 459 // Receive and respond to the keepalive to confirm the mux is 460 // still processing requests. 461 packet, err = serverPipe.readPacket() 462 if err != nil { 463 t.Fatalf("read packet: %v", err) 464 } 465 if packet[0] != msgGlobalRequest { 466 t.Fatalf("expected global request") 467 } 468 469 err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ 470 Data: []byte{}, 471 })) 472 if err != nil { 473 t.Fatalf("failed to send failure msg: %v", err) 474 } 475 476 close(kDone) 477 }() 478 479 // Open a channel. 480 ch, err := client.openChannel("chan", nil) 481 if err != nil { 482 t.Fatalf("OpenChannel: %v", err) 483 } 484 defer ch.Close() 485 486 // Wait for the server to close the channel and send the keepalive. 487 select { 488 case <-kDone: 489 case <-time.After(10 * time.Second): 490 t.Fatalf("server never received ack") 491 } 492 493 // Make sure the channel closed. 494 if _, ok := <-ch.incomingRequests; ok { 495 t.Fatalf("channel not closed") 496 } 497 498 // Confirm client hasn't closed 499 if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { 500 t.Fatalf("failed to send keepalive: %v", err) 501 } 502 503 select { 504 case <-kDone: 505 case <-time.After(10 * time.Second): 506 t.Fatalf("server never shut down") 507 } 508 } 509 510 func TestMuxGlobalRequest(t *testing.T) { 511 clientMux, serverMux := muxPair() 512 defer serverMux.Close() 513 defer clientMux.Close() 514 515 var seen bool 516 go func() { 517 for r := range serverMux.incomingRequests { 518 seen = seen || r.Type == "peek" 519 if r.WantReply { 520 err := r.Reply(r.Type == "yes", 521 append([]byte(r.Type), r.Payload...)) 522 if err != nil { 523 t.Errorf("AckRequest: %v", err) 524 } 525 } 526 } 527 }() 528 529 _, _, err := clientMux.SendRequest("peek", false, nil) 530 if err != nil { 531 t.Errorf("SendRequest: %v", err) 532 } 533 534 ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) 535 if !ok || string(data) != "yesa" || err != nil { 536 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 537 ok, data, err) 538 } 539 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { 540 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 541 ok, data, err) 542 } 543 544 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { 545 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", 546 ok, data, err) 547 } 548 549 if !seen { 550 t.Errorf("never saw 'peek' request") 551 } 552 } 553 554 func TestMuxGlobalRequestUnblock(t *testing.T) { 555 clientMux, serverMux := muxPair() 556 defer serverMux.Close() 557 defer clientMux.Close() 558 559 result := make(chan error, 1) 560 go func() { 561 _, _, err := clientMux.SendRequest("hello", true, nil) 562 result <- err 563 }() 564 565 <-serverMux.incomingRequests 566 serverMux.conn.Close() 567 err := <-result 568 569 if err != io.EOF { 570 t.Errorf("want EOF, got %v", io.EOF) 571 } 572 } 573 574 func TestMuxChannelRequestUnblock(t *testing.T) { 575 a, b, connB := channelPair(t) 576 defer a.Close() 577 defer b.Close() 578 defer connB.Close() 579 580 result := make(chan error, 1) 581 go func() { 582 _, err := a.SendRequest("hello", true, nil) 583 result <- err 584 }() 585 586 <-b.incomingRequests 587 connB.conn.Close() 588 err := <-result 589 590 if err != io.EOF { 591 t.Errorf("want EOF, got %v", err) 592 } 593 } 594 595 func TestMuxCloseChannel(t *testing.T) { 596 r, w, mux := channelPair(t) 597 defer mux.Close() 598 defer r.Close() 599 defer w.Close() 600 601 result := make(chan error, 1) 602 go func() { 603 var b [1024]byte 604 _, err := r.Read(b[:]) 605 result <- err 606 }() 607 if err := w.Close(); err != nil { 608 t.Errorf("w.Close: %v", err) 609 } 610 611 if _, err := w.Write([]byte("hello")); err != io.EOF { 612 t.Errorf("got err %v, want io.EOF after Close", err) 613 } 614 615 if err := <-result; err != io.EOF { 616 t.Errorf("got %v (%T), want io.EOF", err, err) 617 } 618 } 619 620 func TestMuxCloseWriteChannel(t *testing.T) { 621 r, w, mux := channelPair(t) 622 defer mux.Close() 623 624 result := make(chan error, 1) 625 go func() { 626 var b [1024]byte 627 _, err := r.Read(b[:]) 628 result <- err 629 }() 630 if err := w.CloseWrite(); err != nil { 631 t.Errorf("w.CloseWrite: %v", err) 632 } 633 634 if _, err := w.Write([]byte("hello")); err != io.EOF { 635 t.Errorf("got err %v, want io.EOF after CloseWrite", err) 636 } 637 638 if err := <-result; err != io.EOF { 639 t.Errorf("got %v (%T), want io.EOF", err, err) 640 } 641 } 642 643 func TestMuxInvalidRecord(t *testing.T) { 644 a, b := muxPair() 645 defer a.Close() 646 defer b.Close() 647 648 packet := make([]byte, 1+4+4+1) 649 packet[0] = msgChannelData 650 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 651 marshalUint32(packet[5:], 1) 652 packet[9] = 42 653 654 a.conn.writePacket(packet) 655 go a.SendRequest("hello", false, nil) 656 // 'a' wrote an invalid packet, so 'b' has exited. 657 req, ok := <-b.incomingRequests 658 if ok { 659 t.Errorf("got request %#v after receiving invalid packet", req) 660 } 661 } 662 663 func TestZeroWindowAdjust(t *testing.T) { 664 a, b, mux := channelPair(t) 665 defer a.Close() 666 defer b.Close() 667 defer mux.Close() 668 669 go func() { 670 io.WriteString(a, "hello") 671 // bogus adjust. 672 a.sendMessage(windowAdjustMsg{}) 673 io.WriteString(a, "world") 674 a.Close() 675 }() 676 677 want := "helloworld" 678 c, _ := ioutil.ReadAll(b) 679 if string(c) != want { 680 t.Errorf("got %q want %q", c, want) 681 } 682 } 683 684 func TestMuxMaxPacketSize(t *testing.T) { 685 a, b, mux := channelPair(t) 686 defer a.Close() 687 defer b.Close() 688 defer mux.Close() 689 690 large := make([]byte, a.maxRemotePayload+1) 691 packet := make([]byte, 1+4+4+1+len(large)) 692 packet[0] = msgChannelData 693 marshalUint32(packet[1:], a.remoteId) 694 marshalUint32(packet[5:], uint32(len(large))) 695 packet[9] = 42 696 697 if err := a.mux.conn.writePacket(packet); err != nil { 698 t.Errorf("could not send packet") 699 } 700 701 go a.SendRequest("hello", false, nil) 702 703 _, ok := <-b.incomingRequests 704 if ok { 705 t.Errorf("connection still alive after receiving large packet.") 706 } 707 } 708 709 // Don't ship code with debug=true. 710 func TestDebug(t *testing.T) { 711 if debugMux { 712 t.Error("mux debug switched on") 713 } 714 if debugHandshake { 715 t.Error("handshake debug switched on") 716 } 717 if debugTransport { 718 t.Error("transport debug switched on") 719 } 720 }