github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/mux_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"io"
     9  	"io/ioutil"
    10  	"sync"
    11  	"testing"
    12  )
    13  
    14  func muxPair() (*mux, *mux) {
    15  	a, b := memPipe()
    16  
    17  	s := newMux(a)
    18  	c := newMux(b)
    19  
    20  	return s, c
    21  }
    22  
    23  // Returns both ends of a channel, and the mux for the the 2nd
    24  // channel.
    25  func channelPair(t *testing.T) (*channel, *channel, *mux) {
    26  	c, s := muxPair()
    27  
    28  	res := make(chan *channel, 1)
    29  	go func() {
    30  		newCh, ok := <-s.incomingChannels
    31  		if !ok {
    32  			t.Fatalf("No incoming channel")
    33  		}
    34  		if newCh.ChannelType() != "chan" {
    35  			t.Fatalf("got type %q want chan", newCh.ChannelType())
    36  		}
    37  		ch, _, err := newCh.Accept()
    38  		if err != nil {
    39  			t.Fatalf("Accept %v", err)
    40  		}
    41  		res <- ch.(*channel)
    42  	}()
    43  
    44  	ch, err := c.openChannel("chan", nil)
    45  	if err != nil {
    46  		t.Fatalf("OpenChannel: %v", err)
    47  	}
    48  
    49  	return <-res, ch, c
    50  }
    51  
    52  // Test that stderr and stdout can be addressed from different
    53  // goroutines. This is intended for use with the race detector.
    54  func TestMuxChannelExtendedThreadSafety(t *testing.T) {
    55  	writer, reader, mux := channelPair(t)
    56  	defer writer.Close()
    57  	defer reader.Close()
    58  	defer mux.Close()
    59  
    60  	var wr, rd sync.WaitGroup
    61  	magic := "hello world"
    62  
    63  	wr.Add(2)
    64  	go func() {
    65  		io.WriteString(writer, magic)
    66  		wr.Done()
    67  	}()
    68  	go func() {
    69  		io.WriteString(writer.Stderr(), magic)
    70  		wr.Done()
    71  	}()
    72  
    73  	rd.Add(2)
    74  	go func() {
    75  		c, err := ioutil.ReadAll(reader)
    76  		if string(c) != magic {
    77  			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
    78  		}
    79  		rd.Done()
    80  	}()
    81  	go func() {
    82  		c, err := ioutil.ReadAll(reader.Stderr())
    83  		if string(c) != magic {
    84  			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
    85  		}
    86  		rd.Done()
    87  	}()
    88  
    89  	wr.Wait()
    90  	writer.CloseWrite()
    91  	rd.Wait()
    92  }
    93  
    94  func TestMuxReadWrite(t *testing.T) {
    95  	s, c, mux := channelPair(t)
    96  	defer s.Close()
    97  	defer c.Close()
    98  	defer mux.Close()
    99  
   100  	magic := "hello world"
   101  	magicExt := "hello stderr"
   102  	go func() {
   103  		_, err := s.Write([]byte(magic))
   104  		if err != nil {
   105  			t.Fatalf("Write: %v", err)
   106  		}
   107  		_, err = s.Extended(1).Write([]byte(magicExt))
   108  		if err != nil {
   109  			t.Fatalf("Write: %v", err)
   110  		}
   111  		err = s.Close()
   112  		if err != nil {
   113  			t.Fatalf("Close: %v", err)
   114  		}
   115  	}()
   116  
   117  	var buf [1024]byte
   118  	n, err := c.Read(buf[:])
   119  	if err != nil {
   120  		t.Fatalf("server Read: %v", err)
   121  	}
   122  	got := string(buf[:n])
   123  	if got != magic {
   124  		t.Fatalf("server: got %q want %q", got, magic)
   125  	}
   126  
   127  	n, err = c.Extended(1).Read(buf[:])
   128  	if err != nil {
   129  		t.Fatalf("server Read: %v", err)
   130  	}
   131  
   132  	got = string(buf[:n])
   133  	if got != magicExt {
   134  		t.Fatalf("server: got %q want %q", got, magic)
   135  	}
   136  }
   137  
   138  func TestMuxChannelOverflow(t *testing.T) {
   139  	reader, writer, mux := channelPair(t)
   140  	defer reader.Close()
   141  	defer writer.Close()
   142  	defer mux.Close()
   143  
   144  	wDone := make(chan int, 1)
   145  	go func() {
   146  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   147  			t.Errorf("could not fill window: %v", err)
   148  		}
   149  		writer.Write(make([]byte, 1))
   150  		wDone <- 1
   151  	}()
   152  	writer.remoteWin.waitWriterBlocked()
   153  
   154  	// Send 1 byte.
   155  	packet := make([]byte, 1+4+4+1)
   156  	packet[0] = msgChannelData
   157  	marshalUint32(packet[1:], writer.remoteId)
   158  	marshalUint32(packet[5:], uint32(1))
   159  	packet[9] = 42
   160  
   161  	if err := writer.mux.conn.writePacket(packet); err != nil {
   162  		t.Errorf("could not send packet")
   163  	}
   164  	if _, err := reader.SendRequest("hello", true, nil); err == nil {
   165  		t.Errorf("SendRequest succeeded.")
   166  	}
   167  	<-wDone
   168  }
   169  
   170  func TestMuxChannelCloseWriteUnblock(t *testing.T) {
   171  	reader, writer, mux := channelPair(t)
   172  	defer reader.Close()
   173  	defer writer.Close()
   174  	defer mux.Close()
   175  
   176  	wDone := make(chan int, 1)
   177  	go func() {
   178  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   179  			t.Errorf("could not fill window: %v", err)
   180  		}
   181  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   182  			t.Errorf("got %v, want EOF for unblock write", err)
   183  		}
   184  		wDone <- 1
   185  	}()
   186  
   187  	writer.remoteWin.waitWriterBlocked()
   188  	reader.Close()
   189  	<-wDone
   190  }
   191  
   192  func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
   193  	reader, writer, mux := channelPair(t)
   194  	defer reader.Close()
   195  	defer writer.Close()
   196  	defer mux.Close()
   197  
   198  	wDone := make(chan int, 1)
   199  	go func() {
   200  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   201  			t.Errorf("could not fill window: %v", err)
   202  		}
   203  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   204  			t.Errorf("got %v, want EOF for unblock write", err)
   205  		}
   206  		wDone <- 1
   207  	}()
   208  
   209  	writer.remoteWin.waitWriterBlocked()
   210  	mux.Close()
   211  	<-wDone
   212  }
   213  
   214  func TestMuxReject(t *testing.T) {
   215  	client, server := muxPair()
   216  	defer server.Close()
   217  	defer client.Close()
   218  
   219  	go func() {
   220  		ch, ok := <-server.incomingChannels
   221  		if !ok {
   222  			t.Fatalf("Accept")
   223  		}
   224  		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
   225  			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
   226  		}
   227  		ch.Reject(RejectionReason(42), "message")
   228  	}()
   229  
   230  	ch, err := client.openChannel("ch", []byte("extra"))
   231  	if ch != nil {
   232  		t.Fatal("openChannel not rejected")
   233  	}
   234  
   235  	ocf, ok := err.(*OpenChannelError)
   236  	if !ok {
   237  		t.Errorf("got %#v want *OpenChannelError", err)
   238  	} else if ocf.Reason != 42 || ocf.Message != "message" {
   239  		t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
   240  	}
   241  
   242  	want := "ssh: rejected: unknown reason 42 (message)"
   243  	if err.Error() != want {
   244  		t.Errorf("got %q, want %q", err.Error(), want)
   245  	}
   246  }
   247  
   248  func TestMuxChannelRequest(t *testing.T) {
   249  	client, server, mux := channelPair(t)
   250  	defer server.Close()
   251  	defer client.Close()
   252  	defer mux.Close()
   253  
   254  	var received int
   255  	var wg sync.WaitGroup
   256  	wg.Add(1)
   257  	go func() {
   258  		for r := range server.incomingRequests {
   259  			received++
   260  			r.Reply(r.Type == "yes", nil)
   261  		}
   262  		wg.Done()
   263  	}()
   264  	_, err := client.SendRequest("yes", false, nil)
   265  	if err != nil {
   266  		t.Fatalf("SendRequest: %v", err)
   267  	}
   268  	ok, err := client.SendRequest("yes", true, nil)
   269  	if err != nil {
   270  		t.Fatalf("SendRequest: %v", err)
   271  	}
   272  
   273  	if !ok {
   274  		t.Errorf("SendRequest(yes): %v", ok)
   275  
   276  	}
   277  
   278  	ok, err = client.SendRequest("no", true, nil)
   279  	if err != nil {
   280  		t.Fatalf("SendRequest: %v", err)
   281  	}
   282  	if ok {
   283  		t.Errorf("SendRequest(no): %v", ok)
   284  
   285  	}
   286  
   287  	client.Close()
   288  	wg.Wait()
   289  
   290  	if received != 3 {
   291  		t.Errorf("got %d requests, want %d", received, 3)
   292  	}
   293  }
   294  
   295  func TestMuxGlobalRequest(t *testing.T) {
   296  	clientMux, serverMux := muxPair()
   297  	defer serverMux.Close()
   298  	defer clientMux.Close()
   299  
   300  	var seen bool
   301  	go func() {
   302  		for r := range serverMux.incomingRequests {
   303  			seen = seen || r.Type == "peek"
   304  			if r.WantReply {
   305  				err := r.Reply(r.Type == "yes",
   306  					append([]byte(r.Type), r.Payload...))
   307  				if err != nil {
   308  					t.Errorf("AckRequest: %v", err)
   309  				}
   310  			}
   311  		}
   312  	}()
   313  
   314  	_, _, err := clientMux.SendRequest("peek", false, nil)
   315  	if err != nil {
   316  		t.Errorf("SendRequest: %v", err)
   317  	}
   318  
   319  	ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
   320  	if !ok || string(data) != "yesa" || err != nil {
   321  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   322  			ok, data, err)
   323  	}
   324  	if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
   325  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   326  			ok, data, err)
   327  	}
   328  
   329  	if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
   330  		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
   331  			ok, data, err)
   332  	}
   333  
   334  	if !seen {
   335  		t.Errorf("never saw 'peek' request")
   336  	}
   337  }
   338  
   339  func TestMuxGlobalRequestUnblock(t *testing.T) {
   340  	clientMux, serverMux := muxPair()
   341  	defer serverMux.Close()
   342  	defer clientMux.Close()
   343  
   344  	result := make(chan error, 1)
   345  	go func() {
   346  		_, _, err := clientMux.SendRequest("hello", true, nil)
   347  		result <- err
   348  	}()
   349  
   350  	<-serverMux.incomingRequests
   351  	serverMux.conn.Close()
   352  	err := <-result
   353  
   354  	if err != io.EOF {
   355  		t.Errorf("want EOF, got %v", io.EOF)
   356  	}
   357  }
   358  
   359  func TestMuxChannelRequestUnblock(t *testing.T) {
   360  	a, b, connB := channelPair(t)
   361  	defer a.Close()
   362  	defer b.Close()
   363  	defer connB.Close()
   364  
   365  	result := make(chan error, 1)
   366  	go func() {
   367  		_, err := a.SendRequest("hello", true, nil)
   368  		result <- err
   369  	}()
   370  
   371  	<-b.incomingRequests
   372  	connB.conn.Close()
   373  	err := <-result
   374  
   375  	if err != io.EOF {
   376  		t.Errorf("want EOF, got %v", err)
   377  	}
   378  }
   379  
   380  func TestMuxCloseChannel(t *testing.T) {
   381  	r, w, mux := channelPair(t)
   382  	defer mux.Close()
   383  	defer r.Close()
   384  	defer w.Close()
   385  
   386  	result := make(chan error, 1)
   387  	go func() {
   388  		var b [1024]byte
   389  		_, err := r.Read(b[:])
   390  		result <- err
   391  	}()
   392  	if err := w.Close(); err != nil {
   393  		t.Errorf("w.Close: %v", err)
   394  	}
   395  
   396  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   397  		t.Errorf("got err %v, want io.EOF after Close", err)
   398  	}
   399  
   400  	if err := <-result; err != io.EOF {
   401  		t.Errorf("got %v (%T), want io.EOF", err, err)
   402  	}
   403  }
   404  
   405  func TestMuxCloseWriteChannel(t *testing.T) {
   406  	r, w, mux := channelPair(t)
   407  	defer mux.Close()
   408  
   409  	result := make(chan error, 1)
   410  	go func() {
   411  		var b [1024]byte
   412  		_, err := r.Read(b[:])
   413  		result <- err
   414  	}()
   415  	if err := w.CloseWrite(); err != nil {
   416  		t.Errorf("w.CloseWrite: %v", err)
   417  	}
   418  
   419  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   420  		t.Errorf("got err %v, want io.EOF after CloseWrite", err)
   421  	}
   422  
   423  	if err := <-result; err != io.EOF {
   424  		t.Errorf("got %v (%T), want io.EOF", err, err)
   425  	}
   426  }
   427  
   428  func TestMuxInvalidRecord(t *testing.T) {
   429  	a, b := muxPair()
   430  	defer a.Close()
   431  	defer b.Close()
   432  
   433  	packet := make([]byte, 1+4+4+1)
   434  	packet[0] = msgChannelData
   435  	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
   436  	marshalUint32(packet[5:], 1)
   437  	packet[9] = 42
   438  
   439  	a.conn.writePacket(packet)
   440  	go a.SendRequest("hello", false, nil)
   441  	// 'a' wrote an invalid packet, so 'b' has exited.
   442  	req, ok := <-b.incomingRequests
   443  	if ok {
   444  		t.Errorf("got request %#v after receiving invalid packet", req)
   445  	}
   446  }
   447  
   448  func TestZeroWindowAdjust(t *testing.T) {
   449  	a, b, mux := channelPair(t)
   450  	defer a.Close()
   451  	defer b.Close()
   452  	defer mux.Close()
   453  
   454  	go func() {
   455  		io.WriteString(a, "hello")
   456  		// bogus adjust.
   457  		a.sendMessage(windowAdjustMsg{})
   458  		io.WriteString(a, "world")
   459  		a.Close()
   460  	}()
   461  
   462  	want := "helloworld"
   463  	c, _ := ioutil.ReadAll(b)
   464  	if string(c) != want {
   465  		t.Errorf("got %q want %q", c, want)
   466  	}
   467  }
   468  
   469  func TestMuxMaxPacketSize(t *testing.T) {
   470  	a, b, mux := channelPair(t)
   471  	defer a.Close()
   472  	defer b.Close()
   473  	defer mux.Close()
   474  
   475  	large := make([]byte, a.maxRemotePayload+1)
   476  	packet := make([]byte, 1+4+4+1+len(large))
   477  	packet[0] = msgChannelData
   478  	marshalUint32(packet[1:], a.remoteId)
   479  	marshalUint32(packet[5:], uint32(len(large)))
   480  	packet[9] = 42
   481  
   482  	if err := a.mux.conn.writePacket(packet); err != nil {
   483  		t.Errorf("could not send packet")
   484  	}
   485  
   486  	go a.SendRequest("hello", false, nil)
   487  
   488  	_, ok := <-b.incomingRequests
   489  	if ok {
   490  		t.Errorf("connection still alive after receiving large packet.")
   491  	}
   492  }
   493  
   494  // Don't ship code with debug=true.
   495  func TestDebug(t *testing.T) {
   496  	if debugMux {
   497  		t.Error("mux debug switched on")
   498  	}
   499  	if debugHandshake {
   500  		t.Error("handshake debug switched on")
   501  	}
   502  }