github.com/glycerine/xcryptossh@v7.0.4+incompatible/mux_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "context" 9 "io" 10 "io/ioutil" 11 "sync" 12 "testing" 13 ) 14 15 func muxPair(halt *Halter) (*mux, *mux) { 16 a, b := memPipe() 17 18 ctx := context.Background() 19 20 s := newMux(ctx, a, halt) 21 c := newMux(ctx, b, halt) 22 23 return s, c 24 } 25 26 // Returns both ends of a channel, and the mux for the the 2nd 27 // channel. 28 func channelPair(t *testing.T, halt *Halter) (*channel, *channel, *mux) { 29 c, s := muxPair(halt) 30 31 res := make(chan *channel, 1) 32 go func() { 33 newCh, ok := <-s.incomingChannels 34 if !ok { 35 t.Fatalf("No incoming channel") 36 } 37 if newCh.ChannelType() != "chan" { 38 t.Fatalf("got type %q want chan", newCh.ChannelType()) 39 } 40 ch, _, err := newCh.Accept() 41 if err != nil { 42 t.Fatalf("Accept %v", err) 43 } 44 res <- ch.(*channel) 45 }() 46 ctx := context.Background() 47 48 chc, err := c.openChannel(ctx, "chan", nil, nil) 49 if err != nil { 50 t.Fatalf("OpenChannel: %v", err) 51 } 52 53 chs := <-res 54 55 // setup the idleTimers on the memTransports. 56 tc := c.conn.(*memTransport) 57 tc.Lock() 58 tc.idle = chc.idleR 59 tc.Unlock() 60 61 ts := s.conn.(*memTransport) 62 ts.Lock() 63 ts.idle = chs.idleR 64 ts.Unlock() 65 66 return chs, chc, c 67 } 68 69 // Test that stderr and stdout can be addressed from different 70 // goroutines. This is intended for use with the race detector. 71 func TestMuxChannelExtendedThreadSafety(t *testing.T) { 72 defer xtestend(xtestbegin(t)) 73 74 halt := NewHalter() 75 defer halt.RequestStop() 76 77 writer, reader, mux := channelPair(t, halt) 78 defer writer.Close() 79 defer reader.Close() 80 defer mux.Close() 81 82 var wr, rd sync.WaitGroup 83 magic := "hello world" 84 85 wr.Add(2) 86 go func() { 87 io.WriteString(writer, magic) 88 wr.Done() 89 }() 90 go func() { 91 io.WriteString(writer.Stderr(), magic) 92 wr.Done() 93 }() 94 95 rd.Add(2) 96 go func() { 97 c, err := ioutil.ReadAll(reader) 98 if string(c) != magic { 99 t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) 100 } 101 rd.Done() 102 }() 103 go func() { 104 c, err := ioutil.ReadAll(reader.Stderr()) 105 if string(c) != magic { 106 t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) 107 } 108 rd.Done() 109 }() 110 111 wr.Wait() 112 writer.CloseWrite() 113 rd.Wait() 114 } 115 116 func TestMuxReadWrite(t *testing.T) { 117 defer xtestend(xtestbegin(t)) 118 119 halt := NewHalter() 120 defer halt.RequestStop() 121 122 s, c, mux := channelPair(t, halt) 123 defer s.Close() 124 defer c.Close() 125 defer mux.Close() 126 127 writeDone := make(chan bool) 128 129 magic := "hello world" 130 magicExt := "hello stderr" 131 go func() { 132 _, err := s.Write([]byte(magic)) 133 if err != nil { 134 t.Fatalf("Write: %v", err) 135 } 136 _, err = s.Extended(1).Write([]byte(magicExt)) 137 if err != nil { 138 t.Fatalf("Write: %v", err) 139 } 140 err = s.Close() 141 if err != nil { 142 t.Fatalf("Close: %v", err) 143 } 144 145 close(writeDone) 146 }() 147 148 var buf [1024]byte 149 n, err := c.Read(buf[:]) 150 if err != nil { 151 t.Fatalf("server Read: %v", err) 152 } 153 got := string(buf[:n]) 154 if got != magic { 155 t.Fatalf("server: got %q want %q", got, magic) 156 } 157 158 n, err = c.Extended(1).Read(buf[:]) 159 if err != nil { 160 t.Fatalf("server Read: %v", err) 161 } 162 163 got = string(buf[:n]) 164 if got != magicExt { 165 t.Fatalf("server: got %q want %q", got, magic) 166 } 167 168 <-writeDone 169 } 170 171 func TestMuxChannelOverflow(t *testing.T) { 172 defer xtestend(xtestbegin(t)) 173 174 halt := NewHalter() 175 defer halt.RequestStop() 176 177 reader, writer, mux := channelPair(t, halt) 178 defer reader.Close() 179 defer writer.Close() 180 defer mux.Close() 181 182 wDone := make(chan int, 1) 183 go func() { 184 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 185 t.Errorf("could not fill window: %v", err) 186 } 187 writer.Write(make([]byte, 1)) 188 wDone <- 1 189 }() 190 writer.remoteWin.waitWriterBlocked() 191 192 // Send 1 byte. 193 packet := make([]byte, 1+4+4+1) 194 packet[0] = msgChannelData 195 marshalUint32(packet[1:], writer.remoteId) 196 marshalUint32(packet[5:], uint32(1)) 197 packet[9] = 42 198 199 if err := writer.mux.conn.writePacket(packet); err != nil { 200 t.Errorf("could not send packet") 201 } 202 if _, err := reader.SendRequest("hello", true, nil); err == nil { 203 t.Errorf("SendRequest succeeded.") 204 } 205 <-wDone 206 } 207 208 func TestMuxChannelCloseWriteUnblock(t *testing.T) { 209 defer xtestend(xtestbegin(t)) 210 211 halt := NewHalter() 212 defer halt.RequestStop() 213 214 reader, writer, mux := channelPair(t, halt) 215 defer reader.Close() 216 defer writer.Close() 217 defer mux.Close() 218 219 wDone := make(chan int, 1) 220 go func() { 221 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 222 t.Errorf("could not fill window: %v", err) 223 } 224 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 225 t.Errorf("got %v, want EOF for unblock write", err) 226 } 227 wDone <- 1 228 }() 229 230 writer.remoteWin.waitWriterBlocked() 231 reader.Close() 232 <-wDone 233 } 234 235 func TestMuxConnectionCloseWriteUnblock(t *testing.T) { 236 defer xtestend(xtestbegin(t)) 237 238 halt := NewHalter() 239 defer halt.RequestStop() 240 241 reader, writer, mux := channelPair(t, halt) 242 defer reader.Close() 243 defer writer.Close() 244 defer mux.Close() 245 246 wDone := make(chan int, 1) 247 go func() { 248 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 249 t.Errorf("could not fill window: %v", err) 250 } 251 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 252 t.Errorf("got %v, want EOF for unblock write", err) 253 } 254 wDone <- 1 255 }() 256 257 writer.remoteWin.waitWriterBlocked() 258 mux.Close() 259 <-wDone 260 } 261 262 func TestMuxReject(t *testing.T) { 263 defer xtestend(xtestbegin(t)) 264 265 halt := NewHalter() 266 defer halt.RequestStop() 267 268 client, server := muxPair(halt) 269 defer server.Close() 270 defer client.Close() 271 272 go func() { 273 ch, ok := <-server.incomingChannels 274 if !ok { 275 t.Fatalf("Accept") 276 } 277 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 278 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 279 } 280 ch.Reject(RejectionReason(42), "message") 281 }() 282 ctx := context.Background() 283 284 ch, err := client.openChannel(ctx, "ch", []byte("extra"), nil) 285 if ch != nil { 286 t.Fatal("openChannel not rejected") 287 } 288 289 ocf, ok := err.(*OpenChannelError) 290 if !ok { 291 t.Errorf("got %#v want *OpenChannelError", err) 292 } else if ocf.Reason != 42 || ocf.Message != "message" { 293 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") 294 } 295 296 want := "ssh: rejected: unknown reason 42 (message)" 297 if err.Error() != want { 298 t.Errorf("got %q, want %q", err.Error(), want) 299 } 300 } 301 302 func TestMuxChannelRequest(t *testing.T) { 303 defer xtestend(xtestbegin(t)) 304 305 halt := NewHalter() 306 defer halt.RequestStop() 307 308 client, server, mux := channelPair(t, halt) 309 defer server.Close() 310 defer client.Close() 311 defer mux.Close() 312 313 var received int 314 var wg sync.WaitGroup 315 wg.Add(1) 316 go func() { 317 for r := range server.incomingRequests { 318 received++ 319 r.Reply(r.Type == "yes", nil) 320 } 321 wg.Done() 322 }() 323 _, err := client.SendRequest("yes", false, nil) 324 if err != nil { 325 t.Fatalf("SendRequest: %v", err) 326 } 327 ok, err := client.SendRequest("yes", true, nil) 328 if err != nil { 329 t.Fatalf("SendRequest: %v", err) 330 } 331 332 if !ok { 333 t.Errorf("SendRequest(yes): %v", ok) 334 335 } 336 337 ok, err = client.SendRequest("no", true, nil) 338 if err != nil { 339 t.Fatalf("SendRequest: %v", err) 340 } 341 if ok { 342 t.Errorf("SendRequest(no): %v", ok) 343 344 } 345 346 client.Close() 347 wg.Wait() 348 349 if received != 3 { 350 t.Errorf("got %d requests, want %d", received, 3) 351 } 352 } 353 354 func TestMuxGlobalRequest(t *testing.T) { 355 defer xtestend(xtestbegin(t)) 356 357 halt := NewHalter() 358 defer halt.RequestStop() 359 360 clientMux, serverMux := muxPair(halt) 361 defer serverMux.Close() 362 defer clientMux.Close() 363 364 var seen bool 365 go func() { 366 for r := range serverMux.incomingRequests { 367 seen = seen || r.Type == "peek" 368 if r.WantReply { 369 err := r.Reply(r.Type == "yes", 370 append([]byte(r.Type), r.Payload...)) 371 if err != nil { 372 t.Errorf("AckRequest: %v", err) 373 } 374 } 375 } 376 }() 377 ctx := context.Background() 378 379 _, _, err := clientMux.SendRequest(ctx, "peek", false, nil) 380 if err != nil { 381 t.Errorf("SendRequest: %v", err) 382 } 383 384 ok, data, err := clientMux.SendRequest(ctx, "yes", true, []byte("a")) 385 if !ok || string(data) != "yesa" || err != nil { 386 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 387 ok, data, err) 388 } 389 if ok, data, err := clientMux.SendRequest(ctx, "yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { 390 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 391 ok, data, err) 392 } 393 394 if ok, data, err := clientMux.SendRequest(ctx, "no", true, []byte("a")); ok || string(data) != "noa" || err != nil { 395 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", 396 ok, data, err) 397 } 398 399 if !seen { 400 t.Errorf("never saw 'peek' request") 401 } 402 } 403 404 func TestMuxGlobalRequestUnblock(t *testing.T) { 405 defer xtestend(xtestbegin(t)) 406 halt := NewHalter() 407 defer halt.RequestStop() 408 409 clientMux, serverMux := muxPair(halt) 410 defer serverMux.Close() 411 defer clientMux.Close() 412 413 result := make(chan error, 1) 414 ctx := context.Background() 415 416 go func() { 417 _, _, err := clientMux.SendRequest(ctx, "hello", true, nil) 418 result <- err 419 }() 420 421 <-serverMux.incomingRequests 422 serverMux.conn.Close() 423 err := <-result 424 425 if err != io.EOF { 426 t.Errorf("want EOF, got %v", err) 427 } 428 } 429 430 func TestMuxChannelRequestUnblock(t *testing.T) { 431 defer xtestend(xtestbegin(t)) 432 halt := NewHalter() 433 defer halt.RequestStop() 434 435 a, b, connB := channelPair(t, halt) 436 defer a.Close() 437 defer b.Close() 438 defer connB.Close() 439 440 result := make(chan error, 1) 441 go func() { 442 _, err := a.SendRequest("hello", true, nil) 443 result <- err 444 }() 445 446 <-b.incomingRequests 447 connB.conn.Close() 448 err := <-result 449 450 if err != io.EOF { 451 t.Errorf("want EOF, got %v", err) 452 } 453 } 454 455 func TestMuxCloseChannel(t *testing.T) { 456 defer xtestend(xtestbegin(t)) 457 halt := NewHalter() 458 defer halt.RequestStop() 459 460 r, w, mux := channelPair(t, halt) 461 defer mux.Close() 462 defer r.Close() 463 defer w.Close() 464 465 result := make(chan error, 1) 466 go func() { 467 var b [1024]byte 468 _, err := r.Read(b[:]) 469 result <- err 470 }() 471 if err := w.Close(); err != nil { 472 t.Errorf("w.Close: %v", err) 473 } 474 475 if _, err := w.Write([]byte("hello")); err != io.EOF { 476 t.Errorf("got err %v, want io.EOF after Close", err) 477 } 478 479 if err := <-result; err != io.EOF { 480 t.Errorf("got %v (%T), want io.EOF", err, err) 481 } 482 } 483 484 func TestMuxCloseWriteChannel(t *testing.T) { 485 defer xtestend(xtestbegin(t)) 486 halt := NewHalter() 487 defer halt.RequestStop() 488 489 r, w, mux := channelPair(t, halt) 490 defer mux.Close() 491 492 result := make(chan error, 1) 493 go func() { 494 var b [1024]byte 495 _, err := r.Read(b[:]) 496 result <- err 497 }() 498 if err := w.CloseWrite(); err != nil { 499 t.Errorf("w.CloseWrite: %v", err) 500 } 501 502 if _, err := w.Write([]byte("hello")); err != io.EOF { 503 t.Errorf("got err %v, want io.EOF after CloseWrite", err) 504 } 505 506 if err := <-result; err != io.EOF { 507 t.Errorf("got %v (%T), want io.EOF", err, err) 508 } 509 } 510 511 func TestMuxInvalidRecord(t *testing.T) { 512 defer xtestend(xtestbegin(t)) 513 halt := NewHalter() 514 defer halt.RequestStop() 515 516 a, b := muxPair(halt) 517 defer a.Close() 518 defer b.Close() 519 520 packet := make([]byte, 1+4+4+1) 521 packet[0] = msgChannelData 522 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 523 marshalUint32(packet[5:], 1) 524 packet[9] = 42 525 526 a.conn.writePacket(packet) 527 ctx := context.Background() 528 529 go a.SendRequest(ctx, "hello", false, nil) 530 // 'a' wrote an invalid packet, so 'b' has exited. 531 req, ok := <-b.incomingRequests 532 if ok { 533 t.Errorf("got request %#v after receiving invalid packet", req) 534 } 535 } 536 537 func TestZeroWindowAdjust(t *testing.T) { 538 defer xtestend(xtestbegin(t)) 539 540 halt := NewHalter() 541 defer halt.RequestStop() 542 543 a, b, mux := channelPair(t, halt) 544 defer a.Close() 545 defer b.Close() 546 defer mux.Close() 547 548 go func() { 549 io.WriteString(a, "hello") 550 // bogus adjust. 551 a.sendMessage(windowAdjustMsg{}) 552 io.WriteString(a, "world") 553 a.Close() 554 }() 555 556 want := "helloworld" 557 c, _ := ioutil.ReadAll(b) 558 if string(c) != want { 559 t.Errorf("got %q want %q", c, want) 560 } 561 } 562 563 func TestMuxMaxPacketSize(t *testing.T) { 564 defer xtestend(xtestbegin(t)) 565 566 halt := NewHalter() 567 defer halt.RequestStop() 568 569 a, b, mux := channelPair(t, halt) 570 defer a.Close() 571 defer b.Close() 572 defer mux.Close() 573 574 large := make([]byte, a.maxRemotePayload+1) 575 packet := make([]byte, 1+4+4+1+len(large)) 576 packet[0] = msgChannelData 577 marshalUint32(packet[1:], a.remoteId) 578 marshalUint32(packet[5:], uint32(len(large))) 579 packet[9] = 42 580 581 if err := a.mux.conn.writePacket(packet); err != nil { 582 t.Errorf("could not send packet") 583 } 584 585 go a.SendRequest("hello", false, nil) 586 587 _, ok := <-b.incomingRequests 588 if ok { 589 t.Errorf("connection still alive after receiving large packet.") 590 } 591 } 592 593 // Don't ship code with debug=true. 594 func TestDebug(t *testing.T) { 595 defer xtestend(xtestbegin(t)) 596 if debugMux { 597 t.Error("mux debug switched on") 598 } 599 if debugHandshake { 600 t.Error("handshake debug switched on") 601 } 602 if debugTransport { 603 t.Error("transport debug switched on") 604 } 605 }