github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/messages_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"math/big"
    10  	"math/rand"
    11  	"reflect"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  var intLengthTests = []struct {
    17  	val, length int
    18  }{
    19  	{0, 4 + 0},
    20  	{1, 4 + 1},
    21  	{127, 4 + 1},
    22  	{128, 4 + 2},
    23  	{-1, 4 + 1},
    24  }
    25  
    26  func TestIntLength(t *testing.T) {
    27  	for _, test := range intLengthTests {
    28  		v := new(big.Int).SetInt64(int64(test.val))
    29  		length := intLength(v)
    30  		if length != test.length {
    31  			t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
    32  		}
    33  	}
    34  }
    35  
    36  type msgAllTypes struct {
    37  	Bool    bool `sshtype:"21"`
    38  	Array   [16]byte
    39  	Uint64  uint64
    40  	Uint32  uint32
    41  	Uint8   uint8
    42  	String  string
    43  	Strings []string
    44  	Bytes   []byte
    45  	Int     *big.Int
    46  	Rest    []byte `ssh:"rest"`
    47  }
    48  
    49  func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
    50  	m := &msgAllTypes{}
    51  	m.Bool = rand.Intn(2) == 1
    52  	randomBytes(m.Array[:], rand)
    53  	m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
    54  	m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
    55  	m.Uint8 = uint8(rand.Intn(1 << 8))
    56  	m.String = string(m.Array[:])
    57  	m.Strings = randomNameList(rand)
    58  	m.Bytes = m.Array[:]
    59  	m.Int = randomInt(rand)
    60  	m.Rest = m.Array[:]
    61  	return reflect.ValueOf(m)
    62  }
    63  
    64  func TestMarshalUnmarshal(t *testing.T) {
    65  	rand := rand.New(rand.NewSource(0))
    66  	iface := &msgAllTypes{}
    67  	ty := reflect.ValueOf(iface).Type()
    68  
    69  	n := 100
    70  	if testing.Short() {
    71  		n = 5
    72  	}
    73  	for j := 0; j < n; j++ {
    74  		v, ok := quick.Value(ty, rand)
    75  		if !ok {
    76  			t.Errorf("failed to create value")
    77  			break
    78  		}
    79  
    80  		m1 := v.Elem().Interface()
    81  		m2 := iface
    82  
    83  		marshaled := Marshal(m1)
    84  		if err := Unmarshal(marshaled, m2); err != nil {
    85  			t.Errorf("Unmarshal %#v: %s", m1, err)
    86  			break
    87  		}
    88  
    89  		if !reflect.DeepEqual(v.Interface(), m2) {
    90  			t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
    91  			break
    92  		}
    93  	}
    94  }
    95  
    96  func TestUnmarshalEmptyPacket(t *testing.T) {
    97  	var b []byte
    98  	var m channelRequestSuccessMsg
    99  	if err := Unmarshal(b, &m); err == nil {
   100  		t.Fatalf("unmarshal of empty slice succeeded")
   101  	}
   102  }
   103  
   104  func TestUnmarshalUnexpectedPacket(t *testing.T) {
   105  	type S struct {
   106  		I uint32 `sshtype:"43"`
   107  		S string
   108  		B bool
   109  	}
   110  
   111  	s := S{11, "hello", true}
   112  	packet := Marshal(s)
   113  	packet[0] = 42
   114  	roundtrip := S{}
   115  	err := Unmarshal(packet, &roundtrip)
   116  	if err == nil {
   117  		t.Fatal("expected error, not nil")
   118  	}
   119  }
   120  
   121  func TestMarshalPtr(t *testing.T) {
   122  	s := struct {
   123  		S string
   124  	}{"hello"}
   125  
   126  	m1 := Marshal(s)
   127  	m2 := Marshal(&s)
   128  	if !bytes.Equal(m1, m2) {
   129  		t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
   130  	}
   131  }
   132  
   133  func TestBareMarshalUnmarshal(t *testing.T) {
   134  	type S struct {
   135  		I uint32
   136  		S string
   137  		B bool
   138  	}
   139  
   140  	s := S{42, "hello", true}
   141  	packet := Marshal(s)
   142  	roundtrip := S{}
   143  	Unmarshal(packet, &roundtrip)
   144  
   145  	if !reflect.DeepEqual(s, roundtrip) {
   146  		t.Errorf("got %#v, want %#v", roundtrip, s)
   147  	}
   148  }
   149  
   150  func TestBareMarshal(t *testing.T) {
   151  	type S2 struct {
   152  		I uint32
   153  	}
   154  	s := S2{42}
   155  	packet := Marshal(s)
   156  	i, rest, ok := parseUint32(packet)
   157  	if len(rest) > 0 || !ok {
   158  		t.Errorf("parseInt(%q): parse error", packet)
   159  	}
   160  	if i != s.I {
   161  		t.Errorf("got %d, want %d", i, s.I)
   162  	}
   163  }
   164  
   165  func TestUnmarshalShortKexInitPacket(t *testing.T) {
   166  	// This used to panic.
   167  	// Issue 11348
   168  	packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
   169  	kim := &kexInitMsg{}
   170  	if err := Unmarshal(packet, kim); err == nil {
   171  		t.Error("truncated packet unmarshaled without error")
   172  	}
   173  }
   174  
   175  func TestMarshalMultiTag(t *testing.T) {
   176  	var res struct {
   177  		A uint32 `sshtype:"1|2"`
   178  	}
   179  
   180  	good1 := struct {
   181  		A uint32 `sshtype:"1"`
   182  	}{
   183  		1,
   184  	}
   185  	good2 := struct {
   186  		A uint32 `sshtype:"2"`
   187  	}{
   188  		1,
   189  	}
   190  
   191  	if e := Unmarshal(Marshal(good1), &res); e != nil {
   192  		t.Errorf("error unmarshaling multipart tag: %v", e)
   193  	}
   194  
   195  	if e := Unmarshal(Marshal(good2), &res); e != nil {
   196  		t.Errorf("error unmarshaling multipart tag: %v", e)
   197  	}
   198  
   199  	bad1 := struct {
   200  		A uint32 `sshtype:"3"`
   201  	}{
   202  		1,
   203  	}
   204  	if e := Unmarshal(Marshal(bad1), &res); e == nil {
   205  		t.Errorf("bad struct unmarshaled without error")
   206  	}
   207  }
   208  
   209  func randomBytes(out []byte, rand *rand.Rand) {
   210  	for i := 0; i < len(out); i++ {
   211  		out[i] = byte(rand.Int31())
   212  	}
   213  }
   214  
   215  func randomNameList(rand *rand.Rand) []string {
   216  	ret := make([]string, rand.Int31()&15)
   217  	for i := range ret {
   218  		s := make([]byte, 1+(rand.Int31()&15))
   219  		for j := range s {
   220  			s[j] = 'a' + uint8(rand.Int31()&15)
   221  		}
   222  		ret[i] = string(s)
   223  	}
   224  	return ret
   225  }
   226  
   227  func randomInt(rand *rand.Rand) *big.Int {
   228  	return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
   229  }
   230  
   231  func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   232  	ki := &kexInitMsg{}
   233  	randomBytes(ki.Cookie[:], rand)
   234  	ki.KexAlgos = randomNameList(rand)
   235  	ki.ServerHostKeyAlgos = randomNameList(rand)
   236  	ki.CiphersClientServer = randomNameList(rand)
   237  	ki.CiphersServerClient = randomNameList(rand)
   238  	ki.MACsClientServer = randomNameList(rand)
   239  	ki.MACsServerClient = randomNameList(rand)
   240  	ki.CompressionClientServer = randomNameList(rand)
   241  	ki.CompressionServerClient = randomNameList(rand)
   242  	ki.LanguagesClientServer = randomNameList(rand)
   243  	ki.LanguagesServerClient = randomNameList(rand)
   244  	if rand.Int31()&1 == 1 {
   245  		ki.FirstKexFollows = true
   246  	}
   247  	return reflect.ValueOf(ki)
   248  }
   249  
   250  func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   251  	dhi := &kexDHInitMsg{}
   252  	dhi.X = randomInt(rand)
   253  	return reflect.ValueOf(dhi)
   254  }
   255  
   256  var (
   257  	_kexInitMsg   = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
   258  	_kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
   259  
   260  	_kexInit   = Marshal(_kexInitMsg)
   261  	_kexDHInit = Marshal(_kexDHInitMsg)
   262  )
   263  
   264  func BenchmarkMarshalKexInitMsg(b *testing.B) {
   265  	for i := 0; i < b.N; i++ {
   266  		Marshal(_kexInitMsg)
   267  	}
   268  }
   269  
   270  func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
   271  	m := new(kexInitMsg)
   272  	for i := 0; i < b.N; i++ {
   273  		Unmarshal(_kexInit, m)
   274  	}
   275  }
   276  
   277  func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
   278  	for i := 0; i < b.N; i++ {
   279  		Marshal(_kexDHInitMsg)
   280  	}
   281  }
   282  
   283  func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
   284  	m := new(kexDHInitMsg)
   285  	for i := 0; i < b.N; i++ {
   286  		Unmarshal(_kexDHInit, m)
   287  	}
   288  }