github.com/deis/deis@v1.13.5-0.20170519182049-1d9e59fbdbfc/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"io"
     9  	"io/ioutil"
    10  	"sync"
    11  	"testing"
    12  )
    13  
    14  func muxPair() (*mux, *mux) {
    15  	a, b := memPipe()
    16  
    17  	s := newMux(a)
    18  	c := newMux(b)
    19  
    20  	return s, c
    21  }
    22  
    23  // Returns both ends of a channel, and the mux for the the 2nd
    24  // channel.
    25  func channelPair(t *testing.T) (*channel, *channel, *mux) {
    26  	c, s := muxPair()
    27  
    28  	res := make(chan *channel, 1)
    29  	go func() {
    30  		newCh, ok := <-s.incomingChannels
    31  		if !ok {
    32  			t.Fatalf("No incoming channel")
    33  		}
    34  		if newCh.ChannelType() != "chan" {
    35  			t.Fatalf("got type %q want chan", newCh.ChannelType())
    36  		}
    37  		ch, _, err := newCh.Accept()
    38  		if err != nil {
    39  			t.Fatalf("Accept %v", err)
    40  		}
    41  		res <- ch.(*channel)
    42  	}()
    43  
    44  	ch, err := c.openChannel("chan", nil)
    45  	if err != nil {
    46  		t.Fatalf("OpenChannel: %v", err)
    47  	}
    48  
    49  	return <-res, ch, c
    50  }
    51  
    52  // Test that stderr and stdout can be addressed from different
    53  // goroutines. This is intended for use with the race detector.
    54  func TestMuxChannelExtendedThreadSafety(t *testing.T) {
    55  	writer, reader, mux := channelPair(t)
    56  	defer writer.Close()
    57  	defer reader.Close()
    58  	defer mux.Close()
    59  
    60  	var wr, rd sync.WaitGroup
    61  	magic := "hello world"
    62  
    63  	wr.Add(2)
    64  	go func() {
    65  		io.WriteString(writer, magic)
    66  		wr.Done()
    67  	}()
    68  	go func() {
    69  		io.WriteString(writer.Stderr(), magic)
    70  		wr.Done()
    71  	}()
    72  
    73  	rd.Add(2)
    74  	go func() {
    75  		c, err := ioutil.ReadAll(reader)
    76  		if string(c) != magic {
    77  			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
    78  		}
    79  		rd.Done()
    80  	}()
    81  	go func() {
    82  		c, err := ioutil.ReadAll(reader.Stderr())
    83  		if string(c) != magic {
    84  			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
    85  		}
    86  		rd.Done()
    87  	}()
    88  
    89  	wr.Wait()
    90  	writer.CloseWrite()
    91  	rd.Wait()
    92  }
    93  
    94  func TestMuxReadWrite(t *testing.T) {
    95  	s, c, mux := channelPair(t)
    96  	defer s.Close()
    97  	defer c.Close()
    98  	defer mux.Close()
    99  
   100  	magic := "hello world"
   101  	magicExt := "hello stderr"
   102  	go func() {
   103  		_, err := s.Write([]byte(magic))
   104  		if err != nil {
   105  			t.Fatalf("Write: %v", err)
   106  		}
   107  		_, err = s.Extended(1).Write([]byte(magicExt))
   108  		if err != nil {
   109  			t.Fatalf("Write: %v", err)
   110  		}
   111  		err = s.Close()
   112  		if err != nil {
   113  			t.Fatalf("Close: %v", err)
   114  		}
   115  	}()
   116  
   117  	var buf [1024]byte
   118  	n, err := c.Read(buf[:])
   119  	if err != nil {
   120  		t.Fatalf("server Read: %v", err)
   121  	}
   122  	got := string(buf[:n])
   123  	if got != magic {
   124  		t.Fatalf("server: got %q want %q", got, magic)
   125  	}
   126  
   127  	n, err = c.Extended(1).Read(buf[:])
   128  	if err != nil {
   129  		t.Fatalf("server Read: %v", err)
   130  	}
   131  
   132  	got = string(buf[:n])
   133  	if got != magicExt {
   134  		t.Fatalf("server: got %q want %q", got, magic)
   135  	}
   136  }
   137  
   138  func TestMuxChannelOverflow(t *testing.T) {
   139  	reader, writer, mux := channelPair(t)
   140  	defer reader.Close()
   141  	defer writer.Close()
   142  	defer mux.Close()
   143  
   144  	wDone := make(chan int, 1)
   145  	go func() {
   146  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   147  			t.Errorf("could not fill window: %v", err)
   148  		}
   149  		writer.Write(make([]byte, 1))
   150  		wDone <- 1
   151  	}()
   152  	writer.remoteWin.waitWriterBlocked()
   153  
   154  	// Send 1 byte.
   155  	packet := make([]byte, 1+4+4+1)
   156  	packet[0] = msgChannelData
   157  	marshalUint32(packet[1:], writer.remoteId)
   158  	marshalUint32(packet[5:], uint32(1))
   159  	packet[9] = 42
   160  
   161  	if err := writer.mux.conn.writePacket(packet); err != nil {
   162  		t.Errorf("could not send packet")
   163  	}
   164  	if _, err := reader.SendRequest("hello", true, nil); err == nil {
   165  		t.Errorf("SendRequest succeeded.")
   166  	}
   167  	<-wDone
   168  }
   169  
   170  func TestMuxChannelCloseWriteUnblock(t *testing.T) {
   171  	reader, writer, mux := channelPair(t)
   172  	defer reader.Close()
   173  	defer writer.Close()
   174  	defer mux.Close()
   175  
   176  	wDone := make(chan int, 1)
   177  	go func() {
   178  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   179  			t.Errorf("could not fill window: %v", err)
   180  		}
   181  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   182  			t.Errorf("got %v, want EOF for unblock write", err)
   183  		}
   184  		wDone <- 1
   185  	}()
   186  
   187  	writer.remoteWin.waitWriterBlocked()
   188  	reader.Close()
   189  	<-wDone
   190  }
   191  
   192  func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
   193  	reader, writer, mux := channelPair(t)
   194  	defer reader.Close()
   195  	defer writer.Close()
   196  	defer mux.Close()
   197  
   198  	wDone := make(chan int, 1)
   199  	go func() {
   200  		if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
   201  			t.Errorf("could not fill window: %v", err)
   202  		}
   203  		if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
   204  			t.Errorf("got %v, want EOF for unblock write", err)
   205  		}
   206  		wDone <- 1
   207  	}()
   208  
   209  	writer.remoteWin.waitWriterBlocked()
   210  	mux.Close()
   211  	<-wDone
   212  }
   213  
   214  func TestMuxReject(t *testing.T) {
   215  	client, server := muxPair()
   216  	defer server.Close()
   217  	defer client.Close()
   218  
   219  	go func() {
   220  		ch, ok := <-server.incomingChannels
   221  		if !ok {
   222  			t.Fatalf("Accept")
   223  		}
   224  		if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
   225  			t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
   226  		}
   227  		ch.Reject(RejectionReason(42), "message")
   228  	}()
   229  
   230  	ch, err := client.openChannel("ch", []byte("extra"))
   231  	if ch != nil {
   232  		t.Fatal("openChannel not rejected")
   233  	}
   234  
   235  	ocf, ok := err.(*OpenChannelError)
   236  	if !ok {
   237  		t.Errorf("got %#v want *OpenChannelError", err)
   238  	} else if ocf.Reason != 42 || ocf.Message != "message" {
   239  		t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
   240  	}
   241  
   242  	want := "ssh: rejected: unknown reason 42 (message)"
   243  	if err.Error() != want {
   244  		t.Errorf("got %q, want %q", err.Error(), want)
   245  	}
   246  }
   247  
   248  func TestMuxChannelRequest(t *testing.T) {
   249  	client, server, mux := channelPair(t)
   250  	defer server.Close()
   251  	defer client.Close()
   252  	defer mux.Close()
   253  
   254  	var received int
   255  	var wg sync.WaitGroup
   256  	wg.Add(1)
   257  	go func() {
   258  		for r := range server.incomingRequests {
   259  			received++
   260  			r.Reply(r.Type == "yes", nil)
   261  		}
   262  		wg.Done()
   263  	}()
   264  	_, err := client.SendRequest("yes", false, nil)
   265  	if err != nil {
   266  		t.Fatalf("SendRequest: %v", err)
   267  	}
   268  	ok, err := client.SendRequest("yes", true, nil)
   269  	if err != nil {
   270  		t.Fatalf("SendRequest: %v", err)
   271  	}
   272  
   273  	if !ok {
   274  		t.Errorf("SendRequest(yes): %v", ok)
   275  
   276  	}
   277  
   278  	ok, err = client.SendRequest("no", true, nil)
   279  	if err != nil {
   280  		t.Fatalf("SendRequest: %v", err)
   281  	}
   282  	if ok {
   283  		t.Errorf("SendRequest(no): %v", ok)
   284  
   285  	}
   286  
   287  	client.Close()
   288  	wg.Wait()
   289  
   290  	if received != 3 {
   291  		t.Errorf("got %d requests, want %d", received, 3)
   292  	}
   293  }
   294  
   295  func TestMuxGlobalRequest(t *testing.T) {
   296  	clientMux, serverMux := muxPair()
   297  	defer serverMux.Close()
   298  	defer clientMux.Close()
   299  
   300  	var seen bool
   301  	go func() {
   302  		for r := range serverMux.incomingRequests {
   303  			seen = seen || r.Type == "peek"
   304  			if r.WantReply {
   305  				err := r.Reply(r.Type == "yes",
   306  					append([]byte(r.Type), r.Payload...))
   307  				if err != nil {
   308  					t.Errorf("AckRequest: %v", err)
   309  				}
   310  			}
   311  		}
   312  	}()
   313  
   314  	_, _, err := clientMux.SendRequest("peek", false, nil)
   315  	if err != nil {
   316  		t.Errorf("SendRequest: %v", err)
   317  	}
   318  
   319  	ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
   320  	if !ok || string(data) != "yesa" || err != nil {
   321  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   322  			ok, data, err)
   323  	}
   324  	if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
   325  		t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
   326  			ok, data, err)
   327  	}
   328  
   329  	if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
   330  		t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
   331  			ok, data, err)
   332  	}
   333  
   334  	clientMux.Disconnect(0, "")
   335  	if !seen {
   336  		t.Errorf("never saw 'peek' request")
   337  	}
   338  }
   339  
   340  func TestMuxGlobalRequestUnblock(t *testing.T) {
   341  	clientMux, serverMux := muxPair()
   342  	defer serverMux.Close()
   343  	defer clientMux.Close()
   344  
   345  	result := make(chan error, 1)
   346  	go func() {
   347  		_, _, err := clientMux.SendRequest("hello", true, nil)
   348  		result <- err
   349  	}()
   350  
   351  	<-serverMux.incomingRequests
   352  	serverMux.conn.Close()
   353  	err := <-result
   354  
   355  	if err != io.EOF {
   356  		t.Errorf("want EOF, got %v", io.EOF)
   357  	}
   358  }
   359  
   360  func TestMuxChannelRequestUnblock(t *testing.T) {
   361  	a, b, connB := channelPair(t)
   362  	defer a.Close()
   363  	defer b.Close()
   364  	defer connB.Close()
   365  
   366  	result := make(chan error, 1)
   367  	go func() {
   368  		_, err := a.SendRequest("hello", true, nil)
   369  		result <- err
   370  	}()
   371  
   372  	<-b.incomingRequests
   373  	connB.conn.Close()
   374  	err := <-result
   375  
   376  	if err != io.EOF {
   377  		t.Errorf("want EOF, got %v", err)
   378  	}
   379  }
   380  
   381  func TestMuxDisconnect(t *testing.T) {
   382  	a, b := muxPair()
   383  	defer a.Close()
   384  	defer b.Close()
   385  
   386  	go func() {
   387  		for r := range b.incomingRequests {
   388  			r.Reply(true, nil)
   389  		}
   390  	}()
   391  
   392  	a.Disconnect(42, "whatever")
   393  	ok, _, err := a.SendRequest("hello", true, nil)
   394  	if ok || err == nil {
   395  		t.Errorf("got reply after disconnecting")
   396  	}
   397  	err = b.Wait()
   398  	if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
   399  		t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
   400  	}
   401  }
   402  
   403  func TestMuxCloseChannel(t *testing.T) {
   404  	r, w, mux := channelPair(t)
   405  	defer mux.Close()
   406  	defer r.Close()
   407  	defer w.Close()
   408  
   409  	result := make(chan error, 1)
   410  	go func() {
   411  		var b [1024]byte
   412  		_, err := r.Read(b[:])
   413  		result <- err
   414  	}()
   415  	if err := w.Close(); err != nil {
   416  		t.Errorf("w.Close: %v", err)
   417  	}
   418  
   419  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   420  		t.Errorf("got err %v, want io.EOF after Close", err)
   421  	}
   422  
   423  	if err := <-result; err != io.EOF {
   424  		t.Errorf("got %v (%T), want io.EOF", err, err)
   425  	}
   426  }
   427  
   428  func TestMuxCloseWriteChannel(t *testing.T) {
   429  	r, w, mux := channelPair(t)
   430  	defer mux.Close()
   431  
   432  	result := make(chan error, 1)
   433  	go func() {
   434  		var b [1024]byte
   435  		_, err := r.Read(b[:])
   436  		result <- err
   437  	}()
   438  	if err := w.CloseWrite(); err != nil {
   439  		t.Errorf("w.CloseWrite: %v", err)
   440  	}
   441  
   442  	if _, err := w.Write([]byte("hello")); err != io.EOF {
   443  		t.Errorf("got err %v, want io.EOF after CloseWrite", err)
   444  	}
   445  
   446  	if err := <-result; err != io.EOF {
   447  		t.Errorf("got %v (%T), want io.EOF", err, err)
   448  	}
   449  }
   450  
   451  func TestMuxInvalidRecord(t *testing.T) {
   452  	a, b := muxPair()
   453  	defer a.Close()
   454  	defer b.Close()
   455  
   456  	packet := make([]byte, 1+4+4+1)
   457  	packet[0] = msgChannelData
   458  	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
   459  	marshalUint32(packet[5:], 1)
   460  	packet[9] = 42
   461  
   462  	a.conn.writePacket(packet)
   463  	go a.SendRequest("hello", false, nil)
   464  	// 'a' wrote an invalid packet, so 'b' has exited.
   465  	req, ok := <-b.incomingRequests
   466  	if ok {
   467  		t.Errorf("got request %#v after receiving invalid packet", req)
   468  	}
   469  }
   470  
   471  func TestZeroWindowAdjust(t *testing.T) {
   472  	a, b, mux := channelPair(t)
   473  	defer a.Close()
   474  	defer b.Close()
   475  	defer mux.Close()
   476  
   477  	go func() {
   478  		io.WriteString(a, "hello")
   479  		// bogus adjust.
   480  		a.sendMessage(windowAdjustMsg{})
   481  		io.WriteString(a, "world")
   482  		a.Close()
   483  	}()
   484  
   485  	want := "helloworld"
   486  	c, _ := ioutil.ReadAll(b)
   487  	if string(c) != want {
   488  		t.Errorf("got %q want %q", c, want)
   489  	}
   490  }
   491  
   492  func TestMuxMaxPacketSize(t *testing.T) {
   493  	a, b, mux := channelPair(t)
   494  	defer a.Close()
   495  	defer b.Close()
   496  	defer mux.Close()
   497  
   498  	large := make([]byte, a.maxRemotePayload+1)
   499  	packet := make([]byte, 1+4+4+1+len(large))
   500  	packet[0] = msgChannelData
   501  	marshalUint32(packet[1:], a.remoteId)
   502  	marshalUint32(packet[5:], uint32(len(large)))
   503  	packet[9] = 42
   504  
   505  	if err := a.mux.conn.writePacket(packet); err != nil {
   506  		t.Errorf("could not send packet")
   507  	}
   508  
   509  	go a.SendRequest("hello", false, nil)
   510  
   511  	_, ok := <-b.incomingRequests
   512  	if ok {
   513  		t.Errorf("connection still alive after receiving large packet.")
   514  	}
   515  }
   516  
   517  // Don't ship code with debug=true.
   518  func TestDebug(t *testing.T) {
   519  	if debugMux {
   520  		t.Error("mux debug switched on")
   521  	}
   522  	if debugHandshake {
   523  		t.Error("handshake debug switched on")
   524  	}
   525  }