github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/cipher_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto"
    10  	"crypto/aes"
    11  	"crypto/rand"
    12  	"testing"
    13  )
    14  
    15  func TestDefaultCiphersExist(t *testing.T) {
    16  	for _, cipherAlgo := range supportedCiphers {
    17  		if _, ok := cipherModes[cipherAlgo]; !ok {
    18  			t.Errorf("default cipher %q is unknown", cipherAlgo)
    19  		}
    20  	}
    21  }
    22  
    23  func TestPacketCiphers(t *testing.T) {
    24  	// Still test aes128cbc cipher although it's commented out.
    25  	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
    26  	defer delete(cipherModes, aes128cbcID)
    27  
    28  	for cipher := range cipherModes {
    29  		kr := &kexResult{Hash: crypto.SHA1}
    30  		algs := directionAlgorithms{
    31  			Cipher:      cipher,
    32  			MAC:         "hmac-sha1",
    33  			Compression: "none",
    34  		}
    35  		client, err := newPacketCipher(clientKeys, algs, kr)
    36  		if err != nil {
    37  			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
    38  			continue
    39  		}
    40  		server, err := newPacketCipher(clientKeys, algs, kr)
    41  		if err != nil {
    42  			t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
    43  			continue
    44  		}
    45  
    46  		want := "bla bla"
    47  		input := []byte(want)
    48  		buf := &bytes.Buffer{}
    49  		if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
    50  			t.Errorf("writePacket(%q): %v", cipher, err)
    51  			continue
    52  		}
    53  
    54  		packet, err := server.readPacket(0, buf)
    55  		if err != nil {
    56  			t.Errorf("readPacket(%q): %v", cipher, err)
    57  			continue
    58  		}
    59  
    60  		if string(packet) != want {
    61  			t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
    62  		}
    63  	}
    64  }
    65  
    66  func TestCBCOracleCounterMeasure(t *testing.T) {
    67  	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
    68  	defer delete(cipherModes, aes128cbcID)
    69  
    70  	kr := &kexResult{Hash: crypto.SHA1}
    71  	algs := directionAlgorithms{
    72  		Cipher:      aes128cbcID,
    73  		MAC:         "hmac-sha1",
    74  		Compression: "none",
    75  	}
    76  	client, err := newPacketCipher(clientKeys, algs, kr)
    77  	if err != nil {
    78  		t.Fatalf("newPacketCipher(client): %v", err)
    79  	}
    80  
    81  	want := "bla bla"
    82  	input := []byte(want)
    83  	buf := &bytes.Buffer{}
    84  	if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
    85  		t.Errorf("writePacket: %v", err)
    86  	}
    87  
    88  	packetSize := buf.Len()
    89  	buf.Write(make([]byte, 2*maxPacket))
    90  
    91  	// We corrupt each byte, but this usually will only test the
    92  	// 'packet too large' or 'MAC failure' cases.
    93  	lastRead := -1
    94  	for i := 0; i < packetSize; i++ {
    95  		server, err := newPacketCipher(clientKeys, algs, kr)
    96  		if err != nil {
    97  			t.Fatalf("newPacketCipher(client): %v", err)
    98  		}
    99  
   100  		fresh := &bytes.Buffer{}
   101  		fresh.Write(buf.Bytes())
   102  		fresh.Bytes()[i] ^= 0x01
   103  
   104  		before := fresh.Len()
   105  		_, err = server.readPacket(0, fresh)
   106  		if err == nil {
   107  			t.Errorf("corrupt byte %d: readPacket succeeded ", i)
   108  			continue
   109  		}
   110  		if _, ok := err.(cbcError); !ok {
   111  			t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
   112  			continue
   113  		}
   114  
   115  		after := fresh.Len()
   116  		bytesRead := before - after
   117  		if bytesRead < maxPacket {
   118  			t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
   119  			continue
   120  		}
   121  
   122  		if i > 0 && bytesRead != lastRead {
   123  			t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
   124  		}
   125  		lastRead = bytesRead
   126  	}
   127  }