github.com/glycerine/xcryptossh@v7.0.4+incompatible/session_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  // Session tests.
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	crypto_rand "crypto/rand"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"io/ioutil"
    17  	"math/rand"
    18  	"net"
    19  	"testing"
    20  
    21  	"github.com/glycerine/xcryptossh/terminal"
    22  )
    23  
    24  type serverType func(Channel, <-chan *Request, *testing.T)
    25  
    26  // dial constructs a new test server and returns a *ClientConn.
    27  func dial(handler serverType, t *testing.T, halt *Halter) *Client {
    28  	c1, c2, err := netPipe()
    29  	if err != nil {
    30  		t.Fatalf("netPipe: %v", err)
    31  	}
    32  
    33  	go func() {
    34  		defer c1.Close()
    35  		conf := ServerConfig{
    36  			NoClientAuth: true,
    37  			Config: Config{
    38  				Halt: halt,
    39  			},
    40  		}
    41  		conf.AddHostKey(testSigners["rsa"])
    42  		ctx := context.Background()
    43  
    44  		_, chans, reqs, err := NewServerConn(ctx, c1, &conf)
    45  		if err != nil {
    46  			t.Fatalf("Unable to handshake: %v", err)
    47  		}
    48  		go DiscardRequests(ctx, reqs, conf.Halt)
    49  
    50  		for newCh := range chans {
    51  			if newCh.ChannelType() != "session" {
    52  				newCh.Reject(UnknownChannelType, "unknown channel type")
    53  				continue
    54  			}
    55  
    56  			ch, inReqs, err := newCh.Accept()
    57  			if err != nil {
    58  				t.Errorf("Accept: %v", err)
    59  				continue
    60  			}
    61  			go func() {
    62  				handler(ch, inReqs, t)
    63  			}()
    64  		}
    65  	}()
    66  
    67  	config := &ClientConfig{
    68  		User:            "testuser",
    69  		HostKeyCallback: InsecureIgnoreHostKey(),
    70  		Config: Config{
    71  			Halt: halt,
    72  		},
    73  	}
    74  	ctx := context.Background()
    75  
    76  	conn, chans, reqs, err := NewClientConn(ctx, c2, "", config)
    77  	if err != nil {
    78  		t.Fatalf("unable to dial remote side: %v", err)
    79  	}
    80  	return NewClient(ctx, conn, chans, reqs, halt)
    81  }
    82  
    83  // Test a simple string is returned to session.Stdout.
    84  func TestSessionShell(t *testing.T) {
    85  	defer xtestend(xtestbegin(t))
    86  
    87  	halt := NewHalter()
    88  	defer halt.RequestStop()
    89  
    90  	conn := dial(shellHandler, t, halt)
    91  	defer conn.Close()
    92  	ctx := context.Background()
    93  
    94  	session, err := conn.NewSession(ctx)
    95  	if err != nil {
    96  		t.Fatalf("Unable to request new session: %v", err)
    97  	}
    98  	defer session.Close()
    99  	stdout := new(bytes.Buffer)
   100  	session.Stdout = stdout
   101  	if err := session.Shell(); err != nil {
   102  		t.Fatalf("Unable to execute command: %s", err)
   103  	}
   104  	if err := session.Wait(); err != nil {
   105  		t.Fatalf("Remote command did not exit cleanly: %v", err)
   106  	}
   107  	actual := stdout.String()
   108  	if actual != "golang" {
   109  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
   110  	}
   111  }
   112  
   113  // TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
   114  
   115  // Test a simple string is returned via StdoutPipe.
   116  func TestSessionStdoutPipe(t *testing.T) {
   117  	defer xtestend(xtestbegin(t))
   118  
   119  	halt := NewHalter()
   120  	defer halt.RequestStop()
   121  
   122  	conn := dial(shellHandler, t, halt)
   123  	defer conn.Close()
   124  	ctx := context.Background()
   125  
   126  	session, err := conn.NewSession(ctx)
   127  	if err != nil {
   128  		t.Fatalf("Unable to request new session: %v", err)
   129  	}
   130  	defer session.Close()
   131  	stdout, err := session.StdoutPipe()
   132  	if err != nil {
   133  		t.Fatalf("Unable to request StdoutPipe(): %v", err)
   134  	}
   135  	var buf bytes.Buffer
   136  	if err := session.Shell(); err != nil {
   137  		t.Fatalf("Unable to execute command: %v", err)
   138  	}
   139  	done := make(chan bool, 1)
   140  	go func() {
   141  		if _, err := io.Copy(&buf, stdout); err != nil {
   142  			t.Errorf("Copy of stdout failed: %v", err)
   143  		}
   144  		done <- true
   145  	}()
   146  	if err := session.Wait(); err != nil {
   147  		t.Fatalf("Remote command did not exit cleanly: %v", err)
   148  	}
   149  	<-done
   150  	actual := buf.String()
   151  	if actual != "golang" {
   152  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
   153  	}
   154  }
   155  
   156  // Test that a simple string is returned via the Output helper,
   157  // and that stderr is discarded.
   158  func TestSessionOutput(t *testing.T) {
   159  	defer xtestend(xtestbegin(t))
   160  
   161  	halt := NewHalter()
   162  	defer halt.RequestStop()
   163  
   164  	conn := dial(fixedOutputHandler, t, halt)
   165  	defer conn.Close()
   166  	ctx := context.Background()
   167  
   168  	session, err := conn.NewSession(ctx)
   169  	if err != nil {
   170  		t.Fatalf("Unable to request new session: %v", err)
   171  	}
   172  	defer session.Close()
   173  
   174  	buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
   175  	if err != nil {
   176  		t.Error("Remote command did not exit cleanly:", err)
   177  	}
   178  	w := "this-is-stdout."
   179  	g := string(buf)
   180  	if g != w {
   181  		t.Error("Remote command did not return expected string:")
   182  		t.Logf("want %q", w)
   183  		t.Logf("got  %q", g)
   184  	}
   185  }
   186  
   187  // Test that both stdout and stderr are returned
   188  // via the CombinedOutput helper.
   189  func TestSessionCombinedOutput(t *testing.T) {
   190  	defer xtestend(xtestbegin(t))
   191  
   192  	halt := NewHalter()
   193  	defer halt.RequestStop()
   194  
   195  	conn := dial(fixedOutputHandler, t, halt)
   196  	defer conn.Close()
   197  	ctx := context.Background()
   198  
   199  	session, err := conn.NewSession(ctx)
   200  	if err != nil {
   201  		t.Fatalf("Unable to request new session: %v", err)
   202  	}
   203  	defer session.Close()
   204  
   205  	buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
   206  	if err != nil {
   207  		t.Error("Remote command did not exit cleanly:", err)
   208  	}
   209  	const stdout = "this-is-stdout."
   210  	const stderr = "this-is-stderr."
   211  	g := string(buf)
   212  	if g != stdout+stderr && g != stderr+stdout {
   213  		t.Error("Remote command did not return expected string:")
   214  		t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
   215  		t.Logf("got  %q", g)
   216  	}
   217  }
   218  
   219  // Test non-0 exit status is returned correctly.
   220  func TestExitStatusNonZero(t *testing.T) {
   221  	defer xtestend(xtestbegin(t))
   222  
   223  	halt := NewHalter()
   224  	defer halt.RequestStop()
   225  
   226  	conn := dial(exitStatusNonZeroHandler, t, halt)
   227  	defer conn.Close()
   228  	ctx := context.Background()
   229  
   230  	session, err := conn.NewSession(ctx)
   231  	if err != nil {
   232  		t.Fatalf("Unable to request new session: %v", err)
   233  	}
   234  	defer session.Close()
   235  	if err := session.Shell(); err != nil {
   236  		t.Fatalf("Unable to execute command: %v", err)
   237  	}
   238  	err = session.Wait()
   239  	if err == nil {
   240  		t.Fatalf("expected command to fail but it didn't")
   241  	}
   242  	e, ok := err.(*ExitError)
   243  	if !ok {
   244  		t.Fatalf("expected *ExitError but got %T", err)
   245  	}
   246  	if e.ExitStatus() != 15 {
   247  		t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
   248  	}
   249  }
   250  
   251  // Test 0 exit status is returned correctly.
   252  func TestExitStatusZero(t *testing.T) {
   253  	defer xtestend(xtestbegin(t))
   254  
   255  	halt := NewHalter()
   256  	defer halt.RequestStop()
   257  
   258  	conn := dial(exitStatusZeroHandler, t, halt)
   259  	defer conn.Close()
   260  	ctx := context.Background()
   261  
   262  	session, err := conn.NewSession(ctx)
   263  	if err != nil {
   264  		t.Fatalf("Unable to request new session: %v", err)
   265  	}
   266  	defer session.Close()
   267  
   268  	if err := session.Shell(); err != nil {
   269  		t.Fatalf("Unable to execute command: %v", err)
   270  	}
   271  	err = session.Wait()
   272  	if err != nil {
   273  		panic(fmt.Sprintf("expected nil but got %v", err))
   274  	}
   275  }
   276  
   277  // Test exit signal and status are both returned correctly.
   278  func TestExitSignalAndStatus(t *testing.T) {
   279  	defer xtestend(xtestbegin(t))
   280  
   281  	halt := NewHalter()
   282  	defer halt.RequestStop()
   283  
   284  	conn := dial(exitSignalAndStatusHandler, t, halt)
   285  	defer conn.Close()
   286  	ctx := context.Background()
   287  
   288  	session, err := conn.NewSession(ctx)
   289  	if err != nil {
   290  		t.Fatalf("Unable to request new session: %v", err)
   291  	}
   292  	defer session.Close()
   293  	if err := session.Shell(); err != nil {
   294  		t.Fatalf("Unable to execute command: %v", err)
   295  	}
   296  	err = session.Wait()
   297  	if err == nil {
   298  		t.Fatalf("expected command to fail but it didn't")
   299  	}
   300  	e, ok := err.(*ExitError)
   301  	if !ok {
   302  		t.Fatalf("expected *ExitError but got %T", err)
   303  	}
   304  	if e.Signal() != "TERM" || e.ExitStatus() != 15 {
   305  		t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   306  	}
   307  }
   308  
   309  // Test exit signal and status are both returned correctly.
   310  func TestKnownExitSignalOnly(t *testing.T) {
   311  	defer xtestend(xtestbegin(t))
   312  
   313  	halt := NewHalter()
   314  	defer halt.RequestStop()
   315  
   316  	conn := dial(exitSignalHandler, t, halt)
   317  	defer conn.Close()
   318  	ctx := context.Background()
   319  
   320  	session, err := conn.NewSession(ctx)
   321  	if err != nil {
   322  		t.Fatalf("Unable to request new session: %v", err)
   323  	}
   324  	defer session.Close()
   325  	if err := session.Shell(); err != nil {
   326  		t.Fatalf("Unable to execute command: %v", err)
   327  	}
   328  	err = session.Wait()
   329  	if err == nil {
   330  		t.Fatalf("expected command to fail but it didn't")
   331  	}
   332  	e, ok := err.(*ExitError)
   333  	if !ok {
   334  		t.Fatalf("expected *ExitError but got %T", err)
   335  	}
   336  	if e.Signal() != "TERM" || e.ExitStatus() != 143 {
   337  		t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   338  	}
   339  }
   340  
   341  // Test exit signal and status are both returned correctly.
   342  func TestUnknownExitSignal(t *testing.T) {
   343  	defer xtestend(xtestbegin(t))
   344  
   345  	halt := NewHalter()
   346  	defer halt.RequestStop()
   347  
   348  	conn := dial(exitSignalUnknownHandler, t, halt)
   349  	defer conn.Close()
   350  	ctx := context.Background()
   351  
   352  	session, err := conn.NewSession(ctx)
   353  	if err != nil {
   354  		t.Fatalf("Unable to request new session: %v", err)
   355  	}
   356  	defer session.Close()
   357  	if err := session.Shell(); err != nil {
   358  		t.Fatalf("Unable to execute command: %v", err)
   359  	}
   360  	err = session.Wait()
   361  	if err == nil {
   362  		t.Fatalf("expected command to fail but it didn't")
   363  	}
   364  	e, ok := err.(*ExitError)
   365  	if !ok {
   366  		t.Fatalf("expected *ExitError but got %T", err)
   367  	}
   368  	if e.Signal() != "SYS" || e.ExitStatus() != 128 {
   369  		t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   370  	}
   371  }
   372  
   373  func TestExitWithoutStatusOrSignal(t *testing.T) {
   374  	defer xtestend(xtestbegin(t))
   375  
   376  	halt := NewHalter()
   377  	defer halt.RequestStop()
   378  
   379  	conn := dial(exitWithoutSignalOrStatus, t, halt)
   380  	defer conn.Close()
   381  	ctx := context.Background()
   382  
   383  	session, err := conn.NewSession(ctx)
   384  	if err != nil {
   385  		t.Fatalf("Unable to request new session: %v", err)
   386  	}
   387  	defer session.Close()
   388  	if err := session.Shell(); err != nil {
   389  		t.Fatalf("Unable to execute command: %v", err)
   390  	}
   391  	err = session.Wait()
   392  	if err == nil {
   393  		t.Fatalf("expected command to fail but it didn't")
   394  	}
   395  	if _, ok := err.(*ExitMissingError); !ok {
   396  		t.Fatalf("got %T want *ExitMissingError", err)
   397  	}
   398  }
   399  
   400  // windowTestBytes is the number of bytes that we'll send to the SSH server.
   401  const windowTestBytes = 16000 * 200
   402  
   403  // TestServerWindow writes random data to the server. The server is expected to echo
   404  // the same data back, which is compared against the original.
   405  func TestServerWindow(t *testing.T) {
   406  	defer xtestend(xtestbegin(t))
   407  	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   408  	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
   409  	origBytes := origBuf.Bytes()
   410  
   411  	halt := NewHalter()
   412  	defer halt.RequestStop()
   413  
   414  	conn := dial(echoHandler, t, halt)
   415  	defer conn.Close()
   416  	ctx := context.Background()
   417  
   418  	session, err := conn.NewSession(ctx)
   419  	if err != nil {
   420  		t.Fatal(err)
   421  	}
   422  	defer session.Close()
   423  	result := make(chan []byte)
   424  
   425  	go func() {
   426  		defer close(result)
   427  		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   428  		serverStdout, err := session.StdoutPipe()
   429  		if err != nil {
   430  			t.Errorf("StdoutPipe failed: %v", err)
   431  			return
   432  		}
   433  		n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
   434  		if err != nil && err != io.EOF {
   435  			t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
   436  		}
   437  		result <- echoedBuf.Bytes()
   438  	}()
   439  
   440  	serverStdin, err := session.StdinPipe()
   441  	if err != nil {
   442  		t.Fatalf("StdinPipe failed: %v", err)
   443  	}
   444  	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
   445  	if err != nil {
   446  		t.Fatalf("failed to copy origBuf to serverStdin: %v", err)
   447  	}
   448  	if written != windowTestBytes {
   449  		t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes)
   450  	}
   451  
   452  	echoedBytes := <-result
   453  
   454  	if !bytes.Equal(origBytes, echoedBytes) {
   455  		t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
   456  	}
   457  }
   458  
   459  // Verify the client can handle a keepalive packet from the server.
   460  func TestClientHandlesKeepalives(t *testing.T) {
   461  	defer xtestend(xtestbegin(t))
   462  
   463  	halt := NewHalter()
   464  	defer halt.RequestStop()
   465  
   466  	conn := dial(channelKeepaliveSender, t, halt)
   467  	defer conn.Close()
   468  	ctx := context.Background()
   469  
   470  	session, err := conn.NewSession(ctx)
   471  	if err != nil {
   472  		t.Fatal(err)
   473  	}
   474  	defer session.Close()
   475  	if err := session.Shell(); err != nil {
   476  		t.Fatalf("Unable to execute command: %v", err)
   477  	}
   478  	err = session.Wait()
   479  	if err != nil {
   480  		t.Fatalf("expected nil but got: %v", err)
   481  	}
   482  }
   483  
   484  type exitStatusMsg struct {
   485  	Status uint32
   486  }
   487  
   488  type exitSignalMsg struct {
   489  	Signal     string
   490  	CoreDumped bool
   491  	Errmsg     string
   492  	Lang       string
   493  }
   494  
   495  func handleTerminalRequests(in <-chan *Request) {
   496  	for req := range in {
   497  		ok := false
   498  		switch req.Type {
   499  		case "shell":
   500  			ok = true
   501  			if len(req.Payload) > 0 {
   502  				// We don't accept any commands, only the default shell.
   503  				ok = false
   504  			}
   505  		case "env":
   506  			ok = true
   507  		}
   508  		req.Reply(ok, nil)
   509  	}
   510  }
   511  
   512  func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
   513  	term := terminal.NewTerminal(ch, prompt)
   514  	go handleTerminalRequests(in)
   515  	return term
   516  }
   517  
   518  func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   519  	defer ch.Close()
   520  	// this string is returned to stdout
   521  	shell := newServerShell(ch, in, "> ")
   522  	readLine(shell, t)
   523  	sendStatus(0, ch, t)
   524  }
   525  
   526  func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   527  	defer ch.Close()
   528  	shell := newServerShell(ch, in, "> ")
   529  	readLine(shell, t)
   530  	sendStatus(15, ch, t)
   531  }
   532  
   533  func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
   534  	defer ch.Close()
   535  	shell := newServerShell(ch, in, "> ")
   536  	readLine(shell, t)
   537  	sendStatus(15, ch, t)
   538  	sendSignal("TERM", ch, t)
   539  }
   540  
   541  func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
   542  	defer ch.Close()
   543  	shell := newServerShell(ch, in, "> ")
   544  	readLine(shell, t)
   545  	sendSignal("TERM", ch, t)
   546  }
   547  
   548  func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
   549  	defer ch.Close()
   550  	shell := newServerShell(ch, in, "> ")
   551  	readLine(shell, t)
   552  	sendSignal("SYS", ch, t)
   553  }
   554  
   555  func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
   556  	defer ch.Close()
   557  	shell := newServerShell(ch, in, "> ")
   558  	readLine(shell, t)
   559  }
   560  
   561  func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
   562  	defer ch.Close()
   563  	// this string is returned to stdout
   564  	shell := newServerShell(ch, in, "golang")
   565  	readLine(shell, t)
   566  	sendStatus(0, ch, t)
   567  }
   568  
   569  // Ignores the command, writes fixed strings to stderr and stdout.
   570  // Strings are "this-is-stdout." and "this-is-stderr.".
   571  func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
   572  	defer ch.Close()
   573  	_, err := ch.Read(nil)
   574  
   575  	req, ok := <-in
   576  	if !ok {
   577  		t.Fatalf("error: expected channel request, got: %#v", err)
   578  		return
   579  	}
   580  
   581  	// ignore request, always send some text
   582  	req.Reply(true, nil)
   583  
   584  	_, err = io.WriteString(ch, "this-is-stdout.")
   585  	if err != nil {
   586  		t.Fatalf("error writing on server: %v", err)
   587  	}
   588  	_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
   589  	if err != nil {
   590  		t.Fatalf("error writing on server: %v", err)
   591  	}
   592  	sendStatus(0, ch, t)
   593  }
   594  
   595  func readLine(shell *terminal.Terminal, t *testing.T) {
   596  	if _, err := shell.ReadLine(); err != nil && err != io.EOF {
   597  		t.Errorf("unable to read line: %v", err)
   598  	}
   599  }
   600  
   601  func sendStatus(status uint32, ch Channel, t *testing.T) {
   602  	msg := exitStatusMsg{
   603  		Status: status,
   604  	}
   605  	if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
   606  		t.Errorf("unable to send status: %v", err)
   607  	}
   608  }
   609  
   610  func sendSignal(signal string, ch Channel, t *testing.T) {
   611  	sig := exitSignalMsg{
   612  		Signal:     signal,
   613  		CoreDumped: false,
   614  		Errmsg:     "Process terminated",
   615  		Lang:       "en-GB-oed",
   616  	}
   617  	if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
   618  		t.Errorf("unable to send signal: %v", err)
   619  	}
   620  }
   621  
   622  func discardHandler(ch Channel, t *testing.T) {
   623  	defer ch.Close()
   624  	io.Copy(ioutil.Discard, ch)
   625  }
   626  
   627  func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   628  	defer ch.Close()
   629  	if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
   630  		t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
   631  	}
   632  }
   633  
   634  // copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
   635  // buffer size to exercise more code paths.
   636  func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
   637  	var (
   638  		buf       = make([]byte, 32*1024)
   639  		written   int
   640  		remaining = n
   641  	)
   642  	for remaining > 0 {
   643  		l := rand.Intn(1 << 15)
   644  		if remaining < l {
   645  			l = remaining
   646  		}
   647  		nr, er := src.Read(buf[:l])
   648  		nw, ew := dst.Write(buf[:nr])
   649  		remaining -= nw
   650  		written += nw
   651  		if ew != nil {
   652  			return written, ew
   653  		}
   654  		if nr != nw {
   655  			return written, io.ErrShortWrite
   656  		}
   657  		if er != nil && er != io.EOF {
   658  			return written, er
   659  		}
   660  	}
   661  	return written, nil
   662  }
   663  
   664  func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
   665  	defer ch.Close()
   666  	shell := newServerShell(ch, in, "> ")
   667  	readLine(shell, t)
   668  	if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
   669  		t.Errorf("unable to send channel keepalive request: %v", err)
   670  	}
   671  	sendStatus(0, ch, t)
   672  }
   673  
   674  func TestClientWriteEOF(t *testing.T) {
   675  	defer xtestend(xtestbegin(t))
   676  
   677  	halt := NewHalter()
   678  	defer halt.RequestStop()
   679  
   680  	conn := dial(simpleEchoHandler, t, halt)
   681  	defer conn.Close()
   682  	ctx := context.Background()
   683  
   684  	session, err := conn.NewSession(ctx)
   685  	if err != nil {
   686  		t.Fatal(err)
   687  	}
   688  	defer session.Close()
   689  	stdin, err := session.StdinPipe()
   690  	if err != nil {
   691  		t.Fatalf("StdinPipe failed: %v", err)
   692  	}
   693  	stdout, err := session.StdoutPipe()
   694  	if err != nil {
   695  		t.Fatalf("StdoutPipe failed: %v", err)
   696  	}
   697  
   698  	data := []byte(`0000`)
   699  	_, err = stdin.Write(data)
   700  	if err != nil {
   701  		t.Fatalf("Write failed: %v", err)
   702  	}
   703  	stdin.Close()
   704  
   705  	res, err := ioutil.ReadAll(stdout)
   706  	if err != nil {
   707  		t.Fatalf("Read failed: %v", err)
   708  	}
   709  
   710  	if !bytes.Equal(data, res) {
   711  		t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
   712  	}
   713  }
   714  
   715  func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   716  	defer ch.Close()
   717  	data, err := ioutil.ReadAll(ch)
   718  	if err != nil {
   719  		t.Errorf("handler read error: %v", err)
   720  	}
   721  	_, err = ch.Write(data)
   722  	if err != nil {
   723  		t.Errorf("handler write error: %v", err)
   724  	}
   725  }
   726  
   727  func TestSessionID(t *testing.T) {
   728  	defer xtestend(xtestbegin(t))
   729  
   730  	halt := NewHalter()
   731  	defer halt.RequestStop()
   732  
   733  	c1, c2, err := netPipe()
   734  	if err != nil {
   735  		t.Fatalf("netPipe: %v", err)
   736  	}
   737  	defer c1.Close()
   738  	defer c2.Close()
   739  
   740  	serverID := make(chan []byte, 1)
   741  	clientID := make(chan []byte, 1)
   742  
   743  	serverConf := &ServerConfig{
   744  		NoClientAuth: true,
   745  		Config: Config{
   746  			Halt: halt,
   747  		},
   748  	}
   749  	serverConf.AddHostKey(testSigners["ecdsa"])
   750  	clientConf := &ClientConfig{
   751  		HostKeyCallback: InsecureIgnoreHostKey(),
   752  		User:            "user",
   753  		Config: Config{
   754  			Halt: halt,
   755  		},
   756  	}
   757  
   758  	go func() {
   759  		ctx := context.Background()
   760  		conn, chans, reqs, err := NewServerConn(ctx, c1, serverConf)
   761  		if err != nil {
   762  			t.Fatalf("server handshake: %v", err)
   763  		}
   764  		serverID <- conn.SessionID()
   765  		go DiscardRequests(ctx, reqs, serverConf.Halt)
   766  		for ch := range chans {
   767  			ch.Reject(Prohibited, "")
   768  		}
   769  	}()
   770  
   771  	go func() {
   772  		ctx := context.Background()
   773  		conn, chans, reqs, err := NewClientConn(ctx, c2, "", clientConf)
   774  		if err != nil {
   775  			t.Fatalf("client handshake: %v", err)
   776  		}
   777  		clientID <- conn.SessionID()
   778  		go DiscardRequests(ctx, reqs, clientConf.Halt)
   779  		for ch := range chans {
   780  			ch.Reject(Prohibited, "")
   781  		}
   782  	}()
   783  
   784  	s := <-serverID
   785  	c := <-clientID
   786  	if bytes.Compare(s, c) != 0 {
   787  		t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
   788  	} else if len(s) == 0 {
   789  		t.Errorf("client and server SessionID were empty.")
   790  	}
   791  }
   792  
   793  type noReadConn struct {
   794  	readSeen bool
   795  	net.Conn
   796  }
   797  
   798  func (c *noReadConn) Close() error {
   799  	return nil
   800  }
   801  
   802  func (c *noReadConn) Read(b []byte) (int, error) {
   803  	c.readSeen = true
   804  	return 0, errors.New("noReadConn error")
   805  }
   806  
   807  func TestInvalidServerConfiguration(t *testing.T) {
   808  	defer xtestend(xtestbegin(t))
   809  	c1, c2, err := netPipe()
   810  	if err != nil {
   811  		t.Fatalf("netPipe: %v", err)
   812  	}
   813  	defer c1.Close()
   814  	defer c2.Close()
   815  
   816  	serveConn := noReadConn{Conn: c1}
   817  	serverConf := &ServerConfig{}
   818  	ctx := context.Background()
   819  
   820  	NewServerConn(ctx, &serveConn, serverConf)
   821  	if serveConn.readSeen {
   822  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
   823  	}
   824  
   825  	serverConf.AddHostKey(testSigners["ecdsa"])
   826  
   827  	NewServerConn(ctx, &serveConn, serverConf)
   828  	if serveConn.readSeen {
   829  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
   830  	}
   831  }
   832  
   833  func TestHostKeyAlgorithms(t *testing.T) {
   834  	defer xtestend(xtestbegin(t))
   835  	serverConf := &ServerConfig{
   836  		NoClientAuth: true,
   837  		Config: Config{
   838  			Halt: NewHalter(),
   839  		},
   840  	}
   841  	defer serverConf.Halt.RequestStop()
   842  	serverConf.AddHostKey(testSigners["rsa"])
   843  	serverConf.AddHostKey(testSigners["ecdsa"])
   844  
   845  	connect := func(clientConf *ClientConfig, want string) {
   846  		var alg string
   847  		clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
   848  			alg = key.Type()
   849  			return nil
   850  		}
   851  		c1, c2, err := netPipe()
   852  		if err != nil {
   853  			t.Fatalf("netPipe: %v", err)
   854  		}
   855  		defer c1.Close()
   856  		defer c2.Close()
   857  		ctx := context.Background()
   858  
   859  		go NewServerConn(ctx, c1, serverConf)
   860  		_, _, _, err = NewClientConn(ctx, c2, "", clientConf)
   861  		if err != nil {
   862  			t.Fatalf("NewClientConn: %v", err)
   863  		}
   864  		if alg != want {
   865  			t.Errorf("selected key algorithm %s, want %s", alg, want)
   866  		}
   867  	}
   868  
   869  	// By default, we get the preferred algorithm, which is ECDSA 256.
   870  
   871  	clientConf := &ClientConfig{
   872  		HostKeyCallback: InsecureIgnoreHostKey(),
   873  		Config: Config{
   874  			Halt: NewHalter(),
   875  		},
   876  	}
   877  	defer clientConf.Halt.RequestStop()
   878  	connect(clientConf, KeyAlgoECDSA256)
   879  
   880  	// Client asks for RSA explicitly.
   881  	clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
   882  	connect(clientConf, KeyAlgoRSA)
   883  
   884  	c1, c2, err := netPipe()
   885  	if err != nil {
   886  		t.Fatalf("netPipe: %v", err)
   887  	}
   888  	defer c1.Close()
   889  	defer c2.Close()
   890  	ctx := context.Background()
   891  
   892  	go NewServerConn(ctx, c1, serverConf)
   893  	clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
   894  	_, _, _, err = NewClientConn(ctx, c2, "", clientConf)
   895  	if err == nil {
   896  		t.Fatal("succeeded connecting with unknown hostkey algorithm")
   897  	}
   898  }