decred.org/dcrwallet/v3@v3.1.0/wallet/locking_test.go (about)

     1  // Copyright (c) 2019 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package wallet
     6  
     7  import (
     8  	"context"
     9  	"testing"
    10  	"time"
    11  
    12  	"decred.org/dcrwallet/v3/errors"
    13  )
    14  
    15  var testPrivPass = []byte("private")
    16  
    17  func TestLocking(t *testing.T) {
    18  	ctx := context.Background()
    19  
    20  	w, teardown := testWallet(ctx, t, &basicWalletConfig)
    21  	defer teardown()
    22  
    23  	var tests = []func(ctx context.Context, t *testing.T, w *Wallet){
    24  		testUnlock,
    25  		testLockOnBadPassphrase,
    26  		testNoNilTimeoutReplacement,
    27  		testNonNilTimeoutLock,
    28  		testTimeoutReplacement,
    29  	}
    30  	for _, test := range tests {
    31  		test(ctx, t, w)
    32  		w.Lock()
    33  	}
    34  }
    35  
    36  func testUnlock(ctx context.Context, t *testing.T, w *Wallet) {
    37  	if !w.Locked() {
    38  		t.Fatal("expected wallet to be locked")
    39  	}
    40  	// Unlock without timeout
    41  	err := w.Unlock(ctx, testPrivPass, nil)
    42  	if err != nil {
    43  		t.Fatal("failed to unlock wallet")
    44  	}
    45  	if w.Locked() {
    46  		t.Fatal("expected wallet to be unlocked")
    47  	}
    48  	completedLock := make(chan struct{})
    49  	go func() {
    50  		w.Lock()
    51  		completedLock <- struct{}{}
    52  	}()
    53  	time.Sleep(100 * time.Millisecond)
    54  	select {
    55  	case <-completedLock:
    56  	default:
    57  		t.Fatal("expected wallet to lock")
    58  	}
    59  	if !w.Locked() {
    60  		t.Fatal("expected wallet to be locked")
    61  	}
    62  }
    63  
    64  func testLockOnBadPassphrase(ctx context.Context, t *testing.T, w *Wallet) {
    65  	err := w.Unlock(ctx, testPrivPass, nil)
    66  	if err != nil {
    67  		t.Fatal("failed to unlock wallet")
    68  	}
    69  	err = w.Unlock(ctx, []byte("incorrect"), nil)
    70  	if !errors.Is(err, errors.Passphrase) {
    71  		t.Fatal("expected Passphrase error on bad Unlock")
    72  	}
    73  	if !w.Locked() {
    74  		t.Fatal("expected wallet to be locked after failed Unlock")
    75  	}
    76  
    77  	err = w.Unlock(ctx, testPrivPass, nil)
    78  	if err != nil {
    79  		t.Fatal("failed to unlock wallet")
    80  	}
    81  	err = w.Unlock(ctx, []byte("incorrect"), nil)
    82  	if !errors.Is(err, errors.Passphrase) {
    83  		t.Fatal("expected Passphrase error on bad Unlock")
    84  	}
    85  	if !w.Locked() {
    86  		t.Fatal("expected wallet to lock after unlocking with bad passphrase")
    87  	}
    88  }
    89  
    90  // Test:
    91  // If the wallet is currently unlocked without any timeout, timeout is ignored
    92  // and if non-nil, is read in a background goroutine to avoid blocking sends.
    93  func testNoNilTimeoutReplacement(ctx context.Context, t *testing.T, w *Wallet) {
    94  	err := w.Unlock(ctx, testPrivPass, nil)
    95  	if err != nil {
    96  		t.Fatal("failed to unlock wallet")
    97  	}
    98  	timeChan := make(chan time.Time)
    99  	err = w.Unlock(ctx, testPrivPass, timeChan)
   100  	if err != nil {
   101  		t.Fatal("failed to unlock wallet with time channel")
   102  	}
   103  	select {
   104  	case timeChan <- time.Time{}:
   105  	case <-time.After(100 * time.Millisecond):
   106  		t.Fatal("time channel was not read in 100ms")
   107  	}
   108  	if w.Locked() {
   109  		t.Fatal("expected wallet to remain unlocked due to previous unlock without timeout")
   110  	}
   111  }
   112  
   113  // Test:
   114  // If the wallet is locked and a non-nil timeout is provided, the wallet will be
   115  // locked in the background after reading from the channel.
   116  func testNonNilTimeoutLock(ctx context.Context, t *testing.T, w *Wallet) {
   117  	timeChan := make(chan time.Time)
   118  	err := w.Unlock(ctx, testPrivPass, timeChan)
   119  	if err != nil {
   120  		t.Fatal("failed to unlock wallet")
   121  	}
   122  	timeChan <- time.Time{}
   123  	time.Sleep(100 * time.Millisecond) // Allow time for lock in background
   124  	if !w.Locked() {
   125  		t.Fatal("wallet should have locked after timeout")
   126  	}
   127  }
   128  
   129  // Test:
   130  // If the wallet is already unlocked with a previous timeout, the new timeout
   131  // replaces the prior.
   132  func testTimeoutReplacement(ctx context.Context, t *testing.T, w *Wallet) {
   133  	timeChan1 := make(chan time.Time)
   134  	timeChan2 := make(chan time.Time)
   135  	err := w.Unlock(ctx, testPrivPass, timeChan1)
   136  	if err != nil {
   137  		t.Fatal("failed to unlock wallet")
   138  	}
   139  	err = w.Unlock(ctx, testPrivPass, timeChan2)
   140  	if err != nil {
   141  		t.Fatal("failed to unlock wallet")
   142  	}
   143  	timeChan2 <- time.Time{}
   144  	time.Sleep(100 * time.Millisecond) // Allow time for lock in background
   145  	if !w.Locked() {
   146  		t.Fatal("wallet did not lock using replacement timeout")
   147  	}
   148  	select {
   149  	case timeChan1 <- time.Time{}:
   150  	default:
   151  		t.Fatal("previous timeout was not read in background")
   152  	}
   153  }