github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/session_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  // Session tests.
     8  
     9  import (
    10  	"bytes"
    11  	crypto_rand "crypto/rand"
    12  	"errors"
    13  	"io"
    14  	"io/ioutil"
    15  	"math/rand"
    16  	"net"
    17  	"testing"
    18  
    19  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh/terminal"
    20  )
    21  
    22  type serverType func(Channel, <-chan *Request, *testing.T)
    23  
    24  // dial constructs a new test server and returns a *ClientConn.
    25  func dial(handler serverType, t *testing.T) *Client {
    26  	c1, c2, err := netPipe()
    27  	if err != nil {
    28  		t.Fatalf("netPipe: %v", err)
    29  	}
    30  
    31  	go func() {
    32  		defer c1.Close()
    33  		conf := ServerConfig{
    34  			NoClientAuth: true,
    35  		}
    36  		conf.AddHostKey(testSigners["rsa"])
    37  
    38  		conn, chans, reqs, err := NewServerConn(c1, &conf)
    39  		if err != nil {
    40  			t.Fatalf("Unable to handshake: %v", err)
    41  		}
    42  		go DiscardRequests(reqs)
    43  
    44  		for newCh := range chans {
    45  			if newCh.ChannelType() != "session" {
    46  				newCh.Reject(UnknownChannelType, "unknown channel type")
    47  				continue
    48  			}
    49  
    50  			ch, inReqs, err := newCh.Accept()
    51  			if err != nil {
    52  				t.Errorf("Accept: %v", err)
    53  				continue
    54  			}
    55  			go func() {
    56  				handler(ch, inReqs, t)
    57  			}()
    58  		}
    59  		if err := conn.Wait(); err != io.EOF {
    60  			t.Logf("server exit reason: %v", err)
    61  		}
    62  	}()
    63  
    64  	config := &ClientConfig{
    65  		User:            "testuser",
    66  		HostKeyCallback: InsecureIgnoreHostKey(),
    67  	}
    68  
    69  	conn, chans, reqs, err := NewClientConn(c2, "", config)
    70  	if err != nil {
    71  		t.Fatalf("unable to dial remote side: %v", err)
    72  	}
    73  
    74  	return NewClient(conn, chans, reqs)
    75  }
    76  
    77  // Test a simple string is returned to session.Stdout.
    78  func TestSessionShell(t *testing.T) {
    79  	conn := dial(shellHandler, t)
    80  	defer conn.Close()
    81  	session, err := conn.NewSession()
    82  	if err != nil {
    83  		t.Fatalf("Unable to request new session: %v", err)
    84  	}
    85  	defer session.Close()
    86  	stdout := new(bytes.Buffer)
    87  	session.Stdout = stdout
    88  	if err := session.Shell(); err != nil {
    89  		t.Fatalf("Unable to execute command: %s", err)
    90  	}
    91  	if err := session.Wait(); err != nil {
    92  		t.Fatalf("Remote command did not exit cleanly: %v", err)
    93  	}
    94  	actual := stdout.String()
    95  	if actual != "golang" {
    96  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
    97  	}
    98  }
    99  
   100  // TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
   101  
   102  // Test a simple string is returned via StdoutPipe.
   103  func TestSessionStdoutPipe(t *testing.T) {
   104  	conn := dial(shellHandler, t)
   105  	defer conn.Close()
   106  	session, err := conn.NewSession()
   107  	if err != nil {
   108  		t.Fatalf("Unable to request new session: %v", err)
   109  	}
   110  	defer session.Close()
   111  	stdout, err := session.StdoutPipe()
   112  	if err != nil {
   113  		t.Fatalf("Unable to request StdoutPipe(): %v", err)
   114  	}
   115  	var buf bytes.Buffer
   116  	if err := session.Shell(); err != nil {
   117  		t.Fatalf("Unable to execute command: %v", err)
   118  	}
   119  	done := make(chan bool, 1)
   120  	go func() {
   121  		if _, err := io.Copy(&buf, stdout); err != nil {
   122  			t.Errorf("Copy of stdout failed: %v", err)
   123  		}
   124  		done <- true
   125  	}()
   126  	if err := session.Wait(); err != nil {
   127  		t.Fatalf("Remote command did not exit cleanly: %v", err)
   128  	}
   129  	<-done
   130  	actual := buf.String()
   131  	if actual != "golang" {
   132  		t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
   133  	}
   134  }
   135  
   136  // Test that a simple string is returned via the Output helper,
   137  // and that stderr is discarded.
   138  func TestSessionOutput(t *testing.T) {
   139  	conn := dial(fixedOutputHandler, t)
   140  	defer conn.Close()
   141  	session, err := conn.NewSession()
   142  	if err != nil {
   143  		t.Fatalf("Unable to request new session: %v", err)
   144  	}
   145  	defer session.Close()
   146  
   147  	buf, err := session.Output("") // cmd is ignored by fixedOutputHandler
   148  	if err != nil {
   149  		t.Error("Remote command did not exit cleanly:", err)
   150  	}
   151  	w := "this-is-stdout."
   152  	g := string(buf)
   153  	if g != w {
   154  		t.Error("Remote command did not return expected string:")
   155  		t.Logf("want %q", w)
   156  		t.Logf("got  %q", g)
   157  	}
   158  }
   159  
   160  // Test that both stdout and stderr are returned
   161  // via the CombinedOutput helper.
   162  func TestSessionCombinedOutput(t *testing.T) {
   163  	conn := dial(fixedOutputHandler, t)
   164  	defer conn.Close()
   165  	session, err := conn.NewSession()
   166  	if err != nil {
   167  		t.Fatalf("Unable to request new session: %v", err)
   168  	}
   169  	defer session.Close()
   170  
   171  	buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler
   172  	if err != nil {
   173  		t.Error("Remote command did not exit cleanly:", err)
   174  	}
   175  	const stdout = "this-is-stdout."
   176  	const stderr = "this-is-stderr."
   177  	g := string(buf)
   178  	if g != stdout+stderr && g != stderr+stdout {
   179  		t.Error("Remote command did not return expected string:")
   180  		t.Logf("want %q, or %q", stdout+stderr, stderr+stdout)
   181  		t.Logf("got  %q", g)
   182  	}
   183  }
   184  
   185  // Test non-0 exit status is returned correctly.
   186  func TestExitStatusNonZero(t *testing.T) {
   187  	conn := dial(exitStatusNonZeroHandler, t)
   188  	defer conn.Close()
   189  	session, err := conn.NewSession()
   190  	if err != nil {
   191  		t.Fatalf("Unable to request new session: %v", err)
   192  	}
   193  	defer session.Close()
   194  	if err := session.Shell(); err != nil {
   195  		t.Fatalf("Unable to execute command: %v", err)
   196  	}
   197  	err = session.Wait()
   198  	if err == nil {
   199  		t.Fatalf("expected command to fail but it didn't")
   200  	}
   201  	e, ok := err.(*ExitError)
   202  	if !ok {
   203  		t.Fatalf("expected *ExitError but got %T", err)
   204  	}
   205  	if e.ExitStatus() != 15 {
   206  		t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus())
   207  	}
   208  }
   209  
   210  // Test 0 exit status is returned correctly.
   211  func TestExitStatusZero(t *testing.T) {
   212  	conn := dial(exitStatusZeroHandler, t)
   213  	defer conn.Close()
   214  	session, err := conn.NewSession()
   215  	if err != nil {
   216  		t.Fatalf("Unable to request new session: %v", err)
   217  	}
   218  	defer session.Close()
   219  
   220  	if err := session.Shell(); err != nil {
   221  		t.Fatalf("Unable to execute command: %v", err)
   222  	}
   223  	err = session.Wait()
   224  	if err != nil {
   225  		t.Fatalf("expected nil but got %v", err)
   226  	}
   227  }
   228  
   229  // Test exit signal and status are both returned correctly.
   230  func TestExitSignalAndStatus(t *testing.T) {
   231  	conn := dial(exitSignalAndStatusHandler, t)
   232  	defer conn.Close()
   233  	session, err := conn.NewSession()
   234  	if err != nil {
   235  		t.Fatalf("Unable to request new session: %v", err)
   236  	}
   237  	defer session.Close()
   238  	if err := session.Shell(); err != nil {
   239  		t.Fatalf("Unable to execute command: %v", err)
   240  	}
   241  	err = session.Wait()
   242  	if err == nil {
   243  		t.Fatalf("expected command to fail but it didn't")
   244  	}
   245  	e, ok := err.(*ExitError)
   246  	if !ok {
   247  		t.Fatalf("expected *ExitError but got %T", err)
   248  	}
   249  	if e.Signal() != "TERM" || e.ExitStatus() != 15 {
   250  		t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   251  	}
   252  }
   253  
   254  // Test exit signal and status are both returned correctly.
   255  func TestKnownExitSignalOnly(t *testing.T) {
   256  	conn := dial(exitSignalHandler, t)
   257  	defer conn.Close()
   258  	session, err := conn.NewSession()
   259  	if err != nil {
   260  		t.Fatalf("Unable to request new session: %v", err)
   261  	}
   262  	defer session.Close()
   263  	if err := session.Shell(); err != nil {
   264  		t.Fatalf("Unable to execute command: %v", err)
   265  	}
   266  	err = session.Wait()
   267  	if err == nil {
   268  		t.Fatalf("expected command to fail but it didn't")
   269  	}
   270  	e, ok := err.(*ExitError)
   271  	if !ok {
   272  		t.Fatalf("expected *ExitError but got %T", err)
   273  	}
   274  	if e.Signal() != "TERM" || e.ExitStatus() != 143 {
   275  		t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   276  	}
   277  }
   278  
   279  // Test exit signal and status are both returned correctly.
   280  func TestUnknownExitSignal(t *testing.T) {
   281  	conn := dial(exitSignalUnknownHandler, t)
   282  	defer conn.Close()
   283  	session, err := conn.NewSession()
   284  	if err != nil {
   285  		t.Fatalf("Unable to request new session: %v", err)
   286  	}
   287  	defer session.Close()
   288  	if err := session.Shell(); err != nil {
   289  		t.Fatalf("Unable to execute command: %v", err)
   290  	}
   291  	err = session.Wait()
   292  	if err == nil {
   293  		t.Fatalf("expected command to fail but it didn't")
   294  	}
   295  	e, ok := err.(*ExitError)
   296  	if !ok {
   297  		t.Fatalf("expected *ExitError but got %T", err)
   298  	}
   299  	if e.Signal() != "SYS" || e.ExitStatus() != 128 {
   300  		t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus())
   301  	}
   302  }
   303  
   304  func TestExitWithoutStatusOrSignal(t *testing.T) {
   305  	conn := dial(exitWithoutSignalOrStatus, t)
   306  	defer conn.Close()
   307  	session, err := conn.NewSession()
   308  	if err != nil {
   309  		t.Fatalf("Unable to request new session: %v", err)
   310  	}
   311  	defer session.Close()
   312  	if err := session.Shell(); err != nil {
   313  		t.Fatalf("Unable to execute command: %v", err)
   314  	}
   315  	err = session.Wait()
   316  	if err == nil {
   317  		t.Fatalf("expected command to fail but it didn't")
   318  	}
   319  	if _, ok := err.(*ExitMissingError); !ok {
   320  		t.Fatalf("got %T want *ExitMissingError", err)
   321  	}
   322  }
   323  
   324  // windowTestBytes is the number of bytes that we'll send to the SSH server.
   325  const windowTestBytes = 16000 * 200
   326  
   327  // TestServerWindow writes random data to the server. The server is expected to echo
   328  // the same data back, which is compared against the original.
   329  func TestServerWindow(t *testing.T) {
   330  	origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   331  	io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes)
   332  	origBytes := origBuf.Bytes()
   333  
   334  	conn := dial(echoHandler, t)
   335  	defer conn.Close()
   336  	session, err := conn.NewSession()
   337  	if err != nil {
   338  		t.Fatal(err)
   339  	}
   340  	defer session.Close()
   341  	result := make(chan []byte)
   342  
   343  	go func() {
   344  		defer close(result)
   345  		echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes))
   346  		serverStdout, err := session.StdoutPipe()
   347  		if err != nil {
   348  			t.Errorf("StdoutPipe failed: %v", err)
   349  			return
   350  		}
   351  		n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes)
   352  		if err != nil && err != io.EOF {
   353  			t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err)
   354  		}
   355  		result <- echoedBuf.Bytes()
   356  	}()
   357  
   358  	serverStdin, err := session.StdinPipe()
   359  	if err != nil {
   360  		t.Fatalf("StdinPipe failed: %v", err)
   361  	}
   362  	written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes)
   363  	if err != nil {
   364  		t.Errorf("failed to copy origBuf to serverStdin: %v", err)
   365  	} else if written != windowTestBytes {
   366  		t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes)
   367  	}
   368  
   369  	echoedBytes := <-result
   370  
   371  	if !bytes.Equal(origBytes, echoedBytes) {
   372  		t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes))
   373  	}
   374  }
   375  
   376  // Verify the client can handle a keepalive packet from the server.
   377  func TestClientHandlesKeepalives(t *testing.T) {
   378  	conn := dial(channelKeepaliveSender, t)
   379  	defer conn.Close()
   380  	session, err := conn.NewSession()
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  	defer session.Close()
   385  	if err := session.Shell(); err != nil {
   386  		t.Fatalf("Unable to execute command: %v", err)
   387  	}
   388  	err = session.Wait()
   389  	if err != nil {
   390  		t.Fatalf("expected nil but got: %v", err)
   391  	}
   392  }
   393  
   394  type exitStatusMsg struct {
   395  	Status uint32
   396  }
   397  
   398  type exitSignalMsg struct {
   399  	Signal     string
   400  	CoreDumped bool
   401  	Errmsg     string
   402  	Lang       string
   403  }
   404  
   405  func handleTerminalRequests(in <-chan *Request) {
   406  	for req := range in {
   407  		ok := false
   408  		switch req.Type {
   409  		case "shell":
   410  			ok = true
   411  			if len(req.Payload) > 0 {
   412  				// We don't accept any commands, only the default shell.
   413  				ok = false
   414  			}
   415  		case "env":
   416  			ok = true
   417  		}
   418  		req.Reply(ok, nil)
   419  	}
   420  }
   421  
   422  func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal {
   423  	term := terminal.NewTerminal(ch, prompt)
   424  	go handleTerminalRequests(in)
   425  	return term
   426  }
   427  
   428  func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   429  	defer ch.Close()
   430  	// this string is returned to stdout
   431  	shell := newServerShell(ch, in, "> ")
   432  	readLine(shell, t)
   433  	sendStatus(0, ch, t)
   434  }
   435  
   436  func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) {
   437  	defer ch.Close()
   438  	shell := newServerShell(ch, in, "> ")
   439  	readLine(shell, t)
   440  	sendStatus(15, ch, t)
   441  }
   442  
   443  func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) {
   444  	defer ch.Close()
   445  	shell := newServerShell(ch, in, "> ")
   446  	readLine(shell, t)
   447  	sendStatus(15, ch, t)
   448  	sendSignal("TERM", ch, t)
   449  }
   450  
   451  func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) {
   452  	defer ch.Close()
   453  	shell := newServerShell(ch, in, "> ")
   454  	readLine(shell, t)
   455  	sendSignal("TERM", ch, t)
   456  }
   457  
   458  func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) {
   459  	defer ch.Close()
   460  	shell := newServerShell(ch, in, "> ")
   461  	readLine(shell, t)
   462  	sendSignal("SYS", ch, t)
   463  }
   464  
   465  func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) {
   466  	defer ch.Close()
   467  	shell := newServerShell(ch, in, "> ")
   468  	readLine(shell, t)
   469  }
   470  
   471  func shellHandler(ch Channel, in <-chan *Request, t *testing.T) {
   472  	defer ch.Close()
   473  	// this string is returned to stdout
   474  	shell := newServerShell(ch, in, "golang")
   475  	readLine(shell, t)
   476  	sendStatus(0, ch, t)
   477  }
   478  
   479  // Ignores the command, writes fixed strings to stderr and stdout.
   480  // Strings are "this-is-stdout." and "this-is-stderr.".
   481  func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) {
   482  	defer ch.Close()
   483  	_, err := ch.Read(nil)
   484  
   485  	req, ok := <-in
   486  	if !ok {
   487  		t.Fatalf("error: expected channel request, got: %#v", err)
   488  		return
   489  	}
   490  
   491  	// ignore request, always send some text
   492  	req.Reply(true, nil)
   493  
   494  	_, err = io.WriteString(ch, "this-is-stdout.")
   495  	if err != nil {
   496  		t.Fatalf("error writing on server: %v", err)
   497  	}
   498  	_, err = io.WriteString(ch.Stderr(), "this-is-stderr.")
   499  	if err != nil {
   500  		t.Fatalf("error writing on server: %v", err)
   501  	}
   502  	sendStatus(0, ch, t)
   503  }
   504  
   505  func readLine(shell *terminal.Terminal, t *testing.T) {
   506  	if _, err := shell.ReadLine(); err != nil && err != io.EOF {
   507  		t.Errorf("unable to read line: %v", err)
   508  	}
   509  }
   510  
   511  func sendStatus(status uint32, ch Channel, t *testing.T) {
   512  	msg := exitStatusMsg{
   513  		Status: status,
   514  	}
   515  	if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil {
   516  		t.Errorf("unable to send status: %v", err)
   517  	}
   518  }
   519  
   520  func sendSignal(signal string, ch Channel, t *testing.T) {
   521  	sig := exitSignalMsg{
   522  		Signal:     signal,
   523  		CoreDumped: false,
   524  		Errmsg:     "Process terminated",
   525  		Lang:       "en-GB-oed",
   526  	}
   527  	if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil {
   528  		t.Errorf("unable to send signal: %v", err)
   529  	}
   530  }
   531  
   532  func discardHandler(ch Channel, t *testing.T) {
   533  	defer ch.Close()
   534  	io.Copy(ioutil.Discard, ch)
   535  }
   536  
   537  func echoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   538  	defer ch.Close()
   539  	if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil {
   540  		t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err)
   541  	}
   542  }
   543  
   544  // copyNRandomly copies n bytes from src to dst. It uses a variable, and random,
   545  // buffer size to exercise more code paths.
   546  func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) {
   547  	var (
   548  		buf       = make([]byte, 32*1024)
   549  		written   int
   550  		remaining = n
   551  	)
   552  	for remaining > 0 {
   553  		l := rand.Intn(1 << 15)
   554  		if remaining < l {
   555  			l = remaining
   556  		}
   557  		nr, er := src.Read(buf[:l])
   558  		nw, ew := dst.Write(buf[:nr])
   559  		remaining -= nw
   560  		written += nw
   561  		if ew != nil {
   562  			return written, ew
   563  		}
   564  		if nr != nw {
   565  			return written, io.ErrShortWrite
   566  		}
   567  		if er != nil && er != io.EOF {
   568  			return written, er
   569  		}
   570  	}
   571  	return written, nil
   572  }
   573  
   574  func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) {
   575  	defer ch.Close()
   576  	shell := newServerShell(ch, in, "> ")
   577  	readLine(shell, t)
   578  	if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil {
   579  		t.Errorf("unable to send channel keepalive request: %v", err)
   580  	}
   581  	sendStatus(0, ch, t)
   582  }
   583  
   584  func TestClientWriteEOF(t *testing.T) {
   585  	conn := dial(simpleEchoHandler, t)
   586  	defer conn.Close()
   587  
   588  	session, err := conn.NewSession()
   589  	if err != nil {
   590  		t.Fatal(err)
   591  	}
   592  	defer session.Close()
   593  	stdin, err := session.StdinPipe()
   594  	if err != nil {
   595  		t.Fatalf("StdinPipe failed: %v", err)
   596  	}
   597  	stdout, err := session.StdoutPipe()
   598  	if err != nil {
   599  		t.Fatalf("StdoutPipe failed: %v", err)
   600  	}
   601  
   602  	data := []byte(`0000`)
   603  	_, err = stdin.Write(data)
   604  	if err != nil {
   605  		t.Fatalf("Write failed: %v", err)
   606  	}
   607  	stdin.Close()
   608  
   609  	res, err := ioutil.ReadAll(stdout)
   610  	if err != nil {
   611  		t.Fatalf("Read failed: %v", err)
   612  	}
   613  
   614  	if !bytes.Equal(data, res) {
   615  		t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res)
   616  	}
   617  }
   618  
   619  func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) {
   620  	defer ch.Close()
   621  	data, err := ioutil.ReadAll(ch)
   622  	if err != nil {
   623  		t.Errorf("handler read error: %v", err)
   624  	}
   625  	_, err = ch.Write(data)
   626  	if err != nil {
   627  		t.Errorf("handler write error: %v", err)
   628  	}
   629  }
   630  
   631  func TestSessionID(t *testing.T) {
   632  	c1, c2, err := netPipe()
   633  	if err != nil {
   634  		t.Fatalf("netPipe: %v", err)
   635  	}
   636  	defer c1.Close()
   637  	defer c2.Close()
   638  
   639  	serverID := make(chan []byte, 1)
   640  	clientID := make(chan []byte, 1)
   641  
   642  	serverConf := &ServerConfig{
   643  		NoClientAuth: true,
   644  	}
   645  	serverConf.AddHostKey(testSigners["ecdsa"])
   646  	clientConf := &ClientConfig{
   647  		HostKeyCallback: InsecureIgnoreHostKey(),
   648  		User:            "user",
   649  	}
   650  
   651  	go func() {
   652  		conn, chans, reqs, err := NewServerConn(c1, serverConf)
   653  		if err != nil {
   654  			t.Fatalf("server handshake: %v", err)
   655  		}
   656  		serverID <- conn.SessionID()
   657  		go DiscardRequests(reqs)
   658  		for ch := range chans {
   659  			ch.Reject(Prohibited, "")
   660  		}
   661  	}()
   662  
   663  	go func() {
   664  		conn, chans, reqs, err := NewClientConn(c2, "", clientConf)
   665  		if err != nil {
   666  			t.Fatalf("client handshake: %v", err)
   667  		}
   668  		clientID <- conn.SessionID()
   669  		go DiscardRequests(reqs)
   670  		for ch := range chans {
   671  			ch.Reject(Prohibited, "")
   672  		}
   673  	}()
   674  
   675  	s := <-serverID
   676  	c := <-clientID
   677  	if bytes.Compare(s, c) != 0 {
   678  		t.Errorf("server session ID (%x) != client session ID (%x)", s, c)
   679  	} else if len(s) == 0 {
   680  		t.Errorf("client and server SessionID were empty.")
   681  	}
   682  }
   683  
   684  type noReadConn struct {
   685  	readSeen bool
   686  	net.Conn
   687  }
   688  
   689  func (c *noReadConn) Close() error {
   690  	return nil
   691  }
   692  
   693  func (c *noReadConn) Read(b []byte) (int, error) {
   694  	c.readSeen = true
   695  	return 0, errors.New("noReadConn error")
   696  }
   697  
   698  func TestInvalidServerConfiguration(t *testing.T) {
   699  	c1, c2, err := netPipe()
   700  	if err != nil {
   701  		t.Fatalf("netPipe: %v", err)
   702  	}
   703  	defer c1.Close()
   704  	defer c2.Close()
   705  
   706  	serveConn := noReadConn{Conn: c1}
   707  	serverConf := &ServerConfig{}
   708  
   709  	NewServerConn(&serveConn, serverConf)
   710  	if serveConn.readSeen {
   711  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
   712  	}
   713  
   714  	serverConf.AddHostKey(testSigners["ecdsa"])
   715  
   716  	NewServerConn(&serveConn, serverConf)
   717  	if serveConn.readSeen {
   718  		t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
   719  	}
   720  }
   721  
   722  func TestHostKeyAlgorithms(t *testing.T) {
   723  	serverConf := &ServerConfig{
   724  		NoClientAuth: true,
   725  	}
   726  	serverConf.AddHostKey(testSigners["rsa"])
   727  	serverConf.AddHostKey(testSigners["ecdsa"])
   728  
   729  	connect := func(clientConf *ClientConfig, want string) {
   730  		var alg string
   731  		clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
   732  			alg = key.Type()
   733  			return nil
   734  		}
   735  		c1, c2, err := netPipe()
   736  		if err != nil {
   737  			t.Fatalf("netPipe: %v", err)
   738  		}
   739  		defer c1.Close()
   740  		defer c2.Close()
   741  
   742  		go NewServerConn(c1, serverConf)
   743  		_, _, _, err = NewClientConn(c2, "", clientConf)
   744  		if err != nil {
   745  			t.Fatalf("NewClientConn: %v", err)
   746  		}
   747  		if alg != want {
   748  			t.Errorf("selected key algorithm %s, want %s", alg, want)
   749  		}
   750  	}
   751  
   752  	// By default, we get the preferred algorithm, which is ECDSA 256.
   753  
   754  	clientConf := &ClientConfig{
   755  		HostKeyCallback: InsecureIgnoreHostKey(),
   756  	}
   757  	connect(clientConf, KeyAlgoECDSA256)
   758  
   759  	// Client asks for RSA explicitly.
   760  	clientConf.HostKeyAlgorithms = []string{SigAlgoRSA}
   761  	connect(clientConf, KeyAlgoRSA)
   762  
   763  	// Client asks for RSA-SHA2-512 explicitly.
   764  	clientConf.HostKeyAlgorithms = []string{SigAlgoRSASHA2512}
   765  	// We get back an "ssh-rsa" key but the verification happened
   766  	// with an RSA-SHA2-512 signature.
   767  	connect(clientConf, KeyAlgoRSA)
   768  
   769  	c1, c2, err := netPipe()
   770  	if err != nil {
   771  		t.Fatalf("netPipe: %v", err)
   772  	}
   773  	defer c1.Close()
   774  	defer c2.Close()
   775  
   776  	go NewServerConn(c1, serverConf)
   777  	clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
   778  	_, _, _, err = NewClientConn(c2, "", clientConf)
   779  	if err == nil {
   780  		t.Fatal("succeeded connecting with unknown hostkey algorithm")
   781  	}
   782  }