github.com/decred/dcrlnd@v0.7.6/sweep/test_utils.go (about)

     1  package sweep
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/decred/dcrd/chaincfg/chainhash"
    10  	"github.com/decred/dcrd/wire"
    11  	"github.com/decred/dcrlnd/chainntnfs"
    12  )
    13  
    14  var (
    15  	defaultTestTimeout = 5 * time.Second
    16  	processingDelay    = 1 * time.Second
    17  	mockChainHash, _   = chainhash.NewHashFromStr("00aabbccddeeff")
    18  	mockChainHeight    = int32(100)
    19  )
    20  
    21  // MockNotifier simulates the chain notifier for test purposes. This type is
    22  // exported because it is used in nursery tests.
    23  type MockNotifier struct {
    24  	confChannel map[chainhash.Hash]chan *chainntnfs.TxConfirmation
    25  	epochChan   map[chan *chainntnfs.BlockEpoch]int32
    26  	spendChan   map[wire.OutPoint][]chan *chainntnfs.SpendDetail
    27  	spends      map[wire.OutPoint]*wire.MsgTx
    28  	mutex       sync.RWMutex
    29  	t           *testing.T
    30  }
    31  
    32  // NewMockNotifier instantiates a new mock notifier.
    33  func NewMockNotifier(t *testing.T) *MockNotifier {
    34  	return &MockNotifier{
    35  		confChannel: make(map[chainhash.Hash]chan *chainntnfs.TxConfirmation),
    36  		epochChan:   make(map[chan *chainntnfs.BlockEpoch]int32),
    37  		spendChan:   make(map[wire.OutPoint][]chan *chainntnfs.SpendDetail),
    38  		spends:      make(map[wire.OutPoint]*wire.MsgTx),
    39  		t:           t,
    40  	}
    41  }
    42  
    43  // NotifyEpoch simulates a new epoch arriving.
    44  func (m *MockNotifier) NotifyEpoch(height int32) {
    45  	m.t.Helper()
    46  
    47  	for epochChan, chanHeight := range m.epochChan {
    48  		// Only send notifications if the height is greater than the
    49  		// height the caller passed into the register call.
    50  		if chanHeight >= height {
    51  			continue
    52  		}
    53  
    54  		log.Debugf("Notifying height %v to listener", height)
    55  
    56  		select {
    57  		case epochChan <- &chainntnfs.BlockEpoch{
    58  			Height: height,
    59  		}:
    60  		case <-time.After(defaultTestTimeout):
    61  			m.t.Fatal("epoch event not consumed")
    62  		}
    63  	}
    64  }
    65  
    66  // ConfirmTx simulates a tx confirming.
    67  func (m *MockNotifier) ConfirmTx(txid *chainhash.Hash, height uint32) error {
    68  	confirm := &chainntnfs.TxConfirmation{
    69  		BlockHeight: height,
    70  	}
    71  	select {
    72  	case m.getConfChannel(txid) <- confirm:
    73  	case <-time.After(defaultTestTimeout):
    74  		return fmt.Errorf("confirmation not consumed")
    75  	}
    76  	return nil
    77  }
    78  
    79  // SpendOutpoint simulates a utxo being spent.
    80  func (m *MockNotifier) SpendOutpoint(outpoint wire.OutPoint,
    81  	spendingTx wire.MsgTx) {
    82  
    83  	log.Debugf("Spending outpoint %v", outpoint)
    84  
    85  	m.mutex.Lock()
    86  	defer m.mutex.Unlock()
    87  
    88  	channels, ok := m.spendChan[outpoint]
    89  	if ok {
    90  		for _, channel := range channels {
    91  			m.sendSpend(channel, &outpoint, &spendingTx)
    92  		}
    93  	}
    94  
    95  	m.spends[outpoint] = &spendingTx
    96  }
    97  
    98  func (m *MockNotifier) sendSpend(channel chan *chainntnfs.SpendDetail,
    99  	outpoint *wire.OutPoint,
   100  	spendingTx *wire.MsgTx) {
   101  
   102  	spenderTxHash := spendingTx.TxHash()
   103  	channel <- &chainntnfs.SpendDetail{
   104  		SpenderTxHash: &spenderTxHash,
   105  		SpendingTx:    spendingTx,
   106  		SpentOutPoint: outpoint,
   107  	}
   108  }
   109  
   110  // RegisterConfirmationsNtfn registers for tx confirm notifications.
   111  func (m *MockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash,
   112  	_ []byte, numConfs, heightHint uint32) (*chainntnfs.ConfirmationEvent,
   113  	error) {
   114  
   115  	return &chainntnfs.ConfirmationEvent{
   116  		Confirmed: m.getConfChannel(txid),
   117  	}, nil
   118  }
   119  
   120  func (m *MockNotifier) getConfChannel(
   121  	txid *chainhash.Hash) chan *chainntnfs.TxConfirmation {
   122  
   123  	m.mutex.Lock()
   124  	defer m.mutex.Unlock()
   125  
   126  	channel, ok := m.confChannel[*txid]
   127  	if ok {
   128  		return channel
   129  	}
   130  	channel = make(chan *chainntnfs.TxConfirmation)
   131  	m.confChannel[*txid] = channel
   132  
   133  	return channel
   134  }
   135  
   136  // RegisterBlockEpochNtfn registers a block notification.
   137  func (m *MockNotifier) RegisterBlockEpochNtfn(
   138  	bestBlock *chainntnfs.BlockEpoch) (*chainntnfs.BlockEpochEvent, error) {
   139  
   140  	log.Tracef("Mock block ntfn registered")
   141  
   142  	m.mutex.Lock()
   143  	epochChan := make(chan *chainntnfs.BlockEpoch, 1)
   144  
   145  	// The real notifier returns a notification with the current block hash
   146  	// and height immediately if no best block hash or height is specified
   147  	// in the request. We want to emulate this behaviour as well for the
   148  	// mock.
   149  	switch {
   150  	case bestBlock == nil:
   151  		epochChan <- &chainntnfs.BlockEpoch{
   152  			Hash:   mockChainHash,
   153  			Height: mockChainHeight,
   154  		}
   155  		m.epochChan[epochChan] = mockChainHeight
   156  	default:
   157  		m.epochChan[epochChan] = bestBlock.Height
   158  	}
   159  	m.mutex.Unlock()
   160  
   161  	return &chainntnfs.BlockEpochEvent{
   162  		Epochs: epochChan,
   163  		Cancel: func() {
   164  			log.Tracef("Mock block ntfn canceled")
   165  			m.mutex.Lock()
   166  			delete(m.epochChan, epochChan)
   167  			m.mutex.Unlock()
   168  		},
   169  	}, nil
   170  }
   171  
   172  // Start the notifier.
   173  func (m *MockNotifier) Start() error {
   174  	return nil
   175  }
   176  
   177  // Started checks if started
   178  func (m *MockNotifier) Started() bool {
   179  	return true
   180  }
   181  
   182  // Stop the notifier.
   183  func (m *MockNotifier) Stop() error {
   184  	return nil
   185  }
   186  
   187  // RegisterSpendNtfn registers for spend notifications.
   188  func (m *MockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint,
   189  	_ []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {
   190  
   191  	// Add channel to global spend ntfn map.
   192  	m.mutex.Lock()
   193  
   194  	channels, ok := m.spendChan[*outpoint]
   195  	if !ok {
   196  		channels = make([]chan *chainntnfs.SpendDetail, 0)
   197  	}
   198  
   199  	channel := make(chan *chainntnfs.SpendDetail, 1)
   200  	channels = append(channels, channel)
   201  	m.spendChan[*outpoint] = channels
   202  
   203  	// Check if this output has already been spent.
   204  	spendingTx, spent := m.spends[*outpoint]
   205  
   206  	m.mutex.Unlock()
   207  
   208  	// If output has been spent already, signal now. Do this outside the
   209  	// lock to prevent a deadlock.
   210  	if spent {
   211  		m.sendSpend(channel, outpoint, spendingTx)
   212  	}
   213  
   214  	return &chainntnfs.SpendEvent{
   215  		Spend: channel,
   216  		Cancel: func() {
   217  			log.Infof("Cancelling RegisterSpendNtfn for %v",
   218  				outpoint)
   219  
   220  			m.mutex.Lock()
   221  			defer m.mutex.Unlock()
   222  			channels := m.spendChan[*outpoint]
   223  			for i, c := range channels {
   224  				if c == channel {
   225  					channels[i] = channels[len(channels)-1]
   226  					m.spendChan[*outpoint] =
   227  						channels[:len(channels)-1]
   228  				}
   229  			}
   230  
   231  			close(channel)
   232  
   233  			log.Infof("Spend ntfn channel closed for %v",
   234  				outpoint)
   235  		},
   236  	}, nil
   237  }