github.com/glycerine/xcryptossh@v7.0.4+incompatible/mux_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"context"
     9  	"io"
    10  	"io/ioutil"
    11  	"sync"
    12  	"testing"
    13  )
    14  
    15  func muxPair(halt *Halter) (*mux, *mux) {
    16  	a, b := memPipe()
    17  
    18  	ctx := context.Background()
    19  
    20  	s := newMux(ctx, a, halt)
    21  	c := newMux(ctx, b, halt)
    22  
    23  	return s, c
    24  }
    25  
    26  // Returns both ends of a channel, and the mux for the the 2nd
    27  // channel.
    28  func channelPair(t *testing.T, halt *Halter) (*channel, *channel, *mux) {
    29  	c, s := muxPair(halt)
    30  
    31  	res := make(chan *channel, 1)
    32  	go func() {
    33  		newCh, ok := <-s.incomingChannels
    34  		if !ok {
    35  			t.Fatalf("No incoming channel")
    36  		}
    37  		if newCh.ChannelType() != "chan" {
    38  			t.Fatalf("got type %q want chan", newCh.ChannelType())
    39  		}
    40  		ch, _, err := newCh.Accept()
    41  		if err != nil {
    42  			t.Fatalf("Accept %v", err)
    43  		}
    44  		res <- ch.(*channel)
    45  	}()
    46  	ctx := context.Background()
    47  
    48  	chc, err := c.openChannel(ctx, "chan", nil, nil)
    49  	if err != nil {
    50  		t.Fatalf("OpenChannel: %v", err)
    51  	}
    52  
    53  	chs := <-res
    54  
    55  	// setup the idleTimers on the memTransports.
    56  	tc := c.conn.(*memTransport)
    57  	tc.Lock()
    58  	tc.idle = chc.idleR
    59  	tc.Unlock()
    60  
    61  	ts := s.conn.(*memTransport)
    62  	ts.Lock()
    63  	ts.idle = chs.idleR
    64  	ts.Unlock()
    65  
    66  	return chs, chc, c
    67  }
    68  
    69  // Test that stderr and stdout can be addressed from different
    70  // goroutines. This is intended for use with the race detector.
    71  func TestMuxChannelExtendedThreadSafety(t *testing.T) {
    72  	defer xtestend(xtestbegin(t))
    73  
    74  	halt := NewHalter()
    75  	defer halt.RequestStop()
    76  
    77  	writer, reader, mux := channelPair(t, halt)
    78  	defer writer.Close()
    79  	defer reader.Close()
    80  	defer mux.Close()
    81  
    82  	var wr, rd sync.WaitGroup
    83  	magic := "hello world"
    84  
    85  	wr.Add(2)
    86  	go func() {
    87  		io.WriteString(writer, magic)
    88  		wr.Done()
    89  	}()
    90  	go func() {
    91  		io.WriteString(writer.Stderr(), magic)
    92  		wr.Done()
    93  	}()
    94  
    95  	rd.Add(2)
    96  	go func() {
    97  		c, err := ioutil.ReadAll(reader)
    98  		if string(c) != magic {
    99  			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
   100  		}
   101  		rd.Done()
   102  	}()
   103  	go func() {
   104  		c, err := ioutil.ReadAll(reader.Stderr())
   105  		if string(c) != magic {
   106  			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
   107  		}
   108  		rd.Done()
   109  	}()
   110  
   111  	wr.Wait()
   112  	writer.CloseWrite()
   113  	rd.Wait()
   114  }
   115  
   116  func TestMuxReadWrite(t *testing.T) {
   117  	defer xtestend(xtestbegin(t))
   118  
   119  	halt := NewHalter()
   120  	defer halt.RequestStop()
   121  
   122  	s, c, mux := channelPair(t, halt)
   123  	defer s.Close()
   124  	defer c.Close()
   125  	defer mux.Close()
   126  
   127  	writeDone := make(chan bool)
   128  
   129  	magic := "hello world"
   130  	magicExt := "hello stderr"
   131  	go func() {
   132  		_, err := s.Write([]byte(magic))
   133  		if err != nil {
   134  			t.Fatalf("Write: %v", err)
   135  		}
   136  		_, err = s.Extended(1).Write([]byte(magicExt))
   137  		if err != nil {
   138  			t.Fatalf("Write: %v", err)
   139  		}
   140  		err = s.Close()
   141  		if err != nil {
   142  			t.Fatalf("Close: %v", err)
   143  		}
   144  
   145  		close(writeDone)
   146  	}()
   147  
   148  	var buf [1024]byte
   149  	n, err := c.Read(buf[:])
   150  	if err != nil {
   151  		t.Fatalf("server Read: %v", err)
   152  	}
   153  	got := string(buf[:n])
   154  	if got != magic {
   155  		t.Fatalf("server: got %q want %q", got, magic)
   156  	}
   157  
   158  	n, err = c.Extended(1).Read(buf[:])
   159  	if err != nil {
   160  		t.Fatalf("server Read: %v", err)
   161  	}
   162  
   163  	got = string(buf[:n])
   164  	if got != magicExt {
   165  		t.Fatalf("server: got %q want %q", got, magic)
   166  	}
   167  
   168  	<-writeDone
   169  }
   170  
   171  func TestMuxChannelOverflow(t *testing.T) {
   172  	defer xtestend(xtestbegin(t))
   173  
   174  	halt := NewHalter()
   175  	defer halt.RequestStop()
   176  
   177  	reader, writer, mux := channelPair(t, halt)
   178  	defer reader.Close()
   179  	defer writer.Close()
   180  	defer mux.Close()
   181  
   182  	wDone := make(chan int, 1)
   183  	go func() {
   184  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   185  			t.Errorf("could not fill window: %v", err)
   186  		}
   187  		writer.Write(make([]byte, 1))
   188  		wDone <- 1
   189  	}()
   190  	writer.remoteWin.waitWriterBlocked()
   191  
   192  	// Send 1 byte.
   193  	packet := make([]byte, 1+4+4+1)
   194  	packet[0] = msgChannelData
   195  	marshalUint32(packet[1:], writer.remoteId)
   196  	marshalUint32(packet[5:], uint32(1))
   197  	packet[9] = 42
   198  
   199  	if err := writer.mux.conn.writePacket(packet); err != nil {
   200  		t.Errorf("could not send packet")
   201  	}
   202  	if _, err := reader.SendRequest("hello", true, nil); err == nil {
   203  		t.Errorf("SendRequest succeeded.")
   204  	}
   205  	<-wDone
   206  }
   207  
   208  func TestMuxChannelCloseWriteUnblock(t *testing.T) {
   209  	defer xtestend(xtestbegin(t))
   210  
   211  	halt := NewHalter()
   212  	defer halt.RequestStop()
   213  
   214  	reader, writer, mux := channelPair(t, halt)
   215  	defer reader.Close()
   216  	defer writer.Close()
   217  	defer mux.Close()
   218  
   219  	wDone := make(chan int, 1)
   220  	go func() {
   221  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   222  			t.Errorf("could not fill window: %v", err)
   223  		}
   224  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   225  			t.Errorf("got %v, want EOF for unblock write", err)
   226  		}
   227  		wDone <- 1
   228  	}()
   229  
   230  	writer.remoteWin.waitWriterBlocked()
   231  	reader.Close()
   232  	<-wDone
   233  }
   234  
   235  func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
   236  	defer xtestend(xtestbegin(t))
   237  
   238  	halt := NewHalter()
   239  	defer halt.RequestStop()
   240  
   241  	reader, writer, mux := channelPair(t, halt)
   242  	defer reader.Close()
   243  	defer writer.Close()
   244  	defer mux.Close()
   245  
   246  	wDone := make(chan int, 1)
   247  	go func() {
   248  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   249  			t.Errorf("could not fill window: %v", err)
   250  		}
   251  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   252  			t.Errorf("got %v, want EOF for unblock write", err)
   253  		}
   254  		wDone <- 1
   255  	}()
   256  
   257  	writer.remoteWin.waitWriterBlocked()
   258  	mux.Close()
   259  	<-wDone
   260  }
   261  
   262  func TestMuxReject(t *testing.T) {
   263  	defer xtestend(xtestbegin(t))
   264  
   265  	halt := NewHalter()
   266  	defer halt.RequestStop()
   267  
   268  	client, server := muxPair(halt)
   269  	defer server.Close()
   270  	defer client.Close()
   271  
   272  	go func() {
   273  		ch, ok := <-server.incomingChannels
   274  		if !ok {
   275  			t.Fatalf("Accept")
   276  		}
   277  		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
   278  			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
   279  		}
   280  		ch.Reject(RejectionReason(42), "message")
   281  	}()
   282  	ctx := context.Background()
   283  
   284  	ch, err := client.openChannel(ctx, "ch", []byte("extra"), nil)
   285  	if ch != nil {
   286  		t.Fatal("openChannel not rejected")
   287  	}
   288  
   289  	ocf, ok := err.(*OpenChannelError)
   290  	if !ok {
   291  		t.Errorf("got %#v want *OpenChannelError", err)
   292  	} else if ocf.Reason != 42 || ocf.Message != "message" {
   293  		t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
   294  	}
   295  
   296  	want := "ssh: rejected: unknown reason 42 (message)"
   297  	if err.Error() != want {
   298  		t.Errorf("got %q, want %q", err.Error(), want)
   299  	}
   300  }
   301  
   302  func TestMuxChannelRequest(t *testing.T) {
   303  	defer xtestend(xtestbegin(t))
   304  
   305  	halt := NewHalter()
   306  	defer halt.RequestStop()
   307  
   308  	client, server, mux := channelPair(t, halt)
   309  	defer server.Close()
   310  	defer client.Close()
   311  	defer mux.Close()
   312  
   313  	var received int
   314  	var wg sync.WaitGroup
   315  	wg.Add(1)
   316  	go func() {
   317  		for r := range server.incomingRequests {
   318  			received++
   319  			r.Reply(r.Type == "yes", nil)
   320  		}
   321  		wg.Done()
   322  	}()
   323  	_, err := client.SendRequest("yes", false, nil)
   324  	if err != nil {
   325  		t.Fatalf("SendRequest: %v", err)
   326  	}
   327  	ok, err := client.SendRequest("yes", true, nil)
   328  	if err != nil {
   329  		t.Fatalf("SendRequest: %v", err)
   330  	}
   331  
   332  	if !ok {
   333  		t.Errorf("SendRequest(yes): %v", ok)
   334  
   335  	}
   336  
   337  	ok, err = client.SendRequest("no", true, nil)
   338  	if err != nil {
   339  		t.Fatalf("SendRequest: %v", err)
   340  	}
   341  	if ok {
   342  		t.Errorf("SendRequest(no): %v", ok)
   343  
   344  	}
   345  
   346  	client.Close()
   347  	wg.Wait()
   348  
   349  	if received != 3 {
   350  		t.Errorf("got %d requests, want %d", received, 3)
   351  	}
   352  }
   353  
   354  func TestMuxGlobalRequest(t *testing.T) {
   355  	defer xtestend(xtestbegin(t))
   356  
   357  	halt := NewHalter()
   358  	defer halt.RequestStop()
   359  
   360  	clientMux, serverMux := muxPair(halt)
   361  	defer serverMux.Close()
   362  	defer clientMux.Close()
   363  
   364  	var seen bool
   365  	go func() {
   366  		for r := range serverMux.incomingRequests {
   367  			seen = seen || r.Type == "peek"
   368  			if r.WantReply {
   369  				err := r.Reply(r.Type == "yes",
   370  					append([]byte(r.Type), r.Payload...))
   371  				if err != nil {
   372  					t.Errorf("AckRequest: %v", err)
   373  				}
   374  			}
   375  		}
   376  	}()
   377  	ctx := context.Background()
   378  
   379  	_, _, err := clientMux.SendRequest(ctx, "peek", false, nil)
   380  	if err != nil {
   381  		t.Errorf("SendRequest: %v", err)
   382  	}
   383  
   384  	ok, data, err := clientMux.SendRequest(ctx, "yes", true, []byte("a"))
   385  	if !ok || string(data) != "yesa" || err != nil {
   386  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   387  			ok, data, err)
   388  	}
   389  	if ok, data, err := clientMux.SendRequest(ctx, "yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
   390  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   391  			ok, data, err)
   392  	}
   393  
   394  	if ok, data, err := clientMux.SendRequest(ctx, "no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
   395  		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
   396  			ok, data, err)
   397  	}
   398  
   399  	if !seen {
   400  		t.Errorf("never saw 'peek' request")
   401  	}
   402  }
   403  
   404  func TestMuxGlobalRequestUnblock(t *testing.T) {
   405  	defer xtestend(xtestbegin(t))
   406  	halt := NewHalter()
   407  	defer halt.RequestStop()
   408  
   409  	clientMux, serverMux := muxPair(halt)
   410  	defer serverMux.Close()
   411  	defer clientMux.Close()
   412  
   413  	result := make(chan error, 1)
   414  	ctx := context.Background()
   415  
   416  	go func() {
   417  		_, _, err := clientMux.SendRequest(ctx, "hello", true, nil)
   418  		result <- err
   419  	}()
   420  
   421  	<-serverMux.incomingRequests
   422  	serverMux.conn.Close()
   423  	err := <-result
   424  
   425  	if err != io.EOF {
   426  		t.Errorf("want EOF, got %v", err)
   427  	}
   428  }
   429  
   430  func TestMuxChannelRequestUnblock(t *testing.T) {
   431  	defer xtestend(xtestbegin(t))
   432  	halt := NewHalter()
   433  	defer halt.RequestStop()
   434  
   435  	a, b, connB := channelPair(t, halt)
   436  	defer a.Close()
   437  	defer b.Close()
   438  	defer connB.Close()
   439  
   440  	result := make(chan error, 1)
   441  	go func() {
   442  		_, err := a.SendRequest("hello", true, nil)
   443  		result <- err
   444  	}()
   445  
   446  	<-b.incomingRequests
   447  	connB.conn.Close()
   448  	err := <-result
   449  
   450  	if err != io.EOF {
   451  		t.Errorf("want EOF, got %v", err)
   452  	}
   453  }
   454  
   455  func TestMuxCloseChannel(t *testing.T) {
   456  	defer xtestend(xtestbegin(t))
   457  	halt := NewHalter()
   458  	defer halt.RequestStop()
   459  
   460  	r, w, mux := channelPair(t, halt)
   461  	defer mux.Close()
   462  	defer r.Close()
   463  	defer w.Close()
   464  
   465  	result := make(chan error, 1)
   466  	go func() {
   467  		var b [1024]byte
   468  		_, err := r.Read(b[:])
   469  		result <- err
   470  	}()
   471  	if err := w.Close(); err != nil {
   472  		t.Errorf("w.Close: %v", err)
   473  	}
   474  
   475  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   476  		t.Errorf("got err %v, want io.EOF after Close", err)
   477  	}
   478  
   479  	if err := <-result; err != io.EOF {
   480  		t.Errorf("got %v (%T), want io.EOF", err, err)
   481  	}
   482  }
   483  
   484  func TestMuxCloseWriteChannel(t *testing.T) {
   485  	defer xtestend(xtestbegin(t))
   486  	halt := NewHalter()
   487  	defer halt.RequestStop()
   488  
   489  	r, w, mux := channelPair(t, halt)
   490  	defer mux.Close()
   491  
   492  	result := make(chan error, 1)
   493  	go func() {
   494  		var b [1024]byte
   495  		_, err := r.Read(b[:])
   496  		result <- err
   497  	}()
   498  	if err := w.CloseWrite(); err != nil {
   499  		t.Errorf("w.CloseWrite: %v", err)
   500  	}
   501  
   502  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   503  		t.Errorf("got err %v, want io.EOF after CloseWrite", err)
   504  	}
   505  
   506  	if err := <-result; err != io.EOF {
   507  		t.Errorf("got %v (%T), want io.EOF", err, err)
   508  	}
   509  }
   510  
   511  func TestMuxInvalidRecord(t *testing.T) {
   512  	defer xtestend(xtestbegin(t))
   513  	halt := NewHalter()
   514  	defer halt.RequestStop()
   515  
   516  	a, b := muxPair(halt)
   517  	defer a.Close()
   518  	defer b.Close()
   519  
   520  	packet := make([]byte, 1+4+4+1)
   521  	packet[0] = msgChannelData
   522  	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
   523  	marshalUint32(packet[5:], 1)
   524  	packet[9] = 42
   525  
   526  	a.conn.writePacket(packet)
   527  	ctx := context.Background()
   528  
   529  	go a.SendRequest(ctx, "hello", false, nil)
   530  	// 'a' wrote an invalid packet, so 'b' has exited.
   531  	req, ok := <-b.incomingRequests
   532  	if ok {
   533  		t.Errorf("got request %#v after receiving invalid packet", req)
   534  	}
   535  }
   536  
   537  func TestZeroWindowAdjust(t *testing.T) {
   538  	defer xtestend(xtestbegin(t))
   539  
   540  	halt := NewHalter()
   541  	defer halt.RequestStop()
   542  
   543  	a, b, mux := channelPair(t, halt)
   544  	defer a.Close()
   545  	defer b.Close()
   546  	defer mux.Close()
   547  
   548  	go func() {
   549  		io.WriteString(a, "hello")
   550  		// bogus adjust.
   551  		a.sendMessage(windowAdjustMsg{})
   552  		io.WriteString(a, "world")
   553  		a.Close()
   554  	}()
   555  
   556  	want := "helloworld"
   557  	c, _ := ioutil.ReadAll(b)
   558  	if string(c) != want {
   559  		t.Errorf("got %q want %q", c, want)
   560  	}
   561  }
   562  
   563  func TestMuxMaxPacketSize(t *testing.T) {
   564  	defer xtestend(xtestbegin(t))
   565  
   566  	halt := NewHalter()
   567  	defer halt.RequestStop()
   568  
   569  	a, b, mux := channelPair(t, halt)
   570  	defer a.Close()
   571  	defer b.Close()
   572  	defer mux.Close()
   573  
   574  	large := make([]byte, a.maxRemotePayload+1)
   575  	packet := make([]byte, 1+4+4+1+len(large))
   576  	packet[0] = msgChannelData
   577  	marshalUint32(packet[1:], a.remoteId)
   578  	marshalUint32(packet[5:], uint32(len(large)))
   579  	packet[9] = 42
   580  
   581  	if err := a.mux.conn.writePacket(packet); err != nil {
   582  		t.Errorf("could not send packet")
   583  	}
   584  
   585  	go a.SendRequest("hello", false, nil)
   586  
   587  	_, ok := <-b.incomingRequests
   588  	if ok {
   589  		t.Errorf("connection still alive after receiving large packet.")
   590  	}
   591  }
   592  
   593  // Don't ship code with debug=true.
   594  func TestDebug(t *testing.T) {
   595  	defer xtestend(xtestbegin(t))
   596  	if debugMux {
   597  		t.Error("mux debug switched on")
   598  	}
   599  	if debugHandshake {
   600  		t.Error("handshake debug switched on")
   601  	}
   602  	if debugTransport {
   603  		t.Error("transport debug switched on")
   604  	}
   605  }