github.com/decred/dcrlnd@v0.7.6/peer/test_utils.go (about)

     1  package peer
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	"encoding/binary"
     7  	"io"
     8  	"io/ioutil"
     9  	"math/rand"
    10  	"net"
    11  	"os"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/decred/dcrd/chaincfg/chainhash"
    16  	"github.com/decred/dcrd/chaincfg/v3"
    17  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    18  	"github.com/decred/dcrd/dcrutil/v4"
    19  	"github.com/decred/dcrd/wire"
    20  	"github.com/decred/dcrlnd/chainntnfs"
    21  	"github.com/decred/dcrlnd/channeldb"
    22  	"github.com/decred/dcrlnd/htlcswitch"
    23  	"github.com/decred/dcrlnd/input"
    24  	"github.com/decred/dcrlnd/keychain"
    25  	"github.com/decred/dcrlnd/lntest/channels"
    26  	"github.com/decred/dcrlnd/lntest/mock"
    27  	"github.com/decred/dcrlnd/lnwallet"
    28  	"github.com/decred/dcrlnd/lnwallet/chainfee"
    29  	"github.com/decred/dcrlnd/lnwire"
    30  	"github.com/decred/dcrlnd/netann"
    31  	"github.com/decred/dcrlnd/queue"
    32  	"github.com/decred/dcrlnd/shachain"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  const (
    37  	broadcastHeight = 100
    38  
    39  	// timeout is a timeout value to use for tests which need to wait for
    40  	// a return value on a channel.
    41  	timeout = time.Second * 5
    42  )
    43  
    44  var (
    45  	// Just use some arbitrary bytes as delivery script.
    46  	dummyDeliveryScript = channels.AlicesPrivKey
    47  
    48  	testKeyLoc = keychain.KeyLocator{Family: keychain.KeyFamilyNodeKey}
    49  )
    50  
    51  // noUpdate is a function which can be used as a parameter in createTestPeer to
    52  // call the setup code with no custom values on the channels set up.
    53  var noUpdate = func(a, b *channeldb.OpenChannel) {}
    54  
    55  // createTestPeer creates a channel between two nodes, and returns a peer for
    56  // one of the nodes, together with the channel seen from both nodes. It takes
    57  // an updateChan function which can be used to modify the default values on
    58  // the channel states for each peer.
    59  func createTestPeer(notifier chainntnfs.ChainNotifier,
    60  	publTx chan *wire.MsgTx, updateChan func(a, b *channeldb.OpenChannel),
    61  	mockSwitch *mockMessageSwitch) (
    62  	*Brontide, *lnwallet.LightningChannel, func(), error) {
    63  
    64  	chainParams := chaincfg.RegNetParams()
    65  
    66  	nodeKeyLocator := keychain.KeyLocator{
    67  		Family: keychain.KeyFamilyNodeKey,
    68  	}
    69  	aliceKeyPriv, aliceKeyPub := channels.PrivKeyFromBytes(channels.AlicesPrivKey)
    70  	aliceKeySigner := keychain.NewPrivKeyMessageSigner(aliceKeyPriv, nodeKeyLocator)
    71  	bobKeyPriv, bobKeyPub := channels.PrivKeyFromBytes(channels.BobsPrivKey)
    72  
    73  	channelCapacity := dcrutil.Amount(10 * 1e8)
    74  	channelBal := channelCapacity / 2
    75  	aliceDustLimit := dcrutil.Amount(200)
    76  	bobDustLimit := dcrutil.Amount(1300)
    77  	csvTimeoutAlice := uint32(5)
    78  	csvTimeoutBob := uint32(4)
    79  	isAliceInitiator := true
    80  
    81  	prevOut := &wire.OutPoint{
    82  		Hash:  channels.TestHdSeed,
    83  		Index: 0,
    84  	}
    85  	fundingTxIn := wire.NewTxIn(prevOut, 0, nil) // TODO(decred): Need correct input value
    86  
    87  	aliceCfg := channeldb.ChannelConfig{
    88  		ChannelConstraints: channeldb.ChannelConstraints{
    89  			DustLimit:        aliceDustLimit,
    90  			MaxPendingAmount: lnwire.MilliAtom(rand.Int63()),
    91  			ChanReserve:      dcrutil.Amount(rand.Int63()),
    92  			MinHTLC:          lnwire.MilliAtom(rand.Int63()),
    93  			MaxAcceptedHtlcs: uint16(rand.Int31()),
    94  			CsvDelay:         uint16(csvTimeoutAlice),
    95  		},
    96  		MultiSigKey: keychain.KeyDescriptor{
    97  			PubKey: aliceKeyPub,
    98  		},
    99  		RevocationBasePoint: keychain.KeyDescriptor{
   100  			PubKey: aliceKeyPub,
   101  		},
   102  		PaymentBasePoint: keychain.KeyDescriptor{
   103  			PubKey: aliceKeyPub,
   104  		},
   105  		DelayBasePoint: keychain.KeyDescriptor{
   106  			PubKey: aliceKeyPub,
   107  		},
   108  		HtlcBasePoint: keychain.KeyDescriptor{
   109  			PubKey: aliceKeyPub,
   110  		},
   111  	}
   112  	bobCfg := channeldb.ChannelConfig{
   113  		ChannelConstraints: channeldb.ChannelConstraints{
   114  			DustLimit:        bobDustLimit,
   115  			MaxPendingAmount: lnwire.MilliAtom(rand.Int63()),
   116  			ChanReserve:      dcrutil.Amount(rand.Int63()),
   117  			MinHTLC:          lnwire.MilliAtom(rand.Int63()),
   118  			MaxAcceptedHtlcs: uint16(rand.Int31()),
   119  			CsvDelay:         uint16(csvTimeoutBob),
   120  		},
   121  		MultiSigKey: keychain.KeyDescriptor{
   122  			PubKey: bobKeyPub,
   123  		},
   124  		RevocationBasePoint: keychain.KeyDescriptor{
   125  			PubKey: bobKeyPub,
   126  		},
   127  		PaymentBasePoint: keychain.KeyDescriptor{
   128  			PubKey: bobKeyPub,
   129  		},
   130  		DelayBasePoint: keychain.KeyDescriptor{
   131  			PubKey: bobKeyPub,
   132  		},
   133  		HtlcBasePoint: keychain.KeyDescriptor{
   134  			PubKey: bobKeyPub,
   135  		},
   136  	}
   137  
   138  	bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
   139  	if err != nil {
   140  		return nil, nil, nil, err
   141  	}
   142  	bobPreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*bobRoot))
   143  	bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
   144  	if err != nil {
   145  		return nil, nil, nil, err
   146  	}
   147  	bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
   148  
   149  	aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
   150  	if err != nil {
   151  		return nil, nil, nil, err
   152  	}
   153  	alicePreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*aliceRoot))
   154  	aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
   155  	if err != nil {
   156  		return nil, nil, nil, err
   157  	}
   158  	aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
   159  
   160  	aliceCommitTx, bobCommitTx, err := lnwallet.CreateCommitmentTxns(
   161  		channelBal, channelBal, &aliceCfg, &bobCfg, aliceCommitPoint,
   162  		bobCommitPoint, *fundingTxIn, channeldb.SingleFunderTweaklessBit,
   163  		isAliceInitiator, 0, chainParams,
   164  	)
   165  	if err != nil {
   166  		return nil, nil, nil, err
   167  	}
   168  
   169  	alicePath, err := ioutil.TempDir("", "alicedb")
   170  	if err != nil {
   171  		return nil, nil, nil, err
   172  	}
   173  
   174  	dbAlice, err := channeldb.Open(alicePath)
   175  	if err != nil {
   176  		return nil, nil, nil, err
   177  	}
   178  
   179  	bobPath, err := ioutil.TempDir("", "bobdb")
   180  	if err != nil {
   181  		return nil, nil, nil, err
   182  	}
   183  
   184  	dbBob, err := channeldb.Open(bobPath)
   185  	if err != nil {
   186  		return nil, nil, nil, err
   187  	}
   188  
   189  	estimator := chainfee.NewStaticEstimator(12500, 0)
   190  	feePerKB, err := estimator.EstimateFeePerKB(1)
   191  	if err != nil {
   192  		return nil, nil, nil, err
   193  	}
   194  
   195  	// TODO(roasbeef): need to factor in commit fee?
   196  	aliceCommit := channeldb.ChannelCommitment{
   197  		CommitHeight:  0,
   198  		LocalBalance:  lnwire.NewMAtomsFromAtoms(channelBal),
   199  		RemoteBalance: lnwire.NewMAtomsFromAtoms(channelBal),
   200  		FeePerKB:      dcrutil.Amount(feePerKB),
   201  		CommitFee:     feePerKB.FeeForSize(input.CommitmentTxSize),
   202  		CommitTx:      aliceCommitTx,
   203  		CommitSig:     bytes.Repeat([]byte{1}, 71),
   204  	}
   205  	bobCommit := channeldb.ChannelCommitment{
   206  		CommitHeight:  0,
   207  		LocalBalance:  lnwire.NewMAtomsFromAtoms(channelBal),
   208  		RemoteBalance: lnwire.NewMAtomsFromAtoms(channelBal),
   209  		FeePerKB:      dcrutil.Amount(feePerKB),
   210  		CommitFee:     feePerKB.FeeForSize(input.CommitmentTxSize),
   211  		CommitTx:      bobCommitTx,
   212  		CommitSig:     bytes.Repeat([]byte{1}, 71),
   213  	}
   214  
   215  	var chanIDBytes [8]byte
   216  	if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil {
   217  		return nil, nil, nil, err
   218  	}
   219  
   220  	shortChanID := lnwire.NewShortChanIDFromInt(
   221  		binary.BigEndian.Uint64(chanIDBytes[:]),
   222  	)
   223  
   224  	aliceChannelState := &channeldb.OpenChannel{
   225  		LocalChanCfg:            aliceCfg,
   226  		RemoteChanCfg:           bobCfg,
   227  		IdentityPub:             aliceKeyPub,
   228  		FundingOutpoint:         *prevOut,
   229  		ShortChannelID:          shortChanID,
   230  		ChanType:                channeldb.SingleFunderTweaklessBit,
   231  		IsInitiator:             isAliceInitiator,
   232  		Capacity:                channelCapacity,
   233  		RemoteCurrentRevocation: bobCommitPoint,
   234  		RevocationProducer:      alicePreimageProducer,
   235  		RevocationStore:         shachain.NewRevocationStore(),
   236  		LocalCommitment:         aliceCommit,
   237  		RemoteCommitment:        aliceCommit,
   238  		Db:                      dbAlice.ChannelStateDB(),
   239  		Packager:                channeldb.NewChannelPackager(shortChanID),
   240  		FundingTxn:              channels.TestFundingTx,
   241  	}
   242  	bobChannelState := &channeldb.OpenChannel{
   243  		LocalChanCfg:            bobCfg,
   244  		RemoteChanCfg:           aliceCfg,
   245  		IdentityPub:             bobKeyPub,
   246  		FundingOutpoint:         *prevOut,
   247  		ChanType:                channeldb.SingleFunderTweaklessBit,
   248  		IsInitiator:             !isAliceInitiator,
   249  		Capacity:                channelCapacity,
   250  		RemoteCurrentRevocation: aliceCommitPoint,
   251  		RevocationProducer:      bobPreimageProducer,
   252  		RevocationStore:         shachain.NewRevocationStore(),
   253  		LocalCommitment:         bobCommit,
   254  		RemoteCommitment:        bobCommit,
   255  		Db:                      dbBob.ChannelStateDB(),
   256  		Packager:                channeldb.NewChannelPackager(shortChanID),
   257  	}
   258  
   259  	// Set custom values on the channel states.
   260  	updateChan(aliceChannelState, bobChannelState)
   261  
   262  	aliceAddr := &net.TCPAddr{
   263  		IP:   net.ParseIP("127.0.0.1"),
   264  		Port: 18555,
   265  	}
   266  
   267  	if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil {
   268  		return nil, nil, nil, err
   269  	}
   270  
   271  	bobAddr := &net.TCPAddr{
   272  		IP:   net.ParseIP("127.0.0.1"),
   273  		Port: 18556,
   274  	}
   275  
   276  	if err := bobChannelState.SyncPending(bobAddr, 0); err != nil {
   277  		return nil, nil, nil, err
   278  	}
   279  
   280  	cleanUpFunc := func() {
   281  		os.RemoveAll(bobPath)
   282  		os.RemoveAll(alicePath)
   283  	}
   284  
   285  	aliceSigner := &mock.SingleSigner{Privkey: aliceKeyPriv}
   286  	bobSigner := &mock.SingleSigner{Privkey: bobKeyPriv}
   287  
   288  	alicePool := lnwallet.NewSigPool(1, aliceSigner)
   289  	channelAlice, err := lnwallet.NewLightningChannel(
   290  		aliceSigner, aliceChannelState, alicePool, chainParams,
   291  	)
   292  	if err != nil {
   293  		return nil, nil, nil, err
   294  	}
   295  	_ = alicePool.Start()
   296  
   297  	bobPool := lnwallet.NewSigPool(1, bobSigner)
   298  	channelBob, err := lnwallet.NewLightningChannel(
   299  		bobSigner, bobChannelState, bobPool, chainParams,
   300  	)
   301  	if err != nil {
   302  		return nil, nil, nil, err
   303  	}
   304  	_ = bobPool.Start()
   305  
   306  	chainIO := &mock.ChainIO{
   307  		BestHeight: broadcastHeight,
   308  	}
   309  	wallet := &lnwallet.LightningWallet{
   310  		WalletController: &mock.WalletController{
   311  			RootKey:               aliceKeyPriv,
   312  			PublishedTransactions: publTx,
   313  		},
   314  	}
   315  
   316  	// If mockSwitch is not set by the caller, set it to the default as the
   317  	// caller does not need to control it.
   318  	if mockSwitch == nil {
   319  		mockSwitch = &mockMessageSwitch{}
   320  	}
   321  
   322  	nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner)
   323  
   324  	const chanActiveTimeout = time.Minute
   325  
   326  	chanStatusMgr, err := netann.NewChanStatusManager(&netann.ChanStatusConfig{
   327  		ChanStatusSampleInterval: 30 * time.Second,
   328  		ChanEnableTimeout:        chanActiveTimeout,
   329  		ChanDisableTimeout:       2 * time.Minute,
   330  		DB:                       dbAlice.ChannelStateDB(),
   331  		Graph:                    dbAlice.ChannelGraph(),
   332  		MessageSigner:            nodeSignerAlice,
   333  		OurPubKey:                aliceKeyPub,
   334  		OurKeyLoc:                testKeyLoc,
   335  		IsChannelActive:          func(lnwire.ChannelID) bool { return true },
   336  		ApplyChannelUpdate:       func(*lnwire.ChannelUpdate) error { return nil },
   337  	})
   338  	if err != nil {
   339  		return nil, nil, nil, err
   340  	}
   341  	if err = chanStatusMgr.Start(); err != nil {
   342  		return nil, nil, nil, err
   343  	}
   344  
   345  	errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize)
   346  	if err != nil {
   347  		return nil, nil, nil, err
   348  	}
   349  
   350  	var pubKey [33]byte
   351  	copy(pubKey[:], aliceKeyPub.SerializeCompressed())
   352  
   353  	cfgAddr := &lnwire.NetAddress{
   354  		IdentityKey: aliceKeyPub,
   355  		Address:     aliceAddr,
   356  		ChainNet:    wire.SimNet,
   357  	}
   358  
   359  	cfg := &Config{
   360  		Addr:        cfgAddr,
   361  		PubKeyBytes: pubKey,
   362  		ErrorBuffer: errBuffer,
   363  		ChainIO:     chainIO,
   364  		Switch:      mockSwitch,
   365  
   366  		ChanActiveTimeout: chanActiveTimeout,
   367  		InterceptSwitch:   htlcswitch.NewInterceptableSwitch(nil),
   368  
   369  		ChannelDB:      dbAlice.ChannelStateDB(),
   370  		FeeEstimator:   estimator,
   371  		Wallet:         wallet,
   372  		ChainNotifier:  notifier,
   373  		ChanStatusMgr:  chanStatusMgr,
   374  		DisconnectPeer: func(b *secp256k1.PublicKey) error { return nil },
   375  		ChainParams:    chainParams,
   376  	}
   377  
   378  	alicePeer := NewBrontide(*cfg)
   379  
   380  	chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint())
   381  	alicePeer.activeChannels[chanID] = channelAlice
   382  
   383  	alicePeer.wg.Add(1)
   384  	go alicePeer.channelManager()
   385  
   386  	return alicePeer, channelBob, cleanUpFunc, nil
   387  }
   388  
   389  // mockMessageSwitch is a mock implementation of the messageSwitch interface
   390  // used for testing without relying on a *htlcswitch.Switch in unit tests.
   391  type mockMessageSwitch struct {
   392  	links []htlcswitch.ChannelUpdateHandler
   393  }
   394  
   395  // BestHeight currently returns a dummy value.
   396  func (m *mockMessageSwitch) BestHeight() uint32 {
   397  	return 0
   398  }
   399  
   400  // CircuitModifier currently returns a dummy value.
   401  func (m *mockMessageSwitch) CircuitModifier() htlcswitch.CircuitModifier {
   402  	return nil
   403  }
   404  
   405  // RemoveLink currently does nothing.
   406  func (m *mockMessageSwitch) RemoveLink(cid lnwire.ChannelID) {}
   407  
   408  // CreateAndAddLink currently returns a dummy value.
   409  func (m *mockMessageSwitch) CreateAndAddLink(cfg htlcswitch.ChannelLinkConfig,
   410  	lnChan *lnwallet.LightningChannel) error {
   411  
   412  	return nil
   413  }
   414  
   415  // GetLinksByInterface returns the active links.
   416  func (m *mockMessageSwitch) GetLinksByInterface(pub [33]byte) (
   417  	[]htlcswitch.ChannelUpdateHandler, error) {
   418  
   419  	return m.links, nil
   420  }
   421  
   422  // mockUpdateHandler is a mock implementation of the ChannelUpdateHandler
   423  // interface. It is used in mockMessageSwitch's GetLinksByInterface method.
   424  type mockUpdateHandler struct {
   425  	cid lnwire.ChannelID
   426  }
   427  
   428  // newMockUpdateHandler creates a new mockUpdateHandler.
   429  func newMockUpdateHandler(cid lnwire.ChannelID) *mockUpdateHandler {
   430  	return &mockUpdateHandler{
   431  		cid: cid,
   432  	}
   433  }
   434  
   435  // HandleChannelUpdate currently does nothing.
   436  func (m *mockUpdateHandler) HandleChannelUpdate(msg lnwire.Message) {}
   437  
   438  // ChanID returns the mockUpdateHandler's cid.
   439  func (m *mockUpdateHandler) ChanID() lnwire.ChannelID { return m.cid }
   440  
   441  // Bandwidth currently returns a dummy value.
   442  func (m *mockUpdateHandler) Bandwidth() lnwire.MilliAtom { return 0 }
   443  
   444  // EligibleToForward currently returns a dummy value.
   445  func (m *mockUpdateHandler) EligibleToForward() bool { return false }
   446  
   447  // MayAddOutgoingHtlc currently returns nil.
   448  func (m *mockUpdateHandler) MayAddOutgoingHtlc(lnwire.MilliAtom) error { return nil }
   449  
   450  // ShutdownIfChannelClean currently returns nil.
   451  func (m *mockUpdateHandler) ShutdownIfChannelClean() error { return nil }
   452  
   453  type mockMessageConn struct {
   454  	t *testing.T
   455  
   456  	// MessageConn embeds our interface so that the mock does not need to
   457  	// implement every function. The mock will panic if an unspecified function
   458  	// is called.
   459  	MessageConn
   460  
   461  	// writtenMessages is a channel that our mock pushes written messages into.
   462  	writtenMessages chan []byte
   463  
   464  	readMessages   chan []byte
   465  	curReadMessage []byte
   466  }
   467  
   468  func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
   469  	return &mockMessageConn{
   470  		t:               t,
   471  		writtenMessages: make(chan []byte, expectedMessages),
   472  		readMessages:    make(chan []byte, 1),
   473  	}
   474  }
   475  
   476  // SetWriteDeadline mocks setting write deadline for our conn.
   477  func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
   478  	return nil
   479  }
   480  
   481  // Flush mocks a message conn flush.
   482  func (m *mockMessageConn) Flush() (int, error) {
   483  	return 0, nil
   484  }
   485  
   486  // WriteMessage mocks sending of a message on our connection. It will push
   487  // the bytes sent into the mock's writtenMessages channel.
   488  func (m *mockMessageConn) WriteMessage(msg []byte) error {
   489  	select {
   490  	case m.writtenMessages <- msg:
   491  	case <-time.After(timeout):
   492  		m.t.Fatalf("timeout sending message: %v", msg)
   493  	}
   494  
   495  	return nil
   496  }
   497  
   498  // assertWrite asserts that our mock as had WriteMessage called with the byte
   499  // slice we expect.
   500  func (m *mockMessageConn) assertWrite(expected []byte) {
   501  	select {
   502  	case actual := <-m.writtenMessages:
   503  		require.Equal(m.t, expected, actual)
   504  
   505  	case <-time.After(timeout):
   506  		m.t.Fatalf("timeout waiting for write: %v", expected)
   507  	}
   508  }
   509  
   510  func (m *mockMessageConn) SetReadDeadline(t time.Time) error {
   511  	return nil
   512  }
   513  
   514  func (m *mockMessageConn) ReadNextHeader() (uint32, error) {
   515  	m.curReadMessage = <-m.readMessages
   516  	return uint32(len(m.curReadMessage)), nil
   517  }
   518  
   519  func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) {
   520  	return m.curReadMessage, nil
   521  }
   522  
   523  func (m *mockMessageConn) RemoteAddr() net.Addr {
   524  	return nil
   525  }
   526  
   527  func (m *mockMessageConn) LocalAddr() net.Addr {
   528  	return nil
   529  }