github.com/wfusion/gofusion@v1.1.14/test/common/utils/cases/cipher_test.go (about)

     1  package cases
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/aes"
     7  	"crypto/des"
     8  	"errors"
     9  	"math/rand"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/suite"
    15  	"golang.org/x/crypto/chacha20poly1305"
    16  
    17  	"github.com/wfusion/gofusion/common/utils"
    18  	"github.com/wfusion/gofusion/common/utils/cipher"
    19  	"github.com/wfusion/gofusion/log"
    20  
    21  	testUtl "github.com/wfusion/gofusion/test/common/utils"
    22  )
    23  
    24  func TestCipher(t *testing.T) {
    25  	t.Parallel()
    26  	testingSuite := &Cipher{Test: new(testUtl.Test)}
    27  	suite.Run(t, testingSuite)
    28  }
    29  
    30  type Cipher struct {
    31  	*testUtl.Test
    32  }
    33  
    34  func (t *Cipher) BeforeTest(suiteName, testName string) {
    35  	t.Catch(func() {
    36  		log.Info(context.Background(), "right before %s %s", suiteName, testName)
    37  	})
    38  }
    39  
    40  func (t *Cipher) AfterTest(suiteName, testName string) {
    41  	t.Catch(func() {
    42  		log.Info(context.Background(), "right after %s %s", suiteName, testName)
    43  	})
    44  }
    45  
    46  func (t *Cipher) TestDES() {
    47  	t.Catch(func() {
    48  		var (
    49  			key [8]byte
    50  			iv  [des.BlockSize]byte
    51  		)
    52  		_, err := utils.Random(key[:], 0)
    53  		t.NoError(err)
    54  		_, err = utils.Random(iv[:], 0)
    55  		t.NoError(err)
    56  
    57  		t.runTest(cipher.AlgorithmDES, key[:], iv[:], cipher.ModeGCM)
    58  	})
    59  }
    60  
    61  func (t *Cipher) Test3DES() {
    62  	t.Catch(func() {
    63  		var (
    64  			key [8 * 3]byte
    65  			iv  [des.BlockSize]byte
    66  		)
    67  		_, err := utils.Random(key[:], 0)
    68  		t.NoError(err)
    69  		_, err = utils.Random(iv[:], 0)
    70  		t.NoError(err)
    71  
    72  		t.runTest(cipher.Algorithm3DES, key[:], iv[:], cipher.ModeGCM)
    73  	})
    74  }
    75  
    76  func (t *Cipher) TestAES128() {
    77  	t.Catch(func() {
    78  		var (
    79  			key [16]byte
    80  			iv  [aes.BlockSize]byte
    81  		)
    82  		_, err := utils.Random(key[:], 0)
    83  		t.NoError(err)
    84  		_, err = utils.Random(iv[:], 0)
    85  		t.NoError(err)
    86  
    87  		t.runTest(cipher.AlgorithmAES, key[:], iv[:])
    88  	})
    89  }
    90  
    91  func (t *Cipher) TestAES192() {
    92  	t.Catch(func() {
    93  		var (
    94  			key [24]byte
    95  			iv  [aes.BlockSize]byte
    96  		)
    97  		_, err := utils.Random(key[:], 0)
    98  		t.NoError(err)
    99  		_, err = utils.Random(iv[:], 0)
   100  		t.NoError(err)
   101  
   102  		t.runTest(cipher.AlgorithmAES, key[:], iv[:])
   103  	})
   104  }
   105  
   106  func (t *Cipher) TestAES256() {
   107  	t.Catch(func() {
   108  		var (
   109  			key [32]byte
   110  			iv  [aes.BlockSize]byte
   111  		)
   112  		_, err := utils.Random(key[:], 0)
   113  		t.NoError(err)
   114  		_, err = utils.Random(iv[:], 0)
   115  		t.NoError(err)
   116  
   117  		t.runTest(cipher.AlgorithmAES, key[:], iv[:])
   118  	})
   119  }
   120  
   121  func (t *Cipher) TestRC4_8() {
   122  	t.Catch(func() {
   123  		var (
   124  			key [1]byte
   125  		)
   126  		_, err := utils.Random(key[:], 0)
   127  		t.NoError(err)
   128  
   129  		t.runTest(cipher.AlgorithmRC4, key[:], nil,
   130  			cipher.ModeCBC, cipher.ModeCFB, cipher.ModeCTR, cipher.ModeOFB, cipher.ModeGCM)
   131  	})
   132  }
   133  
   134  func (t *Cipher) TestRC4_256() {
   135  	t.Catch(func() {
   136  		var (
   137  			key [32]byte
   138  		)
   139  		_, err := utils.Random(key[:], 0)
   140  		t.NoError(err)
   141  
   142  		t.runTest(cipher.AlgorithmRC4, key[:], nil,
   143  			cipher.ModeCBC, cipher.ModeCFB, cipher.ModeCTR, cipher.ModeOFB, cipher.ModeGCM)
   144  	})
   145  }
   146  
   147  func (t *Cipher) TestRC4_2048() {
   148  	t.Catch(func() {
   149  		var (
   150  			key [256]byte
   151  		)
   152  		_, err := utils.Random(key[:], 0)
   153  		t.NoError(err)
   154  
   155  		t.runTest(cipher.AlgorithmRC4, key[:], nil,
   156  			cipher.ModeCBC, cipher.ModeCFB, cipher.ModeCTR, cipher.ModeOFB, cipher.ModeGCM)
   157  	})
   158  }
   159  
   160  func (t *Cipher) TestChaCha20poly1305() {
   161  	t.Catch(func() {
   162  		var (
   163  			key [chacha20poly1305.KeySize]byte
   164  		)
   165  		_, err := utils.Random(key[:], 0)
   166  		t.NoError(err)
   167  
   168  		t.runTest(cipher.AlgorithmChaCha20poly1305, key[:], nil,
   169  			cipher.ModeCBC, cipher.ModeCFB, cipher.ModeCTR, cipher.ModeOFB, cipher.ModeGCM)
   170  	})
   171  }
   172  
   173  func (t *Cipher) TestXChaCha20poly1305() {
   174  	t.Catch(func() {
   175  		var (
   176  			key [chacha20poly1305.KeySize]byte
   177  		)
   178  		_, err := utils.Random(key[:], 0)
   179  		t.NoError(err)
   180  
   181  		t.runTest(cipher.AlgorithmXChaCha20poly1305, key[:], nil,
   182  			cipher.ModeCBC, cipher.ModeCFB, cipher.ModeCTR, cipher.ModeOFB, cipher.ModeGCM)
   183  	})
   184  }
   185  
   186  func (t *Cipher) TestSM4() {
   187  	t.Catch(func() {
   188  		var (
   189  			key [16]byte
   190  			iv  [aes.BlockSize]byte
   191  		)
   192  		_, err := utils.Random(key[:], 0)
   193  		t.NoError(err)
   194  		_, err = utils.Random(iv[:], 0)
   195  		t.NoError(err)
   196  
   197  		t.runTest(cipher.AlgorithmSM4, key[:], iv[:])
   198  	})
   199  }
   200  
   201  func (t *Cipher) runTest(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   202  	testCases := []testCipherFunc{
   203  		t.testModes,
   204  		t.testLargeBytes,
   205  		t.testBytesParallel,
   206  		t.testStreaming,
   207  		t.testStreamingParallel,
   208  	}
   209  
   210  	rand.Seed(utils.GetTimeStamp(time.Now()))
   211  	for _, testCase := range testCases {
   212  		testCase(algo, key, iv, ignoreModes...)
   213  	}
   214  }
   215  
   216  type testCipherFunc func(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode)
   217  
   218  func (t *Cipher) testModes(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   219  	type caseStruct struct {
   220  		data []byte
   221  		mode cipher.Mode
   222  	}
   223  
   224  	caseList := []caseStruct{
   225  		{
   226  			data: []byte("this is a plain text."),
   227  			mode: cipher.ModeECB,
   228  		},
   229  		{
   230  			data: []byte("this is a plain text."),
   231  			mode: cipher.ModeCBC,
   232  		},
   233  		{
   234  			data: []byte("this is a plain text."),
   235  			mode: cipher.ModeCFB,
   236  		},
   237  		{
   238  			data: []byte("this is a plain text."),
   239  			mode: cipher.ModeCTR,
   240  		},
   241  		{
   242  			data: []byte("this is a plain text."),
   243  			mode: cipher.ModeOFB,
   244  		},
   245  		{
   246  			data: []byte("this is a plain text."),
   247  			mode: cipher.ModeGCM,
   248  		},
   249  	}
   250  
   251  	ignored := utils.NewSet(ignoreModes...)
   252  	for _, cs := range caseList {
   253  		if ignored.Contains(cs.mode) {
   254  			continue
   255  		}
   256  
   257  		name := algo.String() + "_" + cs.mode.String()
   258  		t.Run(name, func() {
   259  			enc, err := cipher.EncryptBytesFunc(algo, cs.mode, key, iv)
   260  			t.NoError(err)
   261  			dec, err := cipher.DecryptBytesFunc(algo, cs.mode, key, iv)
   262  			t.NoError(err)
   263  
   264  			ciphertext, err := enc(cs.data)
   265  			t.NoError(err)
   266  			t.NotEmpty(ciphertext)
   267  			t.NotEqualValues(cs.data, ciphertext)
   268  
   269  			actual, err := dec(ciphertext)
   270  			t.NoError(err)
   271  
   272  			t.EqualValues(cs.data, actual)
   273  		})
   274  	}
   275  }
   276  
   277  func (t *Cipher) testLargeBytes(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   278  	type caseStruct struct {
   279  		mode cipher.Mode
   280  	}
   281  
   282  	caseList := []caseStruct{
   283  		{
   284  			mode: cipher.ModeECB,
   285  		},
   286  		{
   287  			mode: cipher.ModeCBC,
   288  		},
   289  		{
   290  			mode: cipher.ModeCFB,
   291  		},
   292  		{
   293  			mode: cipher.ModeCTR,
   294  		},
   295  		{
   296  			mode: cipher.ModeOFB,
   297  		},
   298  		{
   299  			mode: cipher.ModeGCM,
   300  		},
   301  	}
   302  
   303  	ignored := utils.NewSet(ignoreModes...)
   304  	data := t.randomData()
   305  	for _, cs := range caseList {
   306  		if ignored.Contains(cs.mode) {
   307  			continue
   308  		}
   309  		name := algo.String() + "_" + cs.mode.String() + "_large_bytes"
   310  		t.Run(name, func() {
   311  			enc, err := cipher.EncryptBytesFunc(algo, cs.mode, key, iv)
   312  			t.NoError(err)
   313  			dec, err := cipher.DecryptBytesFunc(algo, cs.mode, key, iv)
   314  			t.NoError(err)
   315  
   316  			ciphertext, err := enc(data)
   317  			t.NoError(err)
   318  			t.NotEmpty(ciphertext)
   319  			t.NotEqualValues(data, ciphertext)
   320  
   321  			actual, err := dec(ciphertext)
   322  			t.NoError(err)
   323  
   324  			t.EqualValues(data, actual)
   325  		})
   326  	}
   327  }
   328  
   329  func (t *Cipher) testBytesParallel(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   330  	type caseStruct struct {
   331  		mode cipher.Mode
   332  	}
   333  
   334  	caseList := []caseStruct{
   335  		{
   336  			mode: cipher.ModeECB,
   337  		},
   338  		{
   339  			mode: cipher.ModeCBC,
   340  		},
   341  		{
   342  			mode: cipher.ModeCFB,
   343  		},
   344  		{
   345  			mode: cipher.ModeCTR,
   346  		},
   347  		{
   348  			mode: cipher.ModeOFB,
   349  		},
   350  		{
   351  			mode: cipher.ModeGCM,
   352  		},
   353  	}
   354  
   355  	ignored := utils.NewSet(ignoreModes...)
   356  	name := algo.String() + "_bytes_parallel"
   357  	t.Run(name, func() {
   358  		wg := new(sync.WaitGroup)
   359  		defer wg.Wait()
   360  
   361  		data := t.randomData()
   362  		for _, cs := range caseList {
   363  			if ignored.Contains(cs.mode) {
   364  				continue
   365  			}
   366  
   367  			mode := cs.mode
   368  			for i := 0; i < 5; i++ {
   369  				wg.Add(1)
   370  				go func() {
   371  					defer wg.Done()
   372  
   373  					enc, err := cipher.EncryptBytesFunc(algo, mode, key, iv)
   374  					t.NoError(err)
   375  					dec, err := cipher.DecryptBytesFunc(algo, mode, key, iv)
   376  					t.NoError(err)
   377  
   378  					ciphertext, err := enc(data)
   379  					t.NoError(err)
   380  					t.NotEmpty(ciphertext)
   381  					t.NotEqualValues(data, ciphertext)
   382  
   383  					actual, err := dec(ciphertext)
   384  					t.NoError(err)
   385  
   386  					t.EqualValues(data, actual)
   387  				}()
   388  			}
   389  		}
   390  	})
   391  }
   392  
   393  func (t *Cipher) testStreaming(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   394  	type caseStruct struct {
   395  		mode cipher.Mode
   396  	}
   397  
   398  	caseList := []caseStruct{
   399  		{
   400  			mode: cipher.ModeECB,
   401  		},
   402  		{
   403  			mode: cipher.ModeCBC,
   404  		},
   405  		{
   406  			mode: cipher.ModeCFB,
   407  		},
   408  		{
   409  			mode: cipher.ModeCTR,
   410  		},
   411  		{
   412  			mode: cipher.ModeOFB,
   413  		},
   414  		{
   415  			mode: cipher.ModeGCM,
   416  		},
   417  	}
   418  
   419  	ignored := utils.NewSet(ignoreModes...)
   420  	data := t.randomData()
   421  	for _, cs := range caseList {
   422  		if ignored.Contains(cs.mode) {
   423  			continue
   424  		}
   425  		if _, err := cipher.EncryptStreamFunc(algo, cs.mode, key, iv); errors.Is(err, cipher.ErrNotSupportStream) {
   426  			continue
   427  		}
   428  
   429  		mode := cs.mode
   430  		name := algo.String() + "_" + mode.String() + "_streaming"
   431  		t.Run(name, func() {
   432  			dataBuffer := bytes.NewReader(data)
   433  
   434  			enc, err := cipher.EncryptStreamFunc(algo, mode, key, iv)
   435  			t.NoError(err)
   436  			dec, err := cipher.DecryptStreamFunc(algo, mode, key, iv)
   437  			t.NoError(err)
   438  
   439  			cipherBuffer := bytes.NewBuffer(nil)
   440  			err = enc(cipherBuffer, dataBuffer)
   441  			t.NoError(err)
   442  			t.NotZero(cipherBuffer.Len())
   443  			t.NotEqualValues(data, cipherBuffer.Bytes())
   444  
   445  			plainBuffer := bytes.NewBuffer(nil)
   446  			err = dec(plainBuffer, cipherBuffer)
   447  			t.NoError(err)
   448  
   449  			t.EqualValues(data, plainBuffer.Bytes())
   450  		})
   451  	}
   452  }
   453  
   454  func (t *Cipher) testStreamingParallel(algo cipher.Algorithm, key, iv []byte, ignoreModes ...cipher.Mode) {
   455  	type caseStruct struct {
   456  		mode cipher.Mode
   457  	}
   458  
   459  	caseList := []caseStruct{
   460  		{
   461  			mode: cipher.ModeECB,
   462  		},
   463  		{
   464  			mode: cipher.ModeCBC,
   465  		},
   466  		{
   467  			mode: cipher.ModeCFB,
   468  		},
   469  		{
   470  			mode: cipher.ModeCTR,
   471  		},
   472  		{
   473  			mode: cipher.ModeOFB,
   474  		},
   475  		{
   476  			mode: cipher.ModeGCM,
   477  		},
   478  	}
   479  
   480  	ignored := utils.NewSet(ignoreModes...)
   481  	name := algo.String() + "_stream_parallel"
   482  	t.Run(name, func() {
   483  		wg := new(sync.WaitGroup)
   484  		defer wg.Wait()
   485  
   486  		data := t.randomData()
   487  		for _, cs := range caseList {
   488  			if ignored.Contains(cs.mode) {
   489  				continue
   490  			}
   491  			if _, err := cipher.EncryptStreamFunc(algo, cs.mode, key, iv); errors.Is(err, cipher.ErrNotSupportStream) {
   492  				continue
   493  			}
   494  
   495  			mode := cs.mode
   496  			for i := 0; i < 5; i++ {
   497  				wg.Add(1)
   498  				go func() {
   499  					defer wg.Done()
   500  
   501  					dataBuffer := bytes.NewReader(data)
   502  
   503  					enc, err := cipher.EncryptStreamFunc(algo, mode, key, iv)
   504  					t.NoError(err)
   505  					dec, err := cipher.DecryptStreamFunc(algo, mode, key, iv)
   506  					t.NoError(err)
   507  
   508  					cipherBuffer := bytes.NewBuffer(nil)
   509  					err = enc(cipherBuffer, dataBuffer)
   510  					t.NoError(err)
   511  					t.NotZero(cipherBuffer.Len())
   512  					t.NotEqualValues(data, cipherBuffer.Bytes())
   513  
   514  					plainBuffer := bytes.NewBuffer(nil)
   515  					err = dec(plainBuffer, cipherBuffer)
   516  					t.NoError(err)
   517  
   518  					t.EqualValues(data, plainBuffer.Bytes())
   519  				}()
   520  			}
   521  		}
   522  	})
   523  
   524  }
   525  
   526  func (t *Cipher) randomData() (data []byte) {
   527  	const (
   528  		jitterLength = 4 * 1024                   // 4kb
   529  		largeLength  = 1024*1024 - jitterLength/2 // 1m - 2kb
   530  	)
   531  
   532  	// 1m ± 2kb
   533  	data = make([]byte, largeLength+rand.Int()%(jitterLength/2))
   534  	//data = make([]byte, 10)
   535  	_, err := utils.Random(data, utils.GetTimeStamp(time.Now()))
   536  	t.NoError(err)
   537  	return
   538  }