github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/cipher_test.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto"
    10  	"crypto/rand"
    11  	"encoding/binary"
    12  	"io"
    13  	"testing"
    14  
    15  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/internal/poly1305"
    16  	"golang.org/x/crypto/chacha20"
    17  )
    18  
    19  func TestDefaultCiphersExist(t *testing.T) {
    20  	for _, cipherAlgo := range supportedCiphers {
    21  		if _, ok := cipherModes[cipherAlgo]; !ok {
    22  			t.Errorf("supported cipher %q is unknown", cipherAlgo)
    23  		}
    24  	}
    25  	for _, cipherAlgo := range preferredCiphers {
    26  		if _, ok := cipherModes[cipherAlgo]; !ok {
    27  			t.Errorf("preferred cipher %q is unknown", cipherAlgo)
    28  		}
    29  	}
    30  }
    31  
    32  func TestPacketCiphers(t *testing.T) {
    33  	defaultMac := "hmac-sha2-256"
    34  	defaultCipher := "aes128-ctr"
    35  	for cipher := range cipherModes {
    36  		t.Run("cipher="+cipher,
    37  			func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
    38  	}
    39  	for mac := range macModes {
    40  		t.Run("mac="+mac,
    41  			func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
    42  	}
    43  }
    44  
    45  func testPacketCipher(t *testing.T, cipher, mac string) {
    46  	kr := &kexResult{Hash: crypto.SHA1}
    47  	algs := directionAlgorithms{
    48  		Cipher:      cipher,
    49  		MAC:         mac,
    50  		Compression: "none",
    51  	}
    52  	client, err := newPacketCipher(clientKeys, algs, kr)
    53  	if err != nil {
    54  		t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
    55  	}
    56  	server, err := newPacketCipher(clientKeys, algs, kr)
    57  	if err != nil {
    58  		t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
    59  	}
    60  
    61  	want := "bla bla"
    62  	input := []byte(want)
    63  	buf := &bytes.Buffer{}
    64  	if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
    65  		t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err)
    66  	}
    67  
    68  	packet, err := server.readCipherPacket(0, buf)
    69  	if err != nil {
    70  		t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err)
    71  	}
    72  
    73  	if string(packet) != want {
    74  		t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
    75  	}
    76  }
    77  
    78  func TestCBCOracleCounterMeasure(t *testing.T) {
    79  	kr := &kexResult{Hash: crypto.SHA1}
    80  	algs := directionAlgorithms{
    81  		Cipher:      aes128cbcID,
    82  		MAC:         "hmac-sha1",
    83  		Compression: "none",
    84  	}
    85  	client, err := newPacketCipher(clientKeys, algs, kr)
    86  	if err != nil {
    87  		t.Fatalf("newPacketCipher(client): %v", err)
    88  	}
    89  
    90  	want := "bla bla"
    91  	input := []byte(want)
    92  	buf := &bytes.Buffer{}
    93  	if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
    94  		t.Errorf("writeCipherPacket: %v", err)
    95  	}
    96  
    97  	packetSize := buf.Len()
    98  	buf.Write(make([]byte, 2*maxPacket))
    99  
   100  	// We corrupt each byte, but this usually will only test the
   101  	// 'packet too large' or 'MAC failure' cases.
   102  	lastRead := -1
   103  	for i := 0; i < packetSize; i++ {
   104  		server, err := newPacketCipher(clientKeys, algs, kr)
   105  		if err != nil {
   106  			t.Fatalf("newPacketCipher(client): %v", err)
   107  		}
   108  
   109  		fresh := &bytes.Buffer{}
   110  		fresh.Write(buf.Bytes())
   111  		fresh.Bytes()[i] ^= 0x01
   112  
   113  		before := fresh.Len()
   114  		_, err = server.readCipherPacket(0, fresh)
   115  		if err == nil {
   116  			t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i)
   117  			continue
   118  		}
   119  		if _, ok := err.(cbcError); !ok {
   120  			t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
   121  			continue
   122  		}
   123  
   124  		after := fresh.Len()
   125  		bytesRead := before - after
   126  		if bytesRead < maxPacket {
   127  			t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
   128  			continue
   129  		}
   130  
   131  		if i > 0 && bytesRead != lastRead {
   132  			t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
   133  		}
   134  		lastRead = bytesRead
   135  	}
   136  }
   137  
   138  func TestCVE202143565(t *testing.T) {
   139  	tests := []struct {
   140  		cipher          string
   141  		constructPacket func(packetCipher) io.Reader
   142  	}{
   143  		{
   144  			cipher: gcmCipherID,
   145  			constructPacket: func(client packetCipher) io.Reader {
   146  				internalCipher := client.(*gcmCipher)
   147  				b := &bytes.Buffer{}
   148  				prefix := [4]byte{}
   149  				if _, err := b.Write(prefix[:]); err != nil {
   150  					t.Fatal(err)
   151  				}
   152  				internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:])
   153  				if _, err := b.Write(internalCipher.buf); err != nil {
   154  					t.Fatal(err)
   155  				}
   156  				internalCipher.incIV()
   157  
   158  				return b
   159  			},
   160  		},
   161  		{
   162  			cipher: chacha20Poly1305ID,
   163  			constructPacket: func(client packetCipher) io.Reader {
   164  				internalCipher := client.(*chacha20Poly1305Cipher)
   165  				b := &bytes.Buffer{}
   166  
   167  				nonce := make([]byte, 12)
   168  				s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce)
   169  				if err != nil {
   170  					t.Fatal(err)
   171  				}
   172  				var polyKey, discardBuf [32]byte
   173  				s.XORKeyStream(polyKey[:], polyKey[:])
   174  				s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
   175  
   176  				internalCipher.buf = make([]byte, 4+poly1305.TagSize)
   177  				binary.BigEndian.PutUint32(internalCipher.buf, 0)
   178  				ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce)
   179  				if err != nil {
   180  					t.Fatal(err)
   181  				}
   182  				ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4])
   183  				if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil {
   184  					t.Fatal(err)
   185  				}
   186  
   187  				s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4])
   188  
   189  				var tag [poly1305.TagSize]byte
   190  				poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey)
   191  
   192  				copy(internalCipher.buf[4:], tag[:])
   193  
   194  				if _, err := b.Write(internalCipher.buf); err != nil {
   195  					t.Fatal(err)
   196  				}
   197  
   198  				return b
   199  			},
   200  		},
   201  	}
   202  
   203  	for _, tc := range tests {
   204  		mac := "hmac-sha2-256"
   205  
   206  		kr := &kexResult{Hash: crypto.SHA1}
   207  		algs := directionAlgorithms{
   208  			Cipher:      tc.cipher,
   209  			MAC:         mac,
   210  			Compression: "none",
   211  		}
   212  		client, err := newPacketCipher(clientKeys, algs, kr)
   213  		if err != nil {
   214  			t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
   215  		}
   216  		server, err := newPacketCipher(clientKeys, algs, kr)
   217  		if err != nil {
   218  			t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err)
   219  		}
   220  
   221  		b := tc.constructPacket(client)
   222  
   223  		wantErr := "ssh: empty packet"
   224  		_, err = server.readCipherPacket(0, b)
   225  		if err == nil {
   226  			t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac)
   227  		} else if err.Error() != wantErr {
   228  			t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr)
   229  		}
   230  	}
   231  }