github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/mux_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "io" 9 "io/ioutil" 10 "sync" 11 "testing" 12 ) 13 14 func muxPair() (*mux, *mux) { 15 a, b := memPipe() 16 17 s := newMux(a) 18 c := newMux(b) 19 20 return s, c 21 } 22 23 // Returns both ends of a channel, and the mux for the the 2nd 24 // channel. 25 func channelPair(t *testing.T) (*channel, *channel, *mux) { 26 c, s := muxPair() 27 28 res := make(chan *channel, 1) 29 go func() { 30 newCh, ok := <-s.incomingChannels 31 if !ok { 32 t.Fatalf("No incoming channel") 33 } 34 if newCh.ChannelType() != "chan" { 35 t.Fatalf("got type %q want chan", newCh.ChannelType()) 36 } 37 ch, _, err := newCh.Accept() 38 if err != nil { 39 t.Fatalf("Accept %v", err) 40 } 41 res <- ch.(*channel) 42 }() 43 44 ch, err := c.openChannel("chan", nil) 45 if err != nil { 46 t.Fatalf("OpenChannel: %v", err) 47 } 48 49 return <-res, ch, c 50 } 51 52 // Test that stderr and stdout can be addressed from different 53 // goroutines. This is intended for use with the race detector. 54 func TestMuxChannelExtendedThreadSafety(t *testing.T) { 55 writer, reader, mux := channelPair(t) 56 defer writer.Close() 57 defer reader.Close() 58 defer mux.Close() 59 60 var wr, rd sync.WaitGroup 61 magic := "hello world" 62 63 wr.Add(2) 64 go func() { 65 io.WriteString(writer, magic) 66 wr.Done() 67 }() 68 go func() { 69 io.WriteString(writer.Stderr(), magic) 70 wr.Done() 71 }() 72 73 rd.Add(2) 74 go func() { 75 c, err := ioutil.ReadAll(reader) 76 if string(c) != magic { 77 t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) 78 } 79 rd.Done() 80 }() 81 go func() { 82 c, err := ioutil.ReadAll(reader.Stderr()) 83 if string(c) != magic { 84 t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) 85 } 86 rd.Done() 87 }() 88 89 wr.Wait() 90 writer.CloseWrite() 91 rd.Wait() 92 } 93 94 func TestMuxReadWrite(t *testing.T) { 95 s, c, mux := channelPair(t) 96 defer s.Close() 97 defer c.Close() 98 defer mux.Close() 99 100 magic := "hello world" 101 magicExt := "hello stderr" 102 go func() { 103 _, err := s.Write([]byte(magic)) 104 if err != nil { 105 t.Fatalf("Write: %v", err) 106 } 107 _, err = s.Extended(1).Write([]byte(magicExt)) 108 if err != nil { 109 t.Fatalf("Write: %v", err) 110 } 111 err = s.Close() 112 if err != nil { 113 t.Fatalf("Close: %v", err) 114 } 115 }() 116 117 var buf [1024]byte 118 n, err := c.Read(buf[:]) 119 if err != nil { 120 t.Fatalf("server Read: %v", err) 121 } 122 got := string(buf[:n]) 123 if got != magic { 124 t.Fatalf("server: got %q want %q", got, magic) 125 } 126 127 n, err = c.Extended(1).Read(buf[:]) 128 if err != nil { 129 t.Fatalf("server Read: %v", err) 130 } 131 132 got = string(buf[:n]) 133 if got != magicExt { 134 t.Fatalf("server: got %q want %q", got, magic) 135 } 136 } 137 138 func TestMuxChannelOverflow(t *testing.T) { 139 reader, writer, mux := channelPair(t) 140 defer reader.Close() 141 defer writer.Close() 142 defer mux.Close() 143 144 wDone := make(chan int, 1) 145 go func() { 146 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 147 t.Errorf("could not fill window: %v", err) 148 } 149 writer.Write(make([]byte, 1)) 150 wDone <- 1 151 }() 152 writer.remoteWin.waitWriterBlocked() 153 154 // Send 1 byte. 155 packet := make([]byte, 1+4+4+1) 156 packet[0] = msgChannelData 157 marshalUint32(packet[1:], writer.remoteId) 158 marshalUint32(packet[5:], uint32(1)) 159 packet[9] = 42 160 161 if err := writer.mux.conn.writePacket(packet); err != nil { 162 t.Errorf("could not send packet") 163 } 164 if _, err := reader.SendRequest("hello", true, nil); err == nil { 165 t.Errorf("SendRequest succeeded.") 166 } 167 <-wDone 168 } 169 170 func TestMuxChannelCloseWriteUnblock(t *testing.T) { 171 reader, writer, mux := channelPair(t) 172 defer reader.Close() 173 defer writer.Close() 174 defer mux.Close() 175 176 wDone := make(chan int, 1) 177 go func() { 178 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 179 t.Errorf("could not fill window: %v", err) 180 } 181 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 182 t.Errorf("got %v, want EOF for unblock write", err) 183 } 184 wDone <- 1 185 }() 186 187 writer.remoteWin.waitWriterBlocked() 188 reader.Close() 189 <-wDone 190 } 191 192 func TestMuxConnectionCloseWriteUnblock(t *testing.T) { 193 reader, writer, mux := channelPair(t) 194 defer reader.Close() 195 defer writer.Close() 196 defer mux.Close() 197 198 wDone := make(chan int, 1) 199 go func() { 200 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { 201 t.Errorf("could not fill window: %v", err) 202 } 203 if _, err := writer.Write(make([]byte, 1)); err != io.EOF { 204 t.Errorf("got %v, want EOF for unblock write", err) 205 } 206 wDone <- 1 207 }() 208 209 writer.remoteWin.waitWriterBlocked() 210 mux.Close() 211 <-wDone 212 } 213 214 func TestMuxReject(t *testing.T) { 215 client, server := muxPair() 216 defer server.Close() 217 defer client.Close() 218 219 go func() { 220 ch, ok := <-server.incomingChannels 221 if !ok { 222 t.Fatalf("Accept") 223 } 224 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 225 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 226 } 227 ch.Reject(RejectionReason(42), "message") 228 }() 229 230 ch, err := client.openChannel("ch", []byte("extra")) 231 if ch != nil { 232 t.Fatal("openChannel not rejected") 233 } 234 235 ocf, ok := err.(*OpenChannelError) 236 if !ok { 237 t.Errorf("got %#v want *OpenChannelError", err) 238 } else if ocf.Reason != 42 || ocf.Message != "message" { 239 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") 240 } 241 242 want := "ssh: rejected: unknown reason 42 (message)" 243 if err.Error() != want { 244 t.Errorf("got %q, want %q", err.Error(), want) 245 } 246 } 247 248 func TestMuxChannelRequest(t *testing.T) { 249 client, server, mux := channelPair(t) 250 defer server.Close() 251 defer client.Close() 252 defer mux.Close() 253 254 var received int 255 var wg sync.WaitGroup 256 wg.Add(1) 257 go func() { 258 for r := range server.incomingRequests { 259 received++ 260 r.Reply(r.Type == "yes", nil) 261 } 262 wg.Done() 263 }() 264 _, err := client.SendRequest("yes", false, nil) 265 if err != nil { 266 t.Fatalf("SendRequest: %v", err) 267 } 268 ok, err := client.SendRequest("yes", true, nil) 269 if err != nil { 270 t.Fatalf("SendRequest: %v", err) 271 } 272 273 if !ok { 274 t.Errorf("SendRequest(yes): %v", ok) 275 276 } 277 278 ok, err = client.SendRequest("no", true, nil) 279 if err != nil { 280 t.Fatalf("SendRequest: %v", err) 281 } 282 if ok { 283 t.Errorf("SendRequest(no): %v", ok) 284 285 } 286 287 client.Close() 288 wg.Wait() 289 290 if received != 3 { 291 t.Errorf("got %d requests, want %d", received, 3) 292 } 293 } 294 295 func TestMuxGlobalRequest(t *testing.T) { 296 clientMux, serverMux := muxPair() 297 defer serverMux.Close() 298 defer clientMux.Close() 299 300 var seen bool 301 go func() { 302 for r := range serverMux.incomingRequests { 303 seen = seen || r.Type == "peek" 304 if r.WantReply { 305 err := r.Reply(r.Type == "yes", 306 append([]byte(r.Type), r.Payload...)) 307 if err != nil { 308 t.Errorf("AckRequest: %v", err) 309 } 310 } 311 } 312 }() 313 314 _, _, err := clientMux.SendRequest("peek", false, nil) 315 if err != nil { 316 t.Errorf("SendRequest: %v", err) 317 } 318 319 ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) 320 if !ok || string(data) != "yesa" || err != nil { 321 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 322 ok, data, err) 323 } 324 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { 325 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 326 ok, data, err) 327 } 328 329 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { 330 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", 331 ok, data, err) 332 } 333 334 if !seen { 335 t.Errorf("never saw 'peek' request") 336 } 337 } 338 339 func TestMuxGlobalRequestUnblock(t *testing.T) { 340 clientMux, serverMux := muxPair() 341 defer serverMux.Close() 342 defer clientMux.Close() 343 344 result := make(chan error, 1) 345 go func() { 346 _, _, err := clientMux.SendRequest("hello", true, nil) 347 result <- err 348 }() 349 350 <-serverMux.incomingRequests 351 serverMux.conn.Close() 352 err := <-result 353 354 if err != io.EOF { 355 t.Errorf("want EOF, got %v", io.EOF) 356 } 357 } 358 359 func TestMuxChannelRequestUnblock(t *testing.T) { 360 a, b, connB := channelPair(t) 361 defer a.Close() 362 defer b.Close() 363 defer connB.Close() 364 365 result := make(chan error, 1) 366 go func() { 367 _, err := a.SendRequest("hello", true, nil) 368 result <- err 369 }() 370 371 <-b.incomingRequests 372 connB.conn.Close() 373 err := <-result 374 375 if err != io.EOF { 376 t.Errorf("want EOF, got %v", err) 377 } 378 } 379 380 func TestMuxCloseChannel(t *testing.T) { 381 r, w, mux := channelPair(t) 382 defer mux.Close() 383 defer r.Close() 384 defer w.Close() 385 386 result := make(chan error, 1) 387 go func() { 388 var b [1024]byte 389 _, err := r.Read(b[:]) 390 result <- err 391 }() 392 if err := w.Close(); err != nil { 393 t.Errorf("w.Close: %v", err) 394 } 395 396 if _, err := w.Write([]byte("hello")); err != io.EOF { 397 t.Errorf("got err %v, want io.EOF after Close", err) 398 } 399 400 if err := <-result; err != io.EOF { 401 t.Errorf("got %v (%T), want io.EOF", err, err) 402 } 403 } 404 405 func TestMuxCloseWriteChannel(t *testing.T) { 406 r, w, mux := channelPair(t) 407 defer mux.Close() 408 409 result := make(chan error, 1) 410 go func() { 411 var b [1024]byte 412 _, err := r.Read(b[:]) 413 result <- err 414 }() 415 if err := w.CloseWrite(); err != nil { 416 t.Errorf("w.CloseWrite: %v", err) 417 } 418 419 if _, err := w.Write([]byte("hello")); err != io.EOF { 420 t.Errorf("got err %v, want io.EOF after CloseWrite", err) 421 } 422 423 if err := <-result; err != io.EOF { 424 t.Errorf("got %v (%T), want io.EOF", err, err) 425 } 426 } 427 428 func TestMuxInvalidRecord(t *testing.T) { 429 a, b := muxPair() 430 defer a.Close() 431 defer b.Close() 432 433 packet := make([]byte, 1+4+4+1) 434 packet[0] = msgChannelData 435 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 436 marshalUint32(packet[5:], 1) 437 packet[9] = 42 438 439 a.conn.writePacket(packet) 440 go a.SendRequest("hello", false, nil) 441 // 'a' wrote an invalid packet, so 'b' has exited. 442 req, ok := <-b.incomingRequests 443 if ok { 444 t.Errorf("got request %#v after receiving invalid packet", req) 445 } 446 } 447 448 func TestZeroWindowAdjust(t *testing.T) { 449 a, b, mux := channelPair(t) 450 defer a.Close() 451 defer b.Close() 452 defer mux.Close() 453 454 go func() { 455 io.WriteString(a, "hello") 456 // bogus adjust. 457 a.sendMessage(windowAdjustMsg{}) 458 io.WriteString(a, "world") 459 a.Close() 460 }() 461 462 want := "helloworld" 463 c, _ := ioutil.ReadAll(b) 464 if string(c) != want { 465 t.Errorf("got %q want %q", c, want) 466 } 467 } 468 469 func TestMuxMaxPacketSize(t *testing.T) { 470 a, b, mux := channelPair(t) 471 defer a.Close() 472 defer b.Close() 473 defer mux.Close() 474 475 large := make([]byte, a.maxRemotePayload+1) 476 packet := make([]byte, 1+4+4+1+len(large)) 477 packet[0] = msgChannelData 478 marshalUint32(packet[1:], a.remoteId) 479 marshalUint32(packet[5:], uint32(len(large))) 480 packet[9] = 42 481 482 if err := a.mux.conn.writePacket(packet); err != nil { 483 t.Errorf("could not send packet") 484 } 485 486 go a.SendRequest("hello", false, nil) 487 488 _, ok := <-b.incomingRequests 489 if ok { 490 t.Errorf("connection still alive after receiving large packet.") 491 } 492 } 493 494 // Don't ship code with debug=true. 495 func TestDebug(t *testing.T) { 496 if debugMux { 497 t.Error("mux debug switched on") 498 } 499 if debugHandshake { 500 t.Error("handshake debug switched on") 501 } 502 }