github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/client_test.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"strings"
    11  	"testing"
    12  )
    13  
    14  func TestClientVersion(t *testing.T) {
    15  	for _, tt := range []struct {
    16  		name      string
    17  		version   string
    18  		multiLine string
    19  		wantErr   bool
    20  	}{
    21  		{
    22  			name:    "default version",
    23  			version: packageVersion,
    24  		},
    25  		{
    26  			name:    "custom version",
    27  			version: "SSH-2.0-CustomClientVersionString",
    28  		},
    29  		{
    30  			name:      "good multi line version",
    31  			version:   packageVersion,
    32  			multiLine: strings.Repeat("ignored\r\n", 20),
    33  		},
    34  		{
    35  			name:      "bad multi line version",
    36  			version:   packageVersion,
    37  			multiLine: "bad multi line version",
    38  			wantErr:   true,
    39  		},
    40  		{
    41  			name:      "long multi line version",
    42  			version:   packageVersion,
    43  			multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
    44  			wantErr:   true,
    45  		},
    46  	} {
    47  		t.Run(tt.name, func(t *testing.T) {
    48  			c1, c2, err := netPipe()
    49  			if err != nil {
    50  				t.Fatalf("netPipe: %v", err)
    51  			}
    52  			defer c1.Close()
    53  			defer c2.Close()
    54  			go func() {
    55  				if tt.multiLine != "" {
    56  					c1.Write([]byte(tt.multiLine))
    57  				}
    58  				NewClientConn(c1, "", &ClientConfig{
    59  					ClientVersion:   tt.version,
    60  					HostKeyCallback: InsecureIgnoreHostKey(),
    61  				})
    62  				c1.Close()
    63  			}()
    64  			conf := &ServerConfig{NoClientAuth: true}
    65  			conf.AddHostKey(testSigners["rsa"])
    66  			conn, _, _, err := NewServerConn(c2, conf)
    67  			if err == nil == tt.wantErr {
    68  				t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
    69  			}
    70  			if tt.wantErr {
    71  				// Don't verify the version on an expected error.
    72  				return
    73  			}
    74  			if got := string(conn.ClientVersion()); got != tt.version {
    75  				t.Fatalf("got %q; want %q", got, tt.version)
    76  			}
    77  		})
    78  	}
    79  }
    80  
    81  func TestHostKeyCheck(t *testing.T) {
    82  	for _, tt := range []struct {
    83  		name      string
    84  		wantError string
    85  		key       PublicKey
    86  	}{
    87  		{"no callback", "must specify HostKeyCallback", nil},
    88  		{"correct key", "", testSigners["rsa"].PublicKey()},
    89  		{"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
    90  	} {
    91  		c1, c2, err := netPipe()
    92  		if err != nil {
    93  			t.Fatalf("netPipe: %v", err)
    94  		}
    95  		defer c1.Close()
    96  		defer c2.Close()
    97  		serverConf := &ServerConfig{
    98  			NoClientAuth: true,
    99  		}
   100  		serverConf.AddHostKey(testSigners["rsa"])
   101  
   102  		go NewServerConn(c1, serverConf)
   103  		clientConf := ClientConfig{
   104  			User: "user",
   105  		}
   106  		if tt.key != nil {
   107  			clientConf.HostKeyCallback = FixedHostKey(tt.key)
   108  		}
   109  
   110  		_, _, _, err = NewClientConn(c2, "", &clientConf)
   111  		if err != nil {
   112  			if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
   113  				t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
   114  			}
   115  		} else if tt.wantError != "" {
   116  			t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
   117  		}
   118  	}
   119  }
   120  
   121  func TestVerifyHostKeySignature(t *testing.T) {
   122  	for _, tt := range []struct {
   123  		key        string
   124  		signAlgo   string
   125  		verifyAlgo string
   126  		wantError  string
   127  	}{
   128  		{"rsa", SigAlgoRSA, SigAlgoRSA, ""},
   129  		{"rsa", SigAlgoRSASHA2256, SigAlgoRSASHA2256, ""},
   130  		{"rsa", SigAlgoRSA, SigAlgoRSASHA2512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`},
   131  		{"ed25519", KeyAlgoED25519, KeyAlgoED25519, ""},
   132  	} {
   133  		key := testSigners[tt.key].PublicKey()
   134  		s, ok := testSigners[tt.key].(AlgorithmSigner)
   135  		if !ok {
   136  			t.Fatalf("needed an AlgorithmSigner")
   137  		}
   138  		sig, err := s.SignWithAlgorithm(rand.Reader, []byte("test"), tt.signAlgo)
   139  		if err != nil {
   140  			t.Fatalf("couldn't sign: %q", err)
   141  		}
   142  
   143  		b := bytes.Buffer{}
   144  		writeString(&b, []byte(sig.Format))
   145  		writeString(&b, sig.Blob)
   146  
   147  		result := kexResult{Signature: b.Bytes(), H: []byte("test")}
   148  
   149  		err = verifyHostKeySignature(key, tt.verifyAlgo, &result)
   150  		if err != nil {
   151  			if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
   152  				t.Errorf("got error %q, expecting %q", err.Error(), tt.wantError)
   153  			}
   154  		} else if tt.wantError != "" {
   155  			t.Errorf("succeeded, but want error string %q", tt.wantError)
   156  		}
   157  	}
   158  }
   159  
   160  func TestBannerCallback(t *testing.T) {
   161  	c1, c2, err := netPipe()
   162  	if err != nil {
   163  		t.Fatalf("netPipe: %v", err)
   164  	}
   165  	defer c1.Close()
   166  	defer c2.Close()
   167  
   168  	serverConf := &ServerConfig{
   169  		PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   170  			return &Permissions{}, nil
   171  		},
   172  		BannerCallback: func(conn ConnMetadata) string {
   173  			return "Hello World"
   174  		},
   175  	}
   176  	serverConf.AddHostKey(testSigners["rsa"])
   177  	go NewServerConn(c1, serverConf)
   178  
   179  	var receivedBanner string
   180  	var bannerCount int
   181  	clientConf := ClientConfig{
   182  		Auth: []AuthMethod{
   183  			Password("123"),
   184  		},
   185  		User:            "user",
   186  		HostKeyCallback: InsecureIgnoreHostKey(),
   187  		BannerCallback: func(message string) error {
   188  			bannerCount++
   189  			receivedBanner = message
   190  			return nil
   191  		},
   192  	}
   193  
   194  	_, _, _, err = NewClientConn(c2, "", &clientConf)
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	if bannerCount != 1 {
   200  		t.Errorf("got %d banners; want 1", bannerCount)
   201  	}
   202  
   203  	expected := "Hello World"
   204  	if receivedBanner != expected {
   205  		t.Fatalf("got %s; want %s", receivedBanner, expected)
   206  	}
   207  }
   208  
   209  func TestNewClientConn(t *testing.T) {
   210  	for _, tt := range []struct {
   211  		name string
   212  		user string
   213  	}{
   214  		{
   215  			name: "good user field for ConnMetadata",
   216  			user: "testuser",
   217  		},
   218  		{
   219  			name: "empty user field for ConnMetadata",
   220  			user: "",
   221  		},
   222  	} {
   223  		t.Run(tt.name, func(t *testing.T) {
   224  			c1, c2, err := netPipe()
   225  			if err != nil {
   226  				t.Fatalf("netPipe: %v", err)
   227  			}
   228  			defer c1.Close()
   229  			defer c2.Close()
   230  
   231  			serverConf := &ServerConfig{
   232  				PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
   233  					return &Permissions{}, nil
   234  				},
   235  			}
   236  			serverConf.AddHostKey(testSigners["rsa"])
   237  			go NewServerConn(c1, serverConf)
   238  
   239  			clientConf := &ClientConfig{
   240  				User: tt.user,
   241  				Auth: []AuthMethod{
   242  					Password("testpw"),
   243  				},
   244  				HostKeyCallback: InsecureIgnoreHostKey(),
   245  			}
   246  			clientConn, _, _, err := NewClientConn(c2, "", clientConf)
   247  			if err != nil {
   248  				t.Fatal(err)
   249  			}
   250  
   251  			if userGot := clientConn.User(); userGot != tt.user {
   252  				t.Errorf("got user %q; want user %q", userGot, tt.user)
   253  			}
   254  		})
   255  	}
   256  }