github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/common_test.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"reflect"
     9  	"testing"
    10  )
    11  
    12  func TestFindAgreedAlgorithms(t *testing.T) {
    13  	initKex := func(k *kexInitMsg) {
    14  		if k.KexAlgos == nil {
    15  			k.KexAlgos = []string{"kex1"}
    16  		}
    17  		if k.ServerHostKeyAlgos == nil {
    18  			k.ServerHostKeyAlgos = []string{"hostkey1"}
    19  		}
    20  		if k.CiphersClientServer == nil {
    21  			k.CiphersClientServer = []string{"cipher1"}
    22  
    23  		}
    24  		if k.CiphersServerClient == nil {
    25  			k.CiphersServerClient = []string{"cipher1"}
    26  
    27  		}
    28  		if k.MACsClientServer == nil {
    29  			k.MACsClientServer = []string{"mac1"}
    30  
    31  		}
    32  		if k.MACsServerClient == nil {
    33  			k.MACsServerClient = []string{"mac1"}
    34  
    35  		}
    36  		if k.CompressionClientServer == nil {
    37  			k.CompressionClientServer = []string{"compression1"}
    38  
    39  		}
    40  		if k.CompressionServerClient == nil {
    41  			k.CompressionServerClient = []string{"compression1"}
    42  
    43  		}
    44  		if k.LanguagesClientServer == nil {
    45  			k.LanguagesClientServer = []string{"language1"}
    46  
    47  		}
    48  		if k.LanguagesServerClient == nil {
    49  			k.LanguagesServerClient = []string{"language1"}
    50  
    51  		}
    52  	}
    53  
    54  	initDirAlgs := func(a *directionAlgorithms) {
    55  		if a.Cipher == "" {
    56  			a.Cipher = "cipher1"
    57  		}
    58  		if a.MAC == "" {
    59  			a.MAC = "mac1"
    60  		}
    61  		if a.Compression == "" {
    62  			a.Compression = "compression1"
    63  		}
    64  	}
    65  
    66  	initAlgs := func(a *algorithms) {
    67  		if a.kex == "" {
    68  			a.kex = "kex1"
    69  		}
    70  		if a.hostKey == "" {
    71  			a.hostKey = "hostkey1"
    72  		}
    73  		initDirAlgs(&a.r)
    74  		initDirAlgs(&a.w)
    75  	}
    76  
    77  	type testcase struct {
    78  		name                   string
    79  		clientIn, serverIn     kexInitMsg
    80  		wantClient, wantServer algorithms
    81  		wantErr                bool
    82  	}
    83  
    84  	cases := []testcase{
    85  		testcase{
    86  			name: "standard",
    87  		},
    88  
    89  		testcase{
    90  			name: "no common hostkey",
    91  			serverIn: kexInitMsg{
    92  				ServerHostKeyAlgos: []string{"hostkey2"},
    93  			},
    94  			wantErr: true,
    95  		},
    96  
    97  		testcase{
    98  			name: "no common kex",
    99  			serverIn: kexInitMsg{
   100  				KexAlgos: []string{"kex2"},
   101  			},
   102  			wantErr: true,
   103  		},
   104  
   105  		testcase{
   106  			name: "no common cipher",
   107  			serverIn: kexInitMsg{
   108  				CiphersClientServer: []string{"cipher2"},
   109  			},
   110  			wantErr: true,
   111  		},
   112  
   113  		testcase{
   114  			name: "client decides cipher",
   115  			serverIn: kexInitMsg{
   116  				CiphersClientServer: []string{"cipher1", "cipher2"},
   117  				CiphersServerClient: []string{"cipher2", "cipher3"},
   118  			},
   119  			clientIn: kexInitMsg{
   120  				CiphersClientServer: []string{"cipher2", "cipher1"},
   121  				CiphersServerClient: []string{"cipher3", "cipher2"},
   122  			},
   123  			wantClient: algorithms{
   124  				r: directionAlgorithms{
   125  					Cipher: "cipher3",
   126  				},
   127  				w: directionAlgorithms{
   128  					Cipher: "cipher2",
   129  				},
   130  			},
   131  			wantServer: algorithms{
   132  				w: directionAlgorithms{
   133  					Cipher: "cipher3",
   134  				},
   135  				r: directionAlgorithms{
   136  					Cipher: "cipher2",
   137  				},
   138  			},
   139  		},
   140  
   141  		// TODO(hanwen): fix and add tests for AEAD ignoring
   142  		// the MACs field
   143  	}
   144  
   145  	for i := range cases {
   146  		initKex(&cases[i].clientIn)
   147  		initKex(&cases[i].serverIn)
   148  		initAlgs(&cases[i].wantClient)
   149  		initAlgs(&cases[i].wantServer)
   150  	}
   151  
   152  	for _, c := range cases {
   153  		t.Run(c.name, func(t *testing.T) {
   154  			serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
   155  			clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
   156  
   157  			serverHasErr := serverErr != nil
   158  			clientHasErr := clientErr != nil
   159  			if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
   160  				t.Fatalf("got client/server error (%v, %v), want hasError %v",
   161  					clientErr, serverErr, c.wantErr)
   162  
   163  			}
   164  			if c.wantErr {
   165  				return
   166  			}
   167  
   168  			if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
   169  				t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
   170  			}
   171  			if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
   172  				t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
   173  			}
   174  		})
   175  	}
   176  }