github.com/emmansun/gmsm@v0.29.1/internal/cryptotest/stream.go (about)

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cryptotest
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/cipher"
    10  	"fmt"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/emmansun/gmsm/internal/subtle"
    15  )
    16  
    17  // Each test is executed with each of the buffer lengths in bufLens.
    18  var (
    19  	bufLens = []int{0, 1, 3, 4, 8, 10, 15, 16, 20, 32, 50, 4096, 5000}
    20  	bufCap  = 10000
    21  )
    22  
    23  // MakeStream returns a cipher.Stream instance.
    24  //
    25  // Multiple calls to MakeStream must return equivalent instances,
    26  // so for example the key and/or IV must be fixed.
    27  type MakeStream func() cipher.Stream
    28  
    29  // TestStream performs a set of tests on cipher.Stream implementations,
    30  // checking the documented requirements of XORKeyStream.
    31  func TestStream(t *testing.T, ms MakeStream) {
    32  
    33  	t.Run("XORSemantics", func(t *testing.T) {
    34  		if strings.Contains(t.Name(), "TestCFBStream") {
    35  			// This is ugly, but so is CFB's abuse of cipher.Stream.
    36  			// Don't want to make it easier for anyone else to do that.
    37  			t.Skip("CFB implements cipher.Stream but does not follow XOR semantics")
    38  		}
    39  
    40  		// Test that XORKeyStream inverts itself for encryption/decryption.
    41  		t.Run("Roundtrip", func(t *testing.T) {
    42  
    43  			for _, length := range bufLens {
    44  				t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
    45  					rng := newRandReader(t)
    46  
    47  					plaintext := make([]byte, length)
    48  					rng.Read(plaintext)
    49  
    50  					ciphertext := make([]byte, length)
    51  					decrypted := make([]byte, length)
    52  
    53  					ms().XORKeyStream(ciphertext, plaintext) // Encrypt plaintext
    54  					ms().XORKeyStream(decrypted, ciphertext) // Decrypt ciphertext
    55  					if !bytes.Equal(decrypted, plaintext) {
    56  						t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %s, want %s", truncateHex(decrypted), truncateHex(plaintext))
    57  					}
    58  				})
    59  			}
    60  		})
    61  
    62  		// Test that XORKeyStream behaves the same as directly XORing
    63  		// plaintext with the stream.
    64  		t.Run("DirectXOR", func(t *testing.T) {
    65  
    66  			for _, length := range bufLens {
    67  				t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
    68  					rng := newRandReader(t)
    69  
    70  					plaintext := make([]byte, length)
    71  					rng.Read(plaintext)
    72  
    73  					// Encrypting all zeros should reveal the stream itself
    74  					stream, directXOR := make([]byte, length), make([]byte, length)
    75  					ms().XORKeyStream(stream, stream)
    76  					// Encrypt plaintext by directly XORing the stream
    77  					subtle.XORBytes(directXOR, stream, plaintext)
    78  
    79  					// Encrypt plaintext with XORKeyStream
    80  					ciphertext := make([]byte, length)
    81  					ms().XORKeyStream(ciphertext, plaintext)
    82  					if !bytes.Equal(ciphertext, directXOR) {
    83  						t.Errorf("xor semantics were not preserved; got %s, want %s", truncateHex(ciphertext), truncateHex(directXOR))
    84  					}
    85  				})
    86  			}
    87  		})
    88  	})
    89  
    90  	t.Run("AlterInput", func(t *testing.T) {
    91  		rng := newRandReader(t)
    92  		src, dst, before := make([]byte, bufCap), make([]byte, bufCap), make([]byte, bufCap)
    93  		rng.Read(src)
    94  
    95  		for _, length := range bufLens {
    96  
    97  			t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
    98  				copy(before, src)
    99  
   100  				ms().XORKeyStream(dst[:length], src[:length])
   101  				if !bytes.Equal(src, before) {
   102  					t.Errorf("XORKeyStream modified src; got %s, want %s", truncateHex(src), truncateHex(before))
   103  				}
   104  			})
   105  		}
   106  	})
   107  
   108  	t.Run("Aliasing", func(t *testing.T) {
   109  		rng := newRandReader(t)
   110  
   111  		buff, expectedOutput := make([]byte, bufCap), make([]byte, bufCap)
   112  
   113  		for _, length := range bufLens {
   114  			// Record what output is when src and dst are different
   115  			rng.Read(buff)
   116  			ms().XORKeyStream(expectedOutput[:length], buff[:length])
   117  
   118  			// Check that the same output is generated when src=dst alias to the same
   119  			// memory
   120  			ms().XORKeyStream(buff[:length], buff[:length])
   121  			if !bytes.Equal(buff[:length], expectedOutput[:length]) {
   122  				t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff[:length], expectedOutput[:length])
   123  			}
   124  		}
   125  	})
   126  
   127  	t.Run("OutOfBoundsWrite", func(t *testing.T) { // Issue 21104
   128  		rng := newRandReader(t)
   129  
   130  		plaintext := make([]byte, bufCap)
   131  		rng.Read(plaintext)
   132  		ciphertext := make([]byte, bufCap)
   133  
   134  		for _, length := range bufLens {
   135  			copy(ciphertext, plaintext) // Reset ciphertext buffer
   136  
   137  			t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
   138  				mustPanic(t, "output smaller than input", func() { ms().XORKeyStream(ciphertext[:length], plaintext) })
   139  
   140  				if !bytes.Equal(ciphertext[length:], plaintext[length:]) {
   141  					t.Errorf("XORKeyStream did out of bounds write; got %s, want %s", truncateHex(ciphertext[length:]), truncateHex(plaintext[length:]))
   142  				}
   143  			})
   144  		}
   145  	})
   146  
   147  	t.Run("BufferOverlap", func(t *testing.T) {
   148  		rng := newRandReader(t)
   149  
   150  		buff := make([]byte, bufCap)
   151  		rng.Read(buff)
   152  
   153  		for _, length := range bufLens {
   154  			if length == 0 || length == 1 {
   155  				continue
   156  			}
   157  
   158  			t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
   159  				// Make src and dst slices point to same array with inexact overlap
   160  				src := buff[:length]
   161  				dst := buff[1 : length+1]
   162  				mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
   163  
   164  				// Only overlap on one byte
   165  				src = buff[:length]
   166  				dst = buff[length-1 : 2*length-1]
   167  				mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
   168  
   169  				// src comes after dst with one byte overlap
   170  				src = buff[length-1 : 2*length-1]
   171  				dst = buff[:length]
   172  				mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
   173  			})
   174  		}
   175  	})
   176  
   177  	t.Run("KeepState", func(t *testing.T) {
   178  		rng := newRandReader(t)
   179  
   180  		plaintext := make([]byte, bufCap)
   181  		rng.Read(plaintext)
   182  		ciphertext := make([]byte, bufCap)
   183  
   184  		// Make one long call to XORKeyStream
   185  		ms().XORKeyStream(ciphertext, plaintext)
   186  
   187  		for _, step := range bufLens {
   188  			if step == 0 {
   189  				continue
   190  			}
   191  			stepMsg := fmt.Sprintf("step %d: ", step)
   192  
   193  			dst := make([]byte, bufCap)
   194  
   195  			// Make a bunch of small calls to (stateful) XORKeyStream
   196  			stream := ms()
   197  			i := 0
   198  			for i+step < len(plaintext) {
   199  				stream.XORKeyStream(dst[i:], plaintext[i:i+step])
   200  				i += step
   201  			}
   202  			stream.XORKeyStream(dst[i:], plaintext[i:])
   203  
   204  			if !bytes.Equal(dst, ciphertext) {
   205  				t.Errorf(stepMsg+"successive XORKeyStream calls returned a different result than a single one; got %s, want %s", truncateHex(dst), truncateHex(ciphertext))
   206  			}
   207  		}
   208  	})
   209  }
   210  
   211  // TestStreamFromBlock creates a Stream from a cipher.Block used in a
   212  // cipher.BlockMode. It addresses Issue 68377 by checking for a panic when the
   213  // BlockMode uses an IV with incorrect length.
   214  // For a valid IV, it also runs all TestStream tests on the resulting stream.
   215  func TestStreamFromBlock(t *testing.T, block cipher.Block, blockMode func(b cipher.Block, iv []byte) cipher.Stream) {
   216  
   217  	t.Run("WrongIVLen", func(t *testing.T) {
   218  		t.Skip("see Issue 68377")
   219  
   220  		rng := newRandReader(t)
   221  		iv := make([]byte, block.BlockSize()+1)
   222  		rng.Read(iv)
   223  		mustPanic(t, "IV length must equal block size", func() { blockMode(block, iv) })
   224  	})
   225  
   226  	t.Run("BlockModeStream", func(t *testing.T) {
   227  		rng := newRandReader(t)
   228  		iv := make([]byte, block.BlockSize())
   229  		rng.Read(iv)
   230  
   231  		TestStream(t, func() cipher.Stream { return blockMode(block, iv) })
   232  	})
   233  }
   234  
   235  func truncateHex(b []byte) string {
   236  	numVals := 50
   237  
   238  	if len(b) <= numVals {
   239  		return fmt.Sprintf("%x", b)
   240  	}
   241  	return fmt.Sprintf("%x...", b[:numVals])
   242  }