github.com/emmansun/gmsm@v0.29.1/internal/cryptotest/aead.go (about)

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cryptotest
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/cipher"
    10  	"fmt"
    11  	"testing"
    12  )
    13  
    14  var lengths = []int{0, 156, 8192, 8193, 8208}
    15  
    16  // MakeAEAD returns a cipher.AEAD instance.
    17  //
    18  // Multiple calls to MakeAEAD must return equivalent instances, so for example
    19  // the key must be fixed.
    20  type MakeAEAD func() (cipher.AEAD, error)
    21  
    22  // TestAEAD performs a set of tests on cipher.AEAD implementations, checking
    23  // the documented requirements of NonceSize, Overhead, Seal and Open.
    24  func TestAEAD(t *testing.T, mAEAD MakeAEAD) {
    25  	aead, err := mAEAD()
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  
    30  	t.Run("Roundtrip", func(t *testing.T) {
    31  
    32  		// Test all combinations of plaintext and additional data lengths.
    33  		for _, ptLen := range lengths {
    34  			for _, adLen := range lengths {
    35  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    36  					rng := newRandReader(t)
    37  
    38  					nonce := make([]byte, aead.NonceSize())
    39  					rng.Read(nonce)
    40  
    41  					before, addData := make([]byte, adLen), make([]byte, ptLen)
    42  					rng.Read(before)
    43  					rng.Read(addData)
    44  
    45  					ciphertext := sealMsg(t, aead, nil, nonce, before, addData)
    46  					after := openWithoutError(t, aead, nil, nonce, ciphertext, addData)
    47  
    48  					if !bytes.Equal(after, before) {
    49  						t.Errorf("plaintext is different after a seal/open cycle; got %s, want %s", truncateHex(after), truncateHex(before))
    50  					}
    51  				})
    52  			}
    53  		}
    54  	})
    55  
    56  	t.Run("InputNotModified", func(t *testing.T) {
    57  
    58  		// Test all combinations of plaintext and additional data lengths.
    59  		for _, ptLen := range lengths {
    60  			for _, adLen := range lengths {
    61  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
    62  					t.Run("Seal", func(t *testing.T) {
    63  						rng := newRandReader(t)
    64  
    65  						nonce := make([]byte, aead.NonceSize())
    66  						rng.Read(nonce)
    67  
    68  						src, before := make([]byte, ptLen), make([]byte, ptLen)
    69  						rng.Read(src)
    70  						copy(before, src)
    71  
    72  						addData := make([]byte, adLen)
    73  						rng.Read(addData)
    74  
    75  						sealMsg(t, aead, nil, nonce, src, addData)
    76  						if !bytes.Equal(src, before) {
    77  							t.Errorf("Seal modified src; got %s, want %s", truncateHex(src), truncateHex(before))
    78  						}
    79  					})
    80  
    81  					t.Run("Open", func(t *testing.T) {
    82  						rng := newRandReader(t)
    83  
    84  						nonce := make([]byte, aead.NonceSize())
    85  						rng.Read(nonce)
    86  
    87  						plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
    88  						rng.Read(plaintext)
    89  						rng.Read(addData)
    90  
    91  						// Record the ciphertext that shouldn't be modified as the input of
    92  						// Open.
    93  						ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
    94  						before := make([]byte, len(ciphertext))
    95  						copy(before, ciphertext)
    96  
    97  						openWithoutError(t, aead, nil, nonce, ciphertext, addData)
    98  						if !bytes.Equal(ciphertext, before) {
    99  							t.Errorf("Open modified src; got %s, want %s", truncateHex(ciphertext), truncateHex(before))
   100  						}
   101  					})
   102  				})
   103  			}
   104  		}
   105  	})
   106  
   107  	t.Run("BufferOverlap", func(t *testing.T) {
   108  
   109  		// Test all combinations of plaintext and additional data lengths.
   110  		for _, ptLen := range lengths {
   111  			if ptLen <= 1 { // We need enough room for an inexact overlap to occur.
   112  				continue
   113  			}
   114  			for _, adLen := range lengths {
   115  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
   116  					t.Run("Seal", func(t *testing.T) {
   117  						rng := newRandReader(t)
   118  
   119  						nonce := make([]byte, aead.NonceSize())
   120  						rng.Read(nonce)
   121  
   122  						// Make a buffer that can hold a plaintext and ciphertext as we
   123  						// overlap their slices to check for panic on inexact overlaps.
   124  						ctLen := ptLen + aead.Overhead()
   125  						buff := make([]byte, ptLen+ctLen)
   126  						rng.Read(buff)
   127  
   128  						addData := make([]byte, adLen)
   129  						rng.Read(addData)
   130  
   131  						// Make plaintext and dst slices point to same array with inexact overlap.
   132  						plaintext := buff[:ptLen]
   133  						dst := buff[1:1] // Shift dst to not start at start of plaintext.
   134  						mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) })
   135  
   136  						// Only overlap on one byte
   137  						plaintext = buff[:ptLen]
   138  						dst = buff[ptLen-1 : ptLen-1]
   139  						mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) })
   140  					})
   141  
   142  					t.Run("Open", func(t *testing.T) {
   143  						rng := newRandReader(t)
   144  
   145  						nonce := make([]byte, aead.NonceSize())
   146  						rng.Read(nonce)
   147  
   148  						// Create a valid ciphertext to test Open with.
   149  						plaintext := make([]byte, ptLen)
   150  						rng.Read(plaintext)
   151  						addData := make([]byte, adLen)
   152  						rng.Read(addData)
   153  						validCT := sealMsg(t, aead, nil, nonce, plaintext, addData)
   154  
   155  						// Make a buffer that can hold a plaintext and ciphertext as we
   156  						// overlap their slices to check for panic on inexact overlaps.
   157  						buff := make([]byte, ptLen+len(validCT))
   158  
   159  						// Make ciphertext and dst slices point to same array with inexact overlap.
   160  						ciphertext := buff[:len(validCT)]
   161  						copy(ciphertext, validCT)
   162  						dst := buff[1:1] // Shift dst to not start at start of ciphertext.
   163  						mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) })
   164  
   165  						// Only overlap on one byte.
   166  						ciphertext = buff[:len(validCT)]
   167  						copy(ciphertext, validCT)
   168  						// Make sure it is the actual ciphertext being overlapped and not
   169  						// the hash digest which might be extracted/truncated in some
   170  						// implementations: Go one byte past the hash digest/tag and into
   171  						// the ciphertext.
   172  						beforeTag := len(validCT) - aead.Overhead()
   173  						dst = buff[beforeTag-1 : beforeTag-1]
   174  						mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) })
   175  					})
   176  				})
   177  			}
   178  		}
   179  	})
   180  
   181  	t.Run("AppendDst", func(t *testing.T) {
   182  
   183  		// Test all combinations of plaintext and additional data lengths.
   184  		for _, ptLen := range lengths {
   185  			for _, adLen := range lengths {
   186  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
   187  
   188  					t.Run("Seal", func(t *testing.T) {
   189  						rng := newRandReader(t)
   190  
   191  						nonce := make([]byte, aead.NonceSize())
   192  						rng.Read(nonce)
   193  
   194  						shortBuff := []byte("a")
   195  						longBuff := make([]byte, 512)
   196  						rng.Read(longBuff)
   197  						prefixes := [][]byte{shortBuff, longBuff}
   198  
   199  						// Check each prefix gets appended to by Seal without altering them.
   200  						for _, prefix := range prefixes {
   201  							plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
   202  							rng.Read(plaintext)
   203  							rng.Read(addData)
   204  							out := sealMsg(t, aead, prefix, nonce, plaintext, addData)
   205  
   206  							// Check that Seal didn't alter the prefix
   207  							if !bytes.Equal(out[:len(prefix)], prefix) {
   208  								t.Errorf("Seal alters dst instead of appending; got %s, want %s", truncateHex(out[:len(prefix)]), truncateHex(prefix))
   209  							}
   210  
   211  							ciphertext := out[len(prefix):]
   212  							// Check that the appended ciphertext wasn't affected by the prefix
   213  							if expectedCT := sealMsg(t, aead, nil, nonce, plaintext, addData); !bytes.Equal(ciphertext, expectedCT) {
   214  								t.Errorf("Seal behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(ciphertext), truncateHex(expectedCT))
   215  							}
   216  						}
   217  					})
   218  
   219  					t.Run("Open", func(t *testing.T) {
   220  						rng := newRandReader(t)
   221  
   222  						nonce := make([]byte, aead.NonceSize())
   223  						rng.Read(nonce)
   224  
   225  						shortBuff := []byte("a")
   226  						longBuff := make([]byte, 512)
   227  						rng.Read(longBuff)
   228  						prefixes := [][]byte{shortBuff, longBuff}
   229  
   230  						// Check each prefix gets appended to by Open without altering them.
   231  						for _, prefix := range prefixes {
   232  							before, addData := make([]byte, adLen), make([]byte, ptLen)
   233  							rng.Read(before)
   234  							rng.Read(addData)
   235  							ciphertext := sealMsg(t, aead, nil, nonce, before, addData)
   236  
   237  							out := openWithoutError(t, aead, prefix, nonce, ciphertext, addData)
   238  
   239  							// Check that Open didn't alter the prefix
   240  							if !bytes.Equal(out[:len(prefix)], prefix) {
   241  								t.Errorf("Open alters dst instead of appending; got %s, want %s", truncateHex(out[:len(prefix)]), truncateHex(prefix))
   242  							}
   243  
   244  							after := out[len(prefix):]
   245  							// Check that the appended plaintext wasn't affected by the prefix
   246  							if !bytes.Equal(after, before) {
   247  								t.Errorf("Open behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(after), truncateHex(before))
   248  							}
   249  						}
   250  					})
   251  				})
   252  			}
   253  		}
   254  	})
   255  
   256  	t.Run("WrongNonce", func(t *testing.T) {
   257  
   258  		// Test all combinations of plaintext and additional data lengths.
   259  		for _, ptLen := range lengths {
   260  			for _, adLen := range lengths {
   261  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
   262  					rng := newRandReader(t)
   263  
   264  					nonce := make([]byte, aead.NonceSize())
   265  					rng.Read(nonce)
   266  
   267  					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
   268  					rng.Read(plaintext)
   269  					rng.Read(addData)
   270  
   271  					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
   272  
   273  					// Perturb the nonce and check for an error when Opening
   274  					alterNonce := make([]byte, aead.NonceSize())
   275  					copy(alterNonce, nonce)
   276  					alterNonce[len(alterNonce)-1] += 1
   277  					_, err := aead.Open(nil, alterNonce, ciphertext, addData)
   278  
   279  					if err == nil {
   280  						t.Errorf("Open did not error when given different nonce than Sealed with")
   281  					}
   282  				})
   283  			}
   284  		}
   285  	})
   286  
   287  	t.Run("WrongAddData", func(t *testing.T) {
   288  
   289  		// Test all combinations of plaintext and additional data lengths.
   290  		for _, ptLen := range lengths {
   291  			for _, adLen := range lengths {
   292  				if adLen == 0 {
   293  					continue
   294  				}
   295  
   296  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
   297  					rng := newRandReader(t)
   298  
   299  					nonce := make([]byte, aead.NonceSize())
   300  					rng.Read(nonce)
   301  
   302  					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
   303  					rng.Read(plaintext)
   304  					rng.Read(addData)
   305  
   306  					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
   307  
   308  					// Perturb the Additional Data and check for an error when Opening
   309  					alterAD := make([]byte, adLen)
   310  					copy(alterAD, addData)
   311  					alterAD[len(alterAD)-1] += 1
   312  					_, err := aead.Open(nil, nonce, ciphertext, alterAD)
   313  
   314  					if err == nil {
   315  						t.Errorf("Open did not error when given different Additional Data than Sealed with")
   316  					}
   317  				})
   318  			}
   319  		}
   320  	})
   321  
   322  	t.Run("WrongCiphertext", func(t *testing.T) {
   323  
   324  		// Test all combinations of plaintext and additional data lengths.
   325  		for _, ptLen := range lengths {
   326  			for _, adLen := range lengths {
   327  
   328  				t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) {
   329  					rng := newRandReader(t)
   330  
   331  					nonce := make([]byte, aead.NonceSize())
   332  					rng.Read(nonce)
   333  
   334  					plaintext, addData := make([]byte, ptLen), make([]byte, adLen)
   335  					rng.Read(plaintext)
   336  					rng.Read(addData)
   337  
   338  					ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData)
   339  
   340  					// Perturb the ciphertext and check for an error when Opening
   341  					alterCT := make([]byte, len(ciphertext))
   342  					copy(alterCT, ciphertext)
   343  					alterCT[len(alterCT)-1] += 1
   344  					_, err := aead.Open(nil, nonce, alterCT, addData)
   345  
   346  					if err == nil {
   347  						t.Errorf("Open did not error when given different ciphertext than was produced by Seal")
   348  					}
   349  				})
   350  			}
   351  		}
   352  	})
   353  }
   354  
   355  // Helper function to Seal a plaintext with additional data. Checks that
   356  // ciphertext isn't bigger than the plaintext length plus Overhead()
   357  func sealMsg(t *testing.T, aead cipher.AEAD, ciphertext, nonce, plaintext, addData []byte) []byte {
   358  	t.Helper()
   359  
   360  	initialLen := len(ciphertext)
   361  
   362  	ciphertext = aead.Seal(ciphertext, nonce, plaintext, addData)
   363  
   364  	lenCT := len(ciphertext) - initialLen
   365  
   366  	// Appended ciphertext shouldn't ever be longer than the length of the
   367  	// plaintext plus Overhead
   368  	if lenCT > len(plaintext)+aead.Overhead() {
   369  		t.Errorf("length of ciphertext from Seal exceeds length of plaintext by more than Overhead(); got %d, want <=%d", lenCT, len(plaintext)+aead.Overhead())
   370  	}
   371  
   372  	return ciphertext
   373  }
   374  
   375  // Helper function to Open and authenticate ciphertext. Checks that Open
   376  // doesn't error (assuming ciphertext was well-formed with corresponding nonce
   377  // and additional data).
   378  func openWithoutError(t *testing.T, aead cipher.AEAD, plaintext, nonce, ciphertext, addData []byte) []byte {
   379  	t.Helper()
   380  
   381  	plaintext, err := aead.Open(plaintext, nonce, ciphertext, addData)
   382  	if err != nil {
   383  		t.Fatalf("Open returned error on properly formed ciphertext; got \"%s\", want \"nil\"", err)
   384  	}
   385  
   386  	return plaintext
   387  }