github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/encrypt/encrypt_test.go (about)

     1  // Copyright 2022 wangqi. All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package encrypt
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"log/slog"
    11  	"reflect"
    12  	"syscall"
    13  	"testing"
    14  
    15  	"github.com/ericwq/aprilsh/util"
    16  )
    17  
    18  func TestPrng(t *testing.T) {
    19  	tc := []int{0, 1, 2, 4, 8, 16, 32}
    20  
    21  	for _, v := range tc {
    22  		got := PrngFill(v)
    23  		if v != len(got) {
    24  			t.Errorf("prngFill got %#v\n", got)
    25  		}
    26  	}
    27  
    28  	for i := 0; i < 8; i++ {
    29  		got := PrngUint8()
    30  		if got == 0 {
    31  			t.Errorf("prngUint8 got %#v\n", got)
    32  		}
    33  	}
    34  }
    35  
    36  func TestBase64Key(t *testing.T) {
    37  	util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
    38  
    39  	// normal key
    40  	normalKey := NewBase64Key()
    41  	printKey := normalKey.printableKey()
    42  	gotNormal := NewBase64Key2(printKey)
    43  	if !reflect.DeepEqual(normalKey.data(), gotNormal.data()) {
    44  		t.Errorf("two keys should be the same. got key1=\n%v, key2=\n%v\n", normalKey, gotNormal)
    45  	}
    46  
    47  	// malform key
    48  	malformBase64 := "/msvMB1KwXL+ygJHdJwwQ=="
    49  	malformKey := NewBase64Key2(malformBase64)
    50  	if malformKey != nil {
    51  		t.Error("malform key should be nil.")
    52  	}
    53  
    54  	// key length is short
    55  	shortLengthKey := &Base64Key{key: PrngFill(8)}
    56  	key4 := NewBase64Key2(shortLengthKey.String())
    57  	if key4 != nil {
    58  		t.Error("key length is short.")
    59  		t.Errorf("key length is short. %q\n", shortLengthKey.printableKey())
    60  	}
    61  }
    62  
    63  func TestUnique(t *testing.T) {
    64  	for i := 0; i < 10; i++ {
    65  		v := Unique()
    66  		expect := i + 1
    67  		if v != uint64(i+1) {
    68  			t.Errorf("Unique expect %d, got %d\n", expect, v)
    69  		}
    70  	}
    71  }
    72  
    73  func TestSession(t *testing.T) {
    74  	tc := []struct {
    75  		name      string
    76  		plainText string
    77  	}{
    78  		{"english plain text", "Datagrams are encrypted and authenticated using AES-128 in the Offset Codebook mode [1]"},
    79  		{"chinese plain text", "原子操作是比其它同步技术更基础的操作。原子操作是无锁的,常常直接通过CPU指令直接实现。"},
    80  	}
    81  
    82  	s, _ := NewSession(*NewBase64Key())
    83  	for _, v := range tc {
    84  		nonce, _ := randomNonce()
    85  		message := Message{nonce: nonce, text: []byte(v.plainText)}
    86  
    87  		// fmt.Printf("#before message nonce=% x, nonce=%p\n", message.nonce, message.nonce)
    88  		cipherText := s.Encrypt(&message)
    89  		// fmt.Printf("#after cipherText=% x\n", cipherText)
    90  
    91  		message2, _ := s.Decrypt(cipherText)
    92  		gotNonce := message2.nonce
    93  		gotText := message2.text
    94  
    95  		if !reflect.DeepEqual(nonce, gotNonce) {
    96  			t.Errorf("%q expect nonce %v, got %v\n", v.name, nonce, gotNonce)
    97  		}
    98  		if v.plainText != string(gotText) {
    99  			t.Errorf("%q expect plain text \n%q, got \n%q\n", v.name, v.plainText, gotText)
   100  		}
   101  	}
   102  }
   103  
   104  func TestSessionError(t *testing.T) {
   105  	util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
   106  
   107  	b := Base64Key{}
   108  	b.key = PrngFill(9)
   109  
   110  	if _, err := NewSession(b); err == nil {
   111  		t.Errorf("expect wrong key size error, got %s\n", err)
   112  	}
   113  
   114  	b.key = PrngFill(32)
   115  	s, _ := NewSession(b)
   116  	nilMessage, _ := s.Decrypt([]byte("zb0SLh88rdSHswjcgcC6949ZUuopGXTt"))
   117  	if nilMessage != nil {
   118  		t.Errorf("expect nil message returned from decrypt(), got %v\n", nilMessage)
   119  	}
   120  }
   121  
   122  func fakeRand(io.Reader, []byte) (int, error) {
   123  	return -2, errors.New("design this error on purpose.")
   124  }
   125  
   126  func TestRandomNonce(t *testing.T) {
   127  	util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
   128  
   129  	nonce, err := _randomNonce(fakeRand)
   130  
   131  	if nonce != nil {
   132  		t.Errorf("expect nil nonce, got %v\n %s\n", nonce, err)
   133  	}
   134  }
   135  
   136  func TestMessage(t *testing.T) {
   137  	tc := []struct {
   138  		name           string
   139  		seqNonce       uint64
   140  		mixPayload     string
   141  		timestamp      uint16
   142  		timestampReply uint16
   143  		payload        string
   144  	}{
   145  		{"english message", uint64(0x5223), "\x12\x23\x34\x45normal message", 0x1223, 0x3445, "normal message"},
   146  		{
   147  			"chinese message", uint64(0x7226) | (uint64(1) << 63), "\x42\x23\x64\x45大端字节序就和我们平时的写法顺序一样",
   148  			0x4223, 0x6445, "大端字节序就和我们平时的写法顺序一样",
   149  		},
   150  	}
   151  
   152  	for _, v := range tc {
   153  		m := NewMessage(v.seqNonce, []byte(v.mixPayload))
   154  
   155  		if len(m.nonce) != 12 {
   156  			t.Errorf("%q expect nonce length %d, got %d\n", v.name, 12, len(m.nonce))
   157  		}
   158  
   159  		if m.NonceVal() != v.seqNonce {
   160  			t.Errorf("%q expect seqNonce %x got %x\n", v.name, v.seqNonce, m.NonceVal())
   161  		}
   162  
   163  		if m.GetTimestamp() != v.timestamp {
   164  			t.Errorf("%q expect timestamp %x got %x\n", v.name, v.timestamp, m.GetTimestamp())
   165  		}
   166  
   167  		if m.GetTimestampReply() != v.timestampReply {
   168  			t.Errorf("%q expect timestampReply %x got %x\n", v.name, v.timestampReply, m.GetTimestampReply())
   169  		}
   170  
   171  		if string(m.GetPayload()) != v.payload {
   172  			t.Errorf("%q expect payload %x got %x\n", v.name, v.payload, m.GetPayload())
   173  		}
   174  	}
   175  }
   176  
   177  func TestDisableDumpingCore(t *testing.T) {
   178  	// get the RLIMIT_CORE
   179  	var rlim syscall.Rlimit
   180  	syscall.Getrlimit(syscall.RLIMIT_CORE, &rlim)
   181  	expect := rlim.Cur
   182  
   183  	DisableDumpingCore()
   184  
   185  	// validate the result
   186  	if savedCoreLimit != expect {
   187  		t.Errorf("#test DisableDumpingCore should be %d, got %d\n", expect, savedCoreLimit)
   188  	}
   189  
   190  	ReenableDumpingCore()
   191  	syscall.Getrlimit(syscall.RLIMIT_CORE, &rlim)
   192  }
   193  
   194  func TestDisableDumpingCoreError(t *testing.T) {
   195  	f0 := func(rlim *syscall.Rlimit, value uint64) {
   196  		// do nothing
   197  	}
   198  
   199  	// test get fail
   200  	// the resouce argument 20 is invalid
   201  	if err := accessRlimit(20, f0, 0); err == nil {
   202  		t.Errorf("#test accessRlimit should return error, got nil\n")
   203  	}
   204  
   205  	f1 := func(rlim *syscall.Rlimit, value uint64) {
   206  		// increase the hard limit
   207  		rlim.Cur = rlim.Max + 1
   208  	}
   209  
   210  	// test set fail
   211  	// increase hard limit is a privilege action
   212  	if err := accessRlimit(syscall.RLIMIT_NOFILE, f1, 0); err == nil {
   213  		t.Errorf("#test accessRlimit should return error, got nil\n")
   214  	}
   215  }