github.com/decred/dcrlnd@v0.7.6/watchtower/wtclient/client_test.go (about)

     1  package wtclient_test
     2  
     3  import (
     4  	"encoding/binary"
     5  	"net"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/decred/dcrd/chaincfg/v3"
    11  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    12  	"github.com/decred/dcrd/txscript/v4"
    13  	"github.com/decred/dcrd/txscript/v4/stdaddr"
    14  	"github.com/decred/dcrd/wire"
    15  	"github.com/decred/dcrlnd/channeldb"
    16  	"github.com/decred/dcrlnd/input"
    17  	"github.com/decred/dcrlnd/keychain"
    18  	"github.com/decred/dcrlnd/lnwallet"
    19  	"github.com/decred/dcrlnd/lnwire"
    20  	"github.com/decred/dcrlnd/tor"
    21  	"github.com/decred/dcrlnd/watchtower/blob"
    22  	"github.com/decred/dcrlnd/watchtower/wtclient"
    23  	"github.com/decred/dcrlnd/watchtower/wtdb"
    24  	"github.com/decred/dcrlnd/watchtower/wtmock"
    25  	"github.com/decred/dcrlnd/watchtower/wtpolicy"
    26  	"github.com/decred/dcrlnd/watchtower/wtserver"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  const (
    31  	towerAddrStr = "18.28.243.2:9911"
    32  )
    33  
    34  var (
    35  	// addr is the server's reward address given to watchtower clients.
    36  	addr, _ = stdaddr.DecodeAddress(
    37  		"Tsi6gGYNSMmFwi7JoL5Li39SrERZTTMu6vY",
    38  		chaincfg.TestNet3Params(),
    39  	)
    40  
    41  	addrScript, _ = input.PayToAddrScript(addr)
    42  )
    43  
    44  // randPrivKey generates a new secp keypair, and returns the public key.
    45  func randPrivKey(t *testing.T) *secp256k1.PrivateKey {
    46  	t.Helper()
    47  
    48  	sk, err := secp256k1.GeneratePrivateKey()
    49  	if err != nil {
    50  		t.Fatalf("unable to generate pubkey: %v", err)
    51  	}
    52  
    53  	return sk
    54  }
    55  
    56  type mockNet struct {
    57  	mu           sync.RWMutex
    58  	connCallback func(wtserver.Peer)
    59  }
    60  
    61  func newMockNet(cb func(wtserver.Peer)) *mockNet {
    62  	return &mockNet{
    63  		connCallback: cb,
    64  	}
    65  }
    66  
    67  func (m *mockNet) Dial(network string, address string,
    68  	timeout time.Duration) (net.Conn, error) {
    69  
    70  	return nil, nil
    71  }
    72  
    73  func (m *mockNet) LookupHost(host string) ([]string, error) {
    74  	panic("not implemented")
    75  }
    76  
    77  func (m *mockNet) LookupSRV(service string, proto string, name string) (string, []*net.SRV, error) {
    78  	panic("not implemented")
    79  }
    80  
    81  func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr, error) {
    82  	panic("not implemented")
    83  }
    84  
    85  func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
    86  	netAddr *lnwire.NetAddress,
    87  	dialer tor.DialFunc) (wtserver.Peer, error) {
    88  
    89  	localPk := local.PubKey()
    90  	localAddr := &net.TCPAddr{
    91  		IP:   net.IP{0x32, 0x31, 0x30, 0x29},
    92  		Port: 36723,
    93  	}
    94  
    95  	localPeer, remotePeer := wtmock.NewMockConn(
    96  		localPk, netAddr.IdentityKey, localAddr, netAddr.Address, 0,
    97  	)
    98  
    99  	m.mu.RLock()
   100  	m.connCallback(remotePeer)
   101  	m.mu.RUnlock()
   102  
   103  	return localPeer, nil
   104  }
   105  
   106  func (m *mockNet) setConnCallback(cb func(wtserver.Peer)) {
   107  	m.mu.Lock()
   108  	defer m.mu.Unlock()
   109  	m.connCallback = cb
   110  }
   111  
   112  type mockChannel struct {
   113  	mu            sync.Mutex
   114  	commitHeight  uint64
   115  	retributions  map[uint64]*lnwallet.BreachRetribution
   116  	localBalance  lnwire.MilliAtom
   117  	remoteBalance lnwire.MilliAtom
   118  
   119  	revSK     *secp256k1.PrivateKey
   120  	revPK     *secp256k1.PublicKey
   121  	revKeyLoc keychain.KeyLocator
   122  
   123  	toRemoteSK     *secp256k1.PrivateKey
   124  	toRemotePK     *secp256k1.PublicKey
   125  	toRemoteKeyLoc keychain.KeyLocator
   126  
   127  	toLocalPK *secp256k1.PublicKey // only need to generate to-local script
   128  
   129  	dustLimit lnwire.MilliAtom
   130  	csvDelay  uint32
   131  }
   132  
   133  func newMockChannel(t *testing.T, signer *wtmock.MockSigner,
   134  	localAmt, remoteAmt lnwire.MilliAtom) *mockChannel {
   135  
   136  	// Generate the revocation, to-local, and to-remote keypairs.
   137  	revSK := randPrivKey(t)
   138  	revPK := revSK.PubKey()
   139  
   140  	toLocalSK := randPrivKey(t)
   141  	toLocalPK := toLocalSK.PubKey()
   142  
   143  	toRemoteSK := randPrivKey(t)
   144  	toRemotePK := toRemoteSK.PubKey()
   145  
   146  	// Register the revocation secret key and the to-remote secret key with
   147  	// the signer. We will not need to sign with the to-local key, as this
   148  	// is to be known only by the counterparty.
   149  	revKeyLoc := signer.AddPrivKey(revSK)
   150  	toRemoteKeyLoc := signer.AddPrivKey(toRemoteSK)
   151  
   152  	c := &mockChannel{
   153  		retributions:   make(map[uint64]*lnwallet.BreachRetribution),
   154  		localBalance:   localAmt,
   155  		remoteBalance:  remoteAmt,
   156  		revSK:          revSK,
   157  		revPK:          revPK,
   158  		revKeyLoc:      revKeyLoc,
   159  		toLocalPK:      toLocalPK,
   160  		toRemoteSK:     toRemoteSK,
   161  		toRemotePK:     toRemotePK,
   162  		toRemoteKeyLoc: toRemoteKeyLoc,
   163  		dustLimit:      546000,
   164  		csvDelay:       144,
   165  	}
   166  
   167  	// Create the initial remote commitment with the initial balances.
   168  	c.createRemoteCommitTx(t)
   169  
   170  	return c
   171  }
   172  
   173  func (c *mockChannel) createRemoteCommitTx(t *testing.T) {
   174  	t.Helper()
   175  
   176  	// Construct the to-local witness script.
   177  	toLocalScript, err := input.CommitScriptToSelf(
   178  		c.csvDelay, c.toLocalPK, c.revPK,
   179  	)
   180  	if err != nil {
   181  		t.Fatalf("unable to create to-local script: %v", err)
   182  	}
   183  
   184  	// Compute the to-local witness script hash.
   185  	toLocalScriptHash, err := input.ScriptHashPkScript(toLocalScript)
   186  	if err != nil {
   187  		t.Fatalf("unable to create to-local witness script hash: %v", err)
   188  	}
   189  
   190  	// Compute the to-remote witness script hash.
   191  	toRemoteScriptHash, err := input.CommitScriptUnencumbered(c.toRemotePK)
   192  	if err != nil {
   193  		t.Fatalf("unable to create to-remote script: %v", err)
   194  	}
   195  
   196  	// Construct the remote commitment txn, containing the to-local and
   197  	// to-remote outputs. The balances are flipped since the transaction is
   198  	// from the PoV of the remote party. We don't need any inputs for this
   199  	// test. We increment the version with the commit height to ensure that
   200  	// all commitment transactions are unique even if the same distribution
   201  	// of funds is used more than once.
   202  	commitTxn := &wire.MsgTx{
   203  		Version: uint16(c.commitHeight + 2),
   204  	}
   205  
   206  	var (
   207  		toLocalSignDesc  *input.SignDescriptor
   208  		toRemoteSignDesc *input.SignDescriptor
   209  	)
   210  
   211  	var outputIndex int
   212  	if c.remoteBalance >= c.dustLimit {
   213  		commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
   214  			Value:    int64(c.remoteBalance.ToAtoms()),
   215  			PkScript: toLocalScriptHash,
   216  		})
   217  
   218  		// Create the sign descriptor used to sign for the to-local
   219  		// input.
   220  		toLocalSignDesc = &input.SignDescriptor{
   221  			KeyDesc: keychain.KeyDescriptor{
   222  				KeyLocator: c.revKeyLoc,
   223  				PubKey:     c.revPK,
   224  			},
   225  			WitnessScript: toLocalScript,
   226  			Output:        commitTxn.TxOut[outputIndex],
   227  			HashType:      txscript.SigHashAll,
   228  		}
   229  		outputIndex++
   230  	}
   231  	if c.localBalance >= c.dustLimit {
   232  		commitTxn.TxOut = append(commitTxn.TxOut, &wire.TxOut{
   233  			Value:    int64(c.localBalance.ToAtoms()),
   234  			PkScript: toRemoteScriptHash,
   235  		})
   236  
   237  		// Create the sign descriptor used to sign for the to-remote
   238  		// input.
   239  		toRemoteSignDesc = &input.SignDescriptor{
   240  			KeyDesc: keychain.KeyDescriptor{
   241  				KeyLocator: c.toRemoteKeyLoc,
   242  				PubKey:     c.toRemotePK,
   243  			},
   244  			WitnessScript: toRemoteScriptHash,
   245  			Output:        commitTxn.TxOut[outputIndex],
   246  			HashType:      txscript.SigHashAll,
   247  		}
   248  	}
   249  
   250  	txid := commitTxn.TxHash()
   251  
   252  	var (
   253  		toLocalOutPoint  wire.OutPoint
   254  		toRemoteOutPoint wire.OutPoint
   255  	)
   256  
   257  	outputIndex = 0
   258  	if toLocalSignDesc != nil {
   259  		toLocalOutPoint = wire.OutPoint{
   260  			Hash:  txid,
   261  			Index: uint32(outputIndex),
   262  		}
   263  		outputIndex++
   264  	}
   265  	if toRemoteSignDesc != nil {
   266  		toRemoteOutPoint = wire.OutPoint{
   267  			Hash:  txid,
   268  			Index: uint32(outputIndex),
   269  		}
   270  	}
   271  
   272  	commitKeyRing := &lnwallet.CommitmentKeyRing{
   273  		RevocationKey: c.revPK,
   274  		ToRemoteKey:   c.toLocalPK,
   275  		ToLocalKey:    c.toRemotePK,
   276  	}
   277  
   278  	retribution := &lnwallet.BreachRetribution{
   279  		BreachTransaction:    commitTxn,
   280  		RevokedStateNum:      c.commitHeight,
   281  		KeyRing:              commitKeyRing,
   282  		RemoteDelay:          c.csvDelay,
   283  		LocalOutpoint:        toRemoteOutPoint,
   284  		LocalOutputSignDesc:  toRemoteSignDesc,
   285  		RemoteOutpoint:       toLocalOutPoint,
   286  		RemoteOutputSignDesc: toLocalSignDesc,
   287  	}
   288  
   289  	c.retributions[c.commitHeight] = retribution
   290  	c.commitHeight++
   291  }
   292  
   293  // advanceState creates the next channel state and retribution without altering
   294  // channel balances.
   295  func (c *mockChannel) advanceState(t *testing.T) {
   296  	c.mu.Lock()
   297  	defer c.mu.Unlock()
   298  
   299  	c.createRemoteCommitTx(t)
   300  }
   301  
   302  // sendPayment creates the next channel state and retribution after transferring
   303  // amt to the remote party.
   304  func (c *mockChannel) sendPayment(t *testing.T, amt lnwire.MilliAtom) {
   305  	t.Helper()
   306  
   307  	c.mu.Lock()
   308  	defer c.mu.Unlock()
   309  
   310  	if c.localBalance < amt {
   311  		t.Fatalf("insufficient funds to send, need: %v, have: %v",
   312  			amt, c.localBalance)
   313  	}
   314  
   315  	c.localBalance -= amt
   316  	c.remoteBalance += amt
   317  	c.createRemoteCommitTx(t)
   318  }
   319  
   320  // receivePayment creates the next channel state and retribution after
   321  // transferring amt to the local party.
   322  func (c *mockChannel) receivePayment(t *testing.T, amt lnwire.MilliAtom) {
   323  	t.Helper()
   324  
   325  	c.mu.Lock()
   326  	defer c.mu.Unlock()
   327  
   328  	if c.remoteBalance < amt {
   329  		t.Fatalf("insufficient funds to recv, need: %v, have: %v",
   330  			amt, c.remoteBalance)
   331  	}
   332  
   333  	c.localBalance += amt
   334  	c.remoteBalance -= amt
   335  	c.createRemoteCommitTx(t)
   336  }
   337  
   338  // getState retrieves the channel's commitment and retribution at state i.
   339  func (c *mockChannel) getState(i uint64) (*wire.MsgTx, *lnwallet.BreachRetribution) {
   340  	c.mu.Lock()
   341  	defer c.mu.Unlock()
   342  
   343  	retribution := c.retributions[i]
   344  
   345  	return retribution.BreachTransaction, retribution
   346  }
   347  
   348  type testHarness struct {
   349  	t          *testing.T
   350  	cfg        harnessCfg
   351  	signer     *wtmock.MockSigner
   352  	capacity   lnwire.MilliAtom
   353  	clientDB   *wtmock.ClientDB
   354  	clientCfg  *wtclient.Config
   355  	client     wtclient.Client
   356  	serverAddr *lnwire.NetAddress
   357  	serverDB   *wtmock.TowerDB
   358  	serverCfg  *wtserver.Config
   359  	server     *wtserver.Server
   360  	net        *mockNet
   361  
   362  	mu       sync.Mutex
   363  	channels map[lnwire.ChannelID]*mockChannel
   364  }
   365  
   366  type harnessCfg struct {
   367  	localBalance       lnwire.MilliAtom
   368  	remoteBalance      lnwire.MilliAtom
   369  	policy             wtpolicy.Policy
   370  	noRegisterChan0    bool
   371  	noAckCreateSession bool
   372  }
   373  
   374  func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
   375  	towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
   376  	if err != nil {
   377  		t.Fatalf("Unable to resolve tower TCP addr: %v", err)
   378  	}
   379  
   380  	privKey, err := secp256k1.GeneratePrivateKey()
   381  	if err != nil {
   382  		t.Fatalf("Unable to generate tower private key: %v", err)
   383  	}
   384  	privKeyECDH := &keychain.PrivKeyECDH{PrivKey: privKey}
   385  
   386  	towerPubKey := privKey.PubKey()
   387  
   388  	towerAddr := &lnwire.NetAddress{
   389  		IdentityKey: towerPubKey,
   390  		Address:     towerTCPAddr,
   391  	}
   392  
   393  	const timeout = 200 * time.Millisecond
   394  	serverDB := wtmock.NewTowerDB()
   395  
   396  	serverCfg := &wtserver.Config{
   397  		DB:           serverDB,
   398  		ReadTimeout:  timeout,
   399  		WriteTimeout: timeout,
   400  		NodeKeyECDH:  privKeyECDH,
   401  		NewAddress: func() (stdaddr.Address, error) {
   402  			return addr, nil
   403  		},
   404  		NoAckCreateSession: cfg.noAckCreateSession,
   405  	}
   406  
   407  	server, err := wtserver.New(serverCfg)
   408  	if err != nil {
   409  		t.Fatalf("unable to create wtserver: %v", err)
   410  	}
   411  
   412  	signer := wtmock.NewMockSigner()
   413  	mockNet := newMockNet(server.InboundPeerConnected)
   414  	clientDB := wtmock.NewClientDB()
   415  
   416  	clientCfg := &wtclient.Config{
   417  		Signer:        signer,
   418  		Dial:          mockNet.Dial,
   419  		DB:            clientDB,
   420  		AuthDial:      mockNet.AuthDial,
   421  		SecretKeyRing: wtmock.NewSecretKeyRing(),
   422  		Policy:        cfg.policy,
   423  		NewAddress: func() ([]byte, error) {
   424  			return addrScript, nil
   425  		},
   426  		ReadTimeout:    timeout,
   427  		WriteTimeout:   timeout,
   428  		MinBackoff:     time.Millisecond,
   429  		MaxBackoff:     time.Second,
   430  		ForceQuitDelay: 10 * time.Second,
   431  		ChainParams:    chaincfg.TestNet3Params(),
   432  	}
   433  	client, err := wtclient.New(clientCfg)
   434  	if err != nil {
   435  		t.Fatalf("Unable to create wtclient: %v", err)
   436  	}
   437  
   438  	if err := server.Start(); err != nil {
   439  		t.Fatalf("Unable to start wtserver: %v", err)
   440  	}
   441  
   442  	if err = client.Start(); err != nil {
   443  		server.Stop()
   444  		t.Fatalf("Unable to start wtclient: %v", err)
   445  	}
   446  	if err := client.AddTower(towerAddr); err != nil {
   447  		server.Stop()
   448  		t.Fatalf("Unable to add tower to wtclient: %v", err)
   449  	}
   450  
   451  	h := &testHarness{
   452  		t:          t,
   453  		cfg:        cfg,
   454  		signer:     signer,
   455  		capacity:   cfg.localBalance + cfg.remoteBalance,
   456  		clientDB:   clientDB,
   457  		clientCfg:  clientCfg,
   458  		client:     client,
   459  		serverAddr: towerAddr,
   460  		serverDB:   serverDB,
   461  		serverCfg:  serverCfg,
   462  		server:     server,
   463  		net:        mockNet,
   464  		channels:   make(map[lnwire.ChannelID]*mockChannel),
   465  	}
   466  
   467  	h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance)
   468  	if !cfg.noRegisterChan0 {
   469  		h.registerChannel(0)
   470  	}
   471  
   472  	return h
   473  }
   474  
   475  // startServer creates a new server using the harness's current serverCfg and
   476  // starts it after pointing the mockNet's callback to the new server.
   477  func (h *testHarness) startServer() {
   478  	h.t.Helper()
   479  
   480  	var err error
   481  	h.server, err = wtserver.New(h.serverCfg)
   482  	if err != nil {
   483  		h.t.Fatalf("unable to create wtserver: %v", err)
   484  	}
   485  
   486  	h.net.setConnCallback(h.server.InboundPeerConnected)
   487  
   488  	if err := h.server.Start(); err != nil {
   489  		h.t.Fatalf("unable to start wtserver: %v", err)
   490  	}
   491  }
   492  
   493  // startClient creates a new server using the harness's current clientCf and
   494  // starts it.
   495  func (h *testHarness) startClient() {
   496  	h.t.Helper()
   497  
   498  	towerTCPAddr, err := net.ResolveTCPAddr("tcp", towerAddrStr)
   499  	if err != nil {
   500  		h.t.Fatalf("Unable to resolve tower TCP addr: %v", err)
   501  	}
   502  	towerAddr := &lnwire.NetAddress{
   503  		IdentityKey: h.serverCfg.NodeKeyECDH.PubKey(),
   504  		Address:     towerTCPAddr,
   505  	}
   506  
   507  	h.client, err = wtclient.New(h.clientCfg)
   508  	if err != nil {
   509  		h.t.Fatalf("unable to create wtclient: %v", err)
   510  	}
   511  	if err := h.client.Start(); err != nil {
   512  		h.t.Fatalf("unable to start wtclient: %v", err)
   513  	}
   514  	if err := h.client.AddTower(towerAddr); err != nil {
   515  		h.t.Fatalf("unable to add tower to wtclient: %v", err)
   516  	}
   517  }
   518  
   519  // chanIDFromInt creates a unique channel id given a unique integral id.
   520  func chanIDFromInt(id uint64) lnwire.ChannelID {
   521  	var chanID lnwire.ChannelID
   522  	binary.BigEndian.PutUint64(chanID[:8], id)
   523  	return chanID
   524  }
   525  
   526  // makeChannel creates new channel with id, using the localAmt and remoteAmt as
   527  // the starting balances. The channel will be available by using h.channel(id).
   528  //
   529  // NOTE: The method fails if channel for id already exists.
   530  func (h *testHarness) makeChannel(id uint64,
   531  	localAmt, remoteAmt lnwire.MilliAtom) {
   532  
   533  	h.t.Helper()
   534  
   535  	chanID := chanIDFromInt(id)
   536  	c := newMockChannel(h.t, h.signer, localAmt, remoteAmt)
   537  
   538  	c.mu.Lock()
   539  	_, ok := h.channels[chanID]
   540  	if !ok {
   541  		h.channels[chanID] = c
   542  	}
   543  	c.mu.Unlock()
   544  
   545  	if ok {
   546  		h.t.Fatalf("channel %d already created", id)
   547  	}
   548  }
   549  
   550  // channel retrieves the channel corresponding to id.
   551  //
   552  // NOTE: The method fails if a channel for id does not exist.
   553  func (h *testHarness) channel(id uint64) *mockChannel {
   554  	h.t.Helper()
   555  
   556  	h.mu.Lock()
   557  	c, ok := h.channels[chanIDFromInt(id)]
   558  	h.mu.Unlock()
   559  	if !ok {
   560  		h.t.Fatalf("unable to fetch channel %d", id)
   561  	}
   562  
   563  	return c
   564  }
   565  
   566  // registerChannel registers the channel identified by id with the client.
   567  func (h *testHarness) registerChannel(id uint64) {
   568  	h.t.Helper()
   569  
   570  	chanID := chanIDFromInt(id)
   571  	err := h.client.RegisterChannel(chanID)
   572  	if err != nil {
   573  		h.t.Fatalf("unable to register channel %d: %v", id, err)
   574  	}
   575  }
   576  
   577  // advanceChannelN calls advanceState on the channel identified by id the number
   578  // of provided times and returns the breach hints corresponding to the new
   579  // states.
   580  func (h *testHarness) advanceChannelN(id uint64, n int) []blob.BreachHint {
   581  	h.t.Helper()
   582  
   583  	channel := h.channel(id)
   584  
   585  	var hints []blob.BreachHint
   586  	for i := uint64(0); i < uint64(n); i++ {
   587  		channel.advanceState(h.t)
   588  		commitTx, _ := h.channel(id).getState(i)
   589  		breachTxID := commitTx.TxHash()
   590  		hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
   591  	}
   592  
   593  	return hints
   594  }
   595  
   596  // backupStates instructs the channel identified by id to send backups to the
   597  // client for states in the range [to, from).
   598  func (h *testHarness) backupStates(id, from, to uint64, expErr error) {
   599  	h.t.Helper()
   600  
   601  	for i := from; i < to; i++ {
   602  		h.backupState(id, i, expErr)
   603  	}
   604  }
   605  
   606  // backupStates instructs the channel identified by id to send a backup for
   607  // state i.
   608  func (h *testHarness) backupState(id, i uint64, expErr error) {
   609  	h.t.Helper()
   610  
   611  	_, retribution := h.channel(id).getState(i)
   612  
   613  	chanID := chanIDFromInt(id)
   614  	err := h.client.BackupState(&chanID, retribution, channeldb.SingleFunderBit)
   615  	if err != expErr {
   616  		h.t.Fatalf("back error mismatch, want: %v, got: %v",
   617  			expErr, err)
   618  	}
   619  }
   620  
   621  // sendPayments instructs the channel identified by id to send amt to the remote
   622  // party for each state in from-to times and returns the breach hints for states
   623  // [from, to).
   624  func (h *testHarness) sendPayments(id, from, to uint64,
   625  	amt lnwire.MilliAtom) []blob.BreachHint {
   626  
   627  	h.t.Helper()
   628  
   629  	channel := h.channel(id)
   630  
   631  	var hints []blob.BreachHint
   632  	for i := from; i < to; i++ {
   633  		h.channel(id).sendPayment(h.t, amt)
   634  		commitTx, _ := channel.getState(i)
   635  		breachTxID := commitTx.TxHash()
   636  		hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
   637  	}
   638  
   639  	return hints
   640  }
   641  
   642  // receivePayment instructs the channel identified by id to recv amt from the
   643  // remote party for each state in from-to times and returns the breach hints for
   644  // states [from, to).
   645  func (h *testHarness) recvPayments(id, from, to uint64,
   646  	amt lnwire.MilliAtom) []blob.BreachHint {
   647  
   648  	h.t.Helper()
   649  
   650  	channel := h.channel(id)
   651  
   652  	var hints []blob.BreachHint
   653  	for i := from; i < to; i++ {
   654  		channel.receivePayment(h.t, amt)
   655  		commitTx, _ := channel.getState(i)
   656  		breachTxID := commitTx.TxHash()
   657  		hints = append(hints, blob.NewBreachHintFromHash(&breachTxID))
   658  	}
   659  
   660  	return hints
   661  }
   662  
   663  // waitServerUpdates blocks until the breach hints provided all appear in the
   664  // watchtower's database or the timeout expires. This is used to test that the
   665  // client in fact sends the updates to the server, even if it is offline.
   666  func (h *testHarness) waitServerUpdates(hints []blob.BreachHint,
   667  	timeout time.Duration) {
   668  
   669  	h.t.Helper()
   670  
   671  	// If no breach hints are provided, we will wait out the full timeout to
   672  	// assert that no updates appear.
   673  	wantUpdates := len(hints) > 0
   674  
   675  	hintSet := make(map[blob.BreachHint]struct{})
   676  	for _, hint := range hints {
   677  		hintSet[hint] = struct{}{}
   678  	}
   679  
   680  	if len(hints) != len(hintSet) {
   681  		h.t.Fatalf("breach hints are not unique, list-len: %d "+
   682  			"set-len: %d", len(hints), len(hintSet))
   683  	}
   684  
   685  	// Closure to assert the server's matches are consistent with the hint
   686  	// set.
   687  	serverHasHints := func(matches []wtdb.Match) bool {
   688  		if len(hintSet) != len(matches) {
   689  			return false
   690  		}
   691  
   692  		for _, match := range matches {
   693  			if _, ok := hintSet[match.Hint]; ok {
   694  				continue
   695  			}
   696  
   697  			h.t.Fatalf("match %v in db is not in hint set",
   698  				match.Hint)
   699  		}
   700  
   701  		return true
   702  	}
   703  
   704  	failTimeout := time.After(timeout)
   705  	for {
   706  		select {
   707  		case <-time.After(time.Second):
   708  			matches, err := h.serverDB.QueryMatches(hints)
   709  			switch {
   710  			case err != nil:
   711  				h.t.Fatalf("unable to query for hints: %v", err)
   712  
   713  			case wantUpdates && serverHasHints(matches):
   714  				return
   715  
   716  			case wantUpdates:
   717  				h.t.Logf("Received %d/%d\n", len(matches),
   718  					len(hints))
   719  			}
   720  
   721  		case <-failTimeout:
   722  			matches, err := h.serverDB.QueryMatches(hints)
   723  			switch {
   724  			case err != nil:
   725  				h.t.Fatalf("unable to query for hints: %v", err)
   726  
   727  			case serverHasHints(matches):
   728  				return
   729  
   730  			default:
   731  				h.t.Fatalf("breach hints not received, only "+
   732  					"got %d/%d", len(matches), len(hints))
   733  			}
   734  		}
   735  	}
   736  }
   737  
   738  // assertUpdatesForPolicy queries the server db for matches using the provided
   739  // breach hints, then asserts that each match has a session with the expected
   740  // policy.
   741  func (h *testHarness) assertUpdatesForPolicy(hints []blob.BreachHint,
   742  	expPolicy wtpolicy.Policy) {
   743  
   744  	// Query for matches on the provided hints.
   745  	matches, err := h.serverDB.QueryMatches(hints)
   746  	if err != nil {
   747  		h.t.Fatalf("unable to query for matches: %v", err)
   748  	}
   749  
   750  	// Assert that the number of matches is exactly the number of provided
   751  	// hints.
   752  	if len(matches) != len(hints) {
   753  		h.t.Fatalf("expected: %d matches, got: %d", len(hints),
   754  			len(matches))
   755  	}
   756  
   757  	// Assert that all of the matches correspond to a session with the
   758  	// expected policy.
   759  	for _, match := range matches {
   760  		matchPolicy := match.SessionInfo.Policy
   761  		if expPolicy != matchPolicy {
   762  			h.t.Fatalf("expected session to have policy: %v, "+
   763  				"got: %v", expPolicy, matchPolicy)
   764  		}
   765  	}
   766  }
   767  
   768  // addTower adds a tower found at `addr` to the client.
   769  func (h *testHarness) addTower(addr *lnwire.NetAddress) {
   770  	h.t.Helper()
   771  
   772  	if err := h.client.AddTower(addr); err != nil {
   773  		h.t.Fatalf("unable to add tower: %v", err)
   774  	}
   775  }
   776  
   777  // removeTower removes a tower from the client. If `addr` is specified, then the
   778  // only said address is removed from the tower.
   779  func (h *testHarness) removeTower(pubKey *secp256k1.PublicKey, addr net.Addr) {
   780  	h.t.Helper()
   781  
   782  	if err := h.client.RemoveTower(pubKey, addr); err != nil {
   783  		h.t.Fatalf("unable to remove tower: %v", err)
   784  	}
   785  }
   786  
   787  const (
   788  	localBalance  = lnwire.MilliAtom(100000000)
   789  	remoteBalance = lnwire.MilliAtom(200000000)
   790  )
   791  
   792  type clientTest struct {
   793  	name string
   794  	cfg  harnessCfg
   795  	fn   func(*testHarness)
   796  }
   797  
   798  var clientTests = []clientTest{
   799  	{
   800  		// Asserts that client will return the ErrUnregisteredChannel
   801  		// error when trying to backup states for a channel that has not
   802  		// been registered (and received it's pkscript).
   803  		name: "backup unregistered channel",
   804  		cfg: harnessCfg{
   805  			localBalance:  localBalance,
   806  			remoteBalance: remoteBalance,
   807  			policy: wtpolicy.Policy{
   808  				TxPolicy: wtpolicy.TxPolicy{
   809  					BlobType:     blob.TypeAltruistCommit,
   810  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
   811  				},
   812  				MaxUpdates: 20000,
   813  			},
   814  			noRegisterChan0: true,
   815  		},
   816  		fn: func(h *testHarness) {
   817  			const (
   818  				numUpdates = 5
   819  				chanID     = 0
   820  			)
   821  
   822  			// Advance the channel and backup the retributions. We
   823  			// expect ErrUnregisteredChannel to be returned since
   824  			// the channel was not registered during harness
   825  			// creation.
   826  			h.advanceChannelN(chanID, numUpdates)
   827  			h.backupStates(
   828  				chanID, 0, numUpdates,
   829  				wtclient.ErrUnregisteredChannel,
   830  			)
   831  		},
   832  	},
   833  	{
   834  		// Asserts that the client returns an ErrClientExiting when
   835  		// trying to backup channels after the Stop method has been
   836  		// called.
   837  		name: "backup after stop",
   838  		cfg: harnessCfg{
   839  			localBalance:  localBalance,
   840  			remoteBalance: remoteBalance,
   841  			policy: wtpolicy.Policy{
   842  				TxPolicy: wtpolicy.TxPolicy{
   843  					BlobType:     blob.TypeAltruistCommit,
   844  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
   845  				},
   846  				MaxUpdates: 20000,
   847  			},
   848  		},
   849  		fn: func(h *testHarness) {
   850  			const (
   851  				numUpdates = 5
   852  				chanID     = 0
   853  			)
   854  
   855  			// Stop the client, subsequent backups should fail.
   856  			h.client.Stop()
   857  
   858  			// Advance the channel and try to back up the states. We
   859  			// expect ErrClientExiting to be returned from
   860  			// BackupState.
   861  			h.advanceChannelN(chanID, numUpdates)
   862  			h.backupStates(
   863  				chanID, 0, numUpdates,
   864  				wtclient.ErrClientExiting,
   865  			)
   866  		},
   867  	},
   868  	{
   869  		// Asserts that the client will continue to back up all states
   870  		// that have previously been enqueued before it finishes
   871  		// exiting.
   872  		name: "backup reliable flush",
   873  		cfg: harnessCfg{
   874  			localBalance:  localBalance,
   875  			remoteBalance: remoteBalance,
   876  			policy: wtpolicy.Policy{
   877  				TxPolicy: wtpolicy.TxPolicy{
   878  					BlobType:     blob.TypeAltruistCommit,
   879  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
   880  				},
   881  				MaxUpdates: 5,
   882  			},
   883  		},
   884  		fn: func(h *testHarness) {
   885  			const (
   886  				numUpdates = 5
   887  				chanID     = 0
   888  			)
   889  
   890  			// Generate numUpdates retributions and back them up to
   891  			// the tower.
   892  			hints := h.advanceChannelN(chanID, numUpdates)
   893  			h.backupStates(chanID, 0, numUpdates, nil)
   894  
   895  			// Stop the client in the background, to assert the
   896  			// pipeline is always flushed before it exits.
   897  			go h.client.Stop()
   898  
   899  			// Wait for all of the updates to be populated in the
   900  			// server's database.
   901  			h.waitServerUpdates(hints, time.Second)
   902  		},
   903  	},
   904  	{
   905  		// Assert that the client will not send out backups for states
   906  		// whose justice transactions are ineligible for backup, e.g.
   907  		// creating dust outputs.
   908  		name: "backup dust ineligible",
   909  		cfg: harnessCfg{
   910  			localBalance:  localBalance,
   911  			remoteBalance: remoteBalance,
   912  			policy: wtpolicy.Policy{
   913  				TxPolicy: wtpolicy.TxPolicy{
   914  					BlobType:     blob.TypeAltruistCommit,
   915  					SweepFeeRate: 1000000, // high sweep fee creates dust
   916  				},
   917  				MaxUpdates: 20000,
   918  			},
   919  		},
   920  		fn: func(h *testHarness) {
   921  			const (
   922  				numUpdates = 5
   923  				chanID     = 0
   924  			)
   925  
   926  			// Create the retributions and queue them for backup.
   927  			h.advanceChannelN(chanID, numUpdates)
   928  			h.backupStates(chanID, 0, numUpdates, nil)
   929  
   930  			// Ensure that no updates are received by the server,
   931  			// since they should all be marked as ineligible.
   932  			h.waitServerUpdates(nil, time.Second)
   933  		},
   934  	},
   935  	{
   936  		// Verifies that the client will properly retransmit a committed
   937  		// state update to the watchtower after a restart if the update
   938  		// was not acked while the client was active last.
   939  		name: "committed update restart",
   940  		cfg: harnessCfg{
   941  			localBalance:  localBalance,
   942  			remoteBalance: remoteBalance,
   943  			policy: wtpolicy.Policy{
   944  				TxPolicy: wtpolicy.TxPolicy{
   945  					BlobType:     blob.TypeAltruistCommit,
   946  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
   947  				},
   948  				MaxUpdates: 20000,
   949  			},
   950  		},
   951  		fn: func(h *testHarness) {
   952  			const (
   953  				numUpdates = 5
   954  				chanID     = 0
   955  			)
   956  
   957  			hints := h.advanceChannelN(0, numUpdates)
   958  
   959  			var numSent uint64
   960  
   961  			// Add the first two states to the client's pipeline.
   962  			h.backupStates(chanID, 0, 2, nil)
   963  			numSent = 2
   964  
   965  			// Wait for both to be reflected in the server's
   966  			// database.
   967  			h.waitServerUpdates(hints[:numSent], time.Second)
   968  
   969  			// Now, restart the server and prevent it from acking
   970  			// state updates.
   971  			h.server.Stop()
   972  			h.serverCfg.NoAckUpdates = true
   973  			h.startServer()
   974  			defer h.server.Stop()
   975  
   976  			// Send the next state update to the tower. Since the
   977  			// tower isn't acking state updates, we expect this
   978  			// update to be committed and sent by the session queue,
   979  			// but it will never receive an ack.
   980  			h.backupState(chanID, numSent, nil)
   981  			numSent++
   982  
   983  			// Force quit the client to abort the state updates it
   984  			// has queued. The sleep ensures that the session queues
   985  			// have enough time to commit the state updates before
   986  			// the client is killed.
   987  			time.Sleep(time.Second)
   988  			h.client.ForceQuit()
   989  
   990  			// Restart the server and allow it to ack the updates
   991  			// after the client retransmits the unacked update.
   992  			h.server.Stop()
   993  			h.serverCfg.NoAckUpdates = false
   994  			h.startServer()
   995  			defer h.server.Stop()
   996  
   997  			// Restart the client and allow it to process the
   998  			// committed update.
   999  			h.startClient()
  1000  			defer h.client.ForceQuit()
  1001  
  1002  			// Wait for the committed update to be accepted by the
  1003  			// tower.
  1004  			h.waitServerUpdates(hints[:numSent], time.Second)
  1005  
  1006  			// Finally, send the rest of the updates and wait for
  1007  			// the tower to receive the remaining states.
  1008  			h.backupStates(chanID, numSent, numUpdates, nil)
  1009  
  1010  			// Wait for all of the updates to be populated in the
  1011  			// server's database.
  1012  			h.waitServerUpdates(hints, time.Second)
  1013  
  1014  		},
  1015  	},
  1016  	{
  1017  		// Asserts that the client will continue to retry sending state
  1018  		// updates if it doesn't receive an ack from the server. The
  1019  		// client is expected to flush everything in its in-memory
  1020  		// pipeline once the server begins sending acks again.
  1021  		name: "no ack from server",
  1022  		cfg: harnessCfg{
  1023  			localBalance:  localBalance,
  1024  			remoteBalance: remoteBalance,
  1025  			policy: wtpolicy.Policy{
  1026  				TxPolicy: wtpolicy.TxPolicy{
  1027  					BlobType:     blob.TypeAltruistCommit,
  1028  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1029  				},
  1030  				MaxUpdates: 5,
  1031  			},
  1032  		},
  1033  		fn: func(h *testHarness) {
  1034  			const (
  1035  				numUpdates = 100
  1036  				chanID     = 0
  1037  			)
  1038  
  1039  			// Generate the retributions that will be backed up.
  1040  			hints := h.advanceChannelN(chanID, numUpdates)
  1041  
  1042  			// Restart the server and prevent it from acking state
  1043  			// updates.
  1044  			h.server.Stop()
  1045  			h.serverCfg.NoAckUpdates = true
  1046  			h.startServer()
  1047  			defer h.server.Stop()
  1048  
  1049  			// Now, queue the retributions for backup.
  1050  			h.backupStates(chanID, 0, numUpdates, nil)
  1051  
  1052  			// Stop the client in the background, to assert the
  1053  			// pipeline is always flushed before it exits.
  1054  			go h.client.Stop()
  1055  
  1056  			// Give the client time to saturate a large number of
  1057  			// session queues for which the server has not acked the
  1058  			// state updates that it has received.
  1059  			time.Sleep(time.Second)
  1060  
  1061  			// Restart the server and allow it to ack the updates
  1062  			// after the client retransmits the unacked updates.
  1063  			h.server.Stop()
  1064  			h.serverCfg.NoAckUpdates = false
  1065  			h.startServer()
  1066  			defer h.server.Stop()
  1067  
  1068  			// Wait for all of the updates to be populated in the
  1069  			// server's database.
  1070  			h.waitServerUpdates(hints, 5*time.Second)
  1071  		},
  1072  	},
  1073  	{
  1074  		// Asserts that the client is able to send state updates to the
  1075  		// tower for a full range of channel values, assuming the sweep
  1076  		// fee rates permit it. We expect all of these to be successful
  1077  		// since a sweep transactions spending only from one output is
  1078  		// less expensive than one that sweeps both.
  1079  		name: "send and recv",
  1080  		cfg: harnessCfg{
  1081  			localBalance:  10000001, // ensure (% amt != 0)
  1082  			remoteBalance: 20000001, // ensure (% amt != 0)
  1083  			policy: wtpolicy.Policy{
  1084  				TxPolicy: wtpolicy.TxPolicy{
  1085  					BlobType:     blob.TypeAltruistCommit,
  1086  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1087  				},
  1088  				MaxUpdates: 1000,
  1089  			},
  1090  		},
  1091  		fn: func(h *testHarness) {
  1092  			var (
  1093  				capacity   = h.cfg.localBalance + h.cfg.remoteBalance
  1094  				paymentAmt = lnwire.MilliAtom(200000)
  1095  				numSends   = uint64(h.cfg.localBalance / paymentAmt)
  1096  				numRecvs   = uint64(capacity / paymentAmt)
  1097  				numUpdates = numSends + numRecvs // 200 updates
  1098  				chanID     = uint64(0)
  1099  			)
  1100  
  1101  			// Send money to the remote party until all funds are
  1102  			// depleted.
  1103  			sendHints := h.sendPayments(chanID, 0, numSends, paymentAmt)
  1104  
  1105  			// Now, sequentially receive the entire channel balance
  1106  			// from the remote party.
  1107  			recvHints := h.recvPayments(chanID, numSends, numUpdates, paymentAmt)
  1108  
  1109  			// Collect the hints generated by both sending and
  1110  			// receiving.
  1111  			hints := append(sendHints, recvHints...)
  1112  
  1113  			// Backup the channel's states the client.
  1114  			h.backupStates(chanID, 0, numUpdates, nil)
  1115  
  1116  			// Wait for all of the updates to be populated in the
  1117  			// server's database.
  1118  			h.waitServerUpdates(hints, 5*time.Second)
  1119  		},
  1120  	},
  1121  	{
  1122  		// Asserts that the client is able to support multiple links.
  1123  		name: "multiple link backup",
  1124  		cfg: harnessCfg{
  1125  			localBalance:  localBalance,
  1126  			remoteBalance: remoteBalance,
  1127  			policy: wtpolicy.Policy{
  1128  				TxPolicy: wtpolicy.TxPolicy{
  1129  					BlobType:     blob.TypeAltruistCommit,
  1130  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1131  				},
  1132  				MaxUpdates: 5,
  1133  			},
  1134  		},
  1135  		fn: func(h *testHarness) {
  1136  			const (
  1137  				numUpdates = 5
  1138  				numChans   = 10
  1139  			)
  1140  
  1141  			// Initialize and register an additional 9 channels.
  1142  			for id := uint64(1); id < 10; id++ {
  1143  				h.makeChannel(
  1144  					id, h.cfg.localBalance,
  1145  					h.cfg.remoteBalance,
  1146  				)
  1147  				h.registerChannel(id)
  1148  			}
  1149  
  1150  			// Generate the retributions for all 10 channels and
  1151  			// collect the breach hints.
  1152  			var hints []blob.BreachHint
  1153  			for id := uint64(0); id < 10; id++ {
  1154  				chanHints := h.advanceChannelN(id, numUpdates)
  1155  				hints = append(hints, chanHints...)
  1156  			}
  1157  
  1158  			// Provided all retributions to the client from all
  1159  			// channels.
  1160  			for id := uint64(0); id < 10; id++ {
  1161  				h.backupStates(id, 0, numUpdates, nil)
  1162  			}
  1163  
  1164  			// Test reliable flush under multi-client scenario.
  1165  			go h.client.Stop()
  1166  
  1167  			// Wait for all of the updates to be populated in the
  1168  			// server's database.
  1169  			h.waitServerUpdates(hints, 10*time.Second)
  1170  		},
  1171  	},
  1172  	{
  1173  		name: "create session no ack",
  1174  		cfg: harnessCfg{
  1175  			localBalance:  localBalance,
  1176  			remoteBalance: remoteBalance,
  1177  			policy: wtpolicy.Policy{
  1178  				TxPolicy: wtpolicy.TxPolicy{
  1179  					BlobType:     blob.TypeAltruistCommit,
  1180  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1181  				},
  1182  				MaxUpdates: 5,
  1183  			},
  1184  			noAckCreateSession: true,
  1185  		},
  1186  		fn: func(h *testHarness) {
  1187  			const (
  1188  				chanID     = 0
  1189  				numUpdates = 3
  1190  			)
  1191  
  1192  			// Generate the retributions that will be backed up.
  1193  			hints := h.advanceChannelN(chanID, numUpdates)
  1194  
  1195  			// Now, queue the retributions for backup.
  1196  			h.backupStates(chanID, 0, numUpdates, nil)
  1197  
  1198  			// Since the client is unable to create a session, the
  1199  			// server should have no updates.
  1200  			h.waitServerUpdates(nil, time.Second)
  1201  
  1202  			// Force quit the client since it has queued backups.
  1203  			h.client.ForceQuit()
  1204  
  1205  			// Restart the server and allow it to ack session
  1206  			// creation.
  1207  			h.server.Stop()
  1208  			h.serverCfg.NoAckCreateSession = false
  1209  			h.startServer()
  1210  			defer h.server.Stop()
  1211  
  1212  			// Restart the client with the same policy, which will
  1213  			// immediately try to overwrite the old session with an
  1214  			// identical one.
  1215  			h.startClient()
  1216  			defer h.client.ForceQuit()
  1217  
  1218  			// Now, queue the retributions for backup.
  1219  			h.backupStates(chanID, 0, numUpdates, nil)
  1220  
  1221  			// Wait for all of the updates to be populated in the
  1222  			// server's database.
  1223  			h.waitServerUpdates(hints, 6*time.Second)
  1224  
  1225  			// Assert that the server has updates for the clients
  1226  			// most recent policy.
  1227  			h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
  1228  		},
  1229  	},
  1230  	{
  1231  		name: "create session no ack change policy",
  1232  		cfg: harnessCfg{
  1233  			localBalance:  localBalance,
  1234  			remoteBalance: remoteBalance,
  1235  			policy: wtpolicy.Policy{
  1236  				TxPolicy: wtpolicy.TxPolicy{
  1237  					BlobType:     blob.TypeAltruistCommit,
  1238  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1239  				},
  1240  				MaxUpdates: 5,
  1241  			},
  1242  			noAckCreateSession: true,
  1243  		},
  1244  		fn: func(h *testHarness) {
  1245  			const (
  1246  				chanID     = 0
  1247  				numUpdates = 3
  1248  			)
  1249  
  1250  			// Generate the retributions that will be backed up.
  1251  			hints := h.advanceChannelN(chanID, numUpdates)
  1252  
  1253  			// Now, queue the retributions for backup.
  1254  			h.backupStates(chanID, 0, numUpdates, nil)
  1255  
  1256  			// Since the client is unable to create a session, the
  1257  			// server should have no updates.
  1258  			h.waitServerUpdates(nil, time.Second)
  1259  
  1260  			// Force quit the client since it has queued backups.
  1261  			h.client.ForceQuit()
  1262  
  1263  			// Restart the server and allow it to ack session
  1264  			// creation.
  1265  			h.server.Stop()
  1266  			h.serverCfg.NoAckCreateSession = false
  1267  			h.startServer()
  1268  			defer h.server.Stop()
  1269  
  1270  			// Restart the client with a new policy, which will
  1271  			// immediately try to overwrite the prior session with
  1272  			// the old policy.
  1273  			h.clientCfg.Policy.SweepFeeRate *= 2
  1274  			h.startClient()
  1275  			defer h.client.ForceQuit()
  1276  
  1277  			// Now, queue the retributions for backup.
  1278  			h.backupStates(chanID, 0, numUpdates, nil)
  1279  
  1280  			// Wait for all of the updates to be populated in the
  1281  			// server's database.
  1282  			h.waitServerUpdates(hints, 5*time.Second)
  1283  
  1284  			// Assert that the server has updates for the clients
  1285  			// most recent policy.
  1286  			h.assertUpdatesForPolicy(hints, h.clientCfg.Policy)
  1287  		},
  1288  	},
  1289  	{
  1290  		// Asserts that the client will not request a new session if
  1291  		// already has an existing session with the same TxPolicy. This
  1292  		// permits the client to continue using policies that differ in
  1293  		// operational parameters, but don't manifest in different
  1294  		// justice transactions.
  1295  		name: "create session change policy same txpolicy",
  1296  		cfg: harnessCfg{
  1297  			localBalance:  localBalance,
  1298  			remoteBalance: remoteBalance,
  1299  			policy: wtpolicy.Policy{
  1300  				TxPolicy: wtpolicy.TxPolicy{
  1301  					BlobType:     blob.TypeAltruistCommit,
  1302  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1303  				},
  1304  				MaxUpdates: 10,
  1305  			},
  1306  		},
  1307  		fn: func(h *testHarness) {
  1308  			const (
  1309  				chanID     = 0
  1310  				numUpdates = 6
  1311  			)
  1312  
  1313  			// Generate the retributions that will be backed up.
  1314  			hints := h.advanceChannelN(chanID, numUpdates)
  1315  
  1316  			// Now, queue the first half of the retributions.
  1317  			h.backupStates(chanID, 0, numUpdates/2, nil)
  1318  
  1319  			// Wait for the server to collect the first half.
  1320  			h.waitServerUpdates(hints[:numUpdates/2], time.Second)
  1321  
  1322  			// Stop the client, which should have no more backups.
  1323  			h.client.Stop()
  1324  
  1325  			// Record the policy that the first half was stored
  1326  			// under. We'll expect the second half to also be stored
  1327  			// under the original policy, since we are only adjusting
  1328  			// the MaxUpdates. The client should detect that the
  1329  			// two policies have equivalent TxPolicies and continue
  1330  			// using the first.
  1331  			expPolicy := h.clientCfg.Policy
  1332  
  1333  			// Restart the client with a new policy.
  1334  			h.clientCfg.Policy.MaxUpdates = 20
  1335  			h.startClient()
  1336  			defer h.client.ForceQuit()
  1337  
  1338  			// Now, queue the second half of the retributions.
  1339  			h.backupStates(chanID, numUpdates/2, numUpdates, nil)
  1340  
  1341  			// Wait for all of the updates to be populated in the
  1342  			// server's database.
  1343  			h.waitServerUpdates(hints, 5*time.Second)
  1344  
  1345  			// Assert that the server has updates for the client's
  1346  			// original policy.
  1347  			h.assertUpdatesForPolicy(hints, expPolicy)
  1348  		},
  1349  	},
  1350  	{
  1351  		// Asserts that the client will deduplicate backups presented by
  1352  		// a channel both in memory and after a restart. The client
  1353  		// should only accept backups with a commit height greater than
  1354  		// any processed already processed for a given policy.
  1355  		name: "dedup backups",
  1356  		cfg: harnessCfg{
  1357  			localBalance:  localBalance,
  1358  			remoteBalance: remoteBalance,
  1359  			policy: wtpolicy.Policy{
  1360  				TxPolicy: wtpolicy.TxPolicy{
  1361  					BlobType:     blob.TypeAltruistCommit,
  1362  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1363  				},
  1364  				MaxUpdates: 5,
  1365  			},
  1366  		},
  1367  		fn: func(h *testHarness) {
  1368  			const (
  1369  				numUpdates = 10
  1370  				chanID     = 0
  1371  			)
  1372  
  1373  			// Generate the retributions that will be backed up.
  1374  			hints := h.advanceChannelN(chanID, numUpdates)
  1375  
  1376  			// Queue the first half of the retributions twice, the
  1377  			// second batch should be entirely deduped by the
  1378  			// client's in-memory tracking.
  1379  			h.backupStates(chanID, 0, numUpdates/2, nil)
  1380  			h.backupStates(chanID, 0, numUpdates/2, nil)
  1381  
  1382  			// Wait for the first half of the updates to be
  1383  			// populated in the server's database.
  1384  			h.waitServerUpdates(hints[:len(hints)/2], 5*time.Second)
  1385  
  1386  			// Restart the client, so we can ensure the deduping is
  1387  			// maintained across restarts.
  1388  			h.client.Stop()
  1389  			h.startClient()
  1390  			defer h.client.ForceQuit()
  1391  
  1392  			// Try to back up the full range of retributions. Only
  1393  			// the second half should actually be sent.
  1394  			h.backupStates(chanID, 0, numUpdates, nil)
  1395  
  1396  			// Wait for all of the updates to be populated in the
  1397  			// server's database.
  1398  			h.waitServerUpdates(hints, 5*time.Second)
  1399  		},
  1400  	},
  1401  	{
  1402  		// Asserts that the client can continue making backups to a
  1403  		// tower that's been re-added after it's been removed.
  1404  		name: "re-add removed tower",
  1405  		cfg: harnessCfg{
  1406  			localBalance:  localBalance,
  1407  			remoteBalance: remoteBalance,
  1408  			policy: wtpolicy.Policy{
  1409  				TxPolicy: wtpolicy.TxPolicy{
  1410  					BlobType:     blob.TypeAltruistCommit,
  1411  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1412  				},
  1413  				MaxUpdates: 5,
  1414  			},
  1415  		},
  1416  		fn: func(h *testHarness) {
  1417  			const (
  1418  				chanID     = 0
  1419  				numUpdates = 4
  1420  			)
  1421  
  1422  			// Create four channel updates and only back up the
  1423  			// first two.
  1424  			hints := h.advanceChannelN(chanID, numUpdates)
  1425  			h.backupStates(chanID, 0, numUpdates/2, nil)
  1426  			h.waitServerUpdates(hints[:numUpdates/2], 5*time.Second)
  1427  
  1428  			// Fully remove the tower, causing its existing sessions
  1429  			// to be marked inactive.
  1430  			h.removeTower(h.serverAddr.IdentityKey, nil)
  1431  
  1432  			// Back up the remaining states. Since the tower has
  1433  			// been removed, it shouldn't receive any updates.
  1434  			h.backupStates(chanID, numUpdates/2, numUpdates, nil)
  1435  			h.waitServerUpdates(nil, time.Second)
  1436  
  1437  			// Re-add the tower. We prevent the tower from acking
  1438  			// session creation to ensure the inactive sessions are
  1439  			// not used.
  1440  			err := h.server.Stop()
  1441  			require.Nil(h.t, err)
  1442  			h.serverCfg.NoAckCreateSession = true
  1443  			h.startServer()
  1444  			h.addTower(h.serverAddr)
  1445  			h.waitServerUpdates(nil, time.Second)
  1446  
  1447  			// Finally, allow the tower to ack session creation,
  1448  			// allowing the state updates to be sent through the new
  1449  			// session.
  1450  			err = h.server.Stop()
  1451  			require.Nil(h.t, err)
  1452  			h.serverCfg.NoAckCreateSession = false
  1453  			h.startServer()
  1454  			h.waitServerUpdates(hints[numUpdates/2:], 5*time.Second)
  1455  		},
  1456  	},
  1457  	{
  1458  		// Asserts that the client's force quite delay will properly
  1459  		// shutdown the client if it is unable to completely drain the
  1460  		// task pipeline.
  1461  		name: "force unclean shutdown",
  1462  		cfg: harnessCfg{
  1463  			localBalance:  localBalance,
  1464  			remoteBalance: remoteBalance,
  1465  			policy: wtpolicy.Policy{
  1466  				TxPolicy: wtpolicy.TxPolicy{
  1467  					BlobType:     blob.TypeAltruistCommit,
  1468  					SweepFeeRate: wtpolicy.DefaultSweepFeeRate,
  1469  				},
  1470  				MaxUpdates: 5,
  1471  			},
  1472  		},
  1473  		fn: func(h *testHarness) {
  1474  			const (
  1475  				chanID     = 0
  1476  				numUpdates = 6
  1477  				maxUpdates = 5
  1478  			)
  1479  
  1480  			// Advance the channel to create all states.
  1481  			hints := h.advanceChannelN(chanID, numUpdates)
  1482  
  1483  			// Back up 4 of the 5 states for the negotiated session.
  1484  			h.backupStates(chanID, 0, maxUpdates-1, nil)
  1485  			h.waitServerUpdates(hints[:maxUpdates-1], 5*time.Second)
  1486  
  1487  			// Now, restart the tower and prevent it from acking any
  1488  			// new sessions. We do this here as once the last slot
  1489  			// is exhausted the client will attempt to renegotiate.
  1490  			err := h.server.Stop()
  1491  			require.Nil(h.t, err)
  1492  			h.serverCfg.NoAckCreateSession = true
  1493  			h.startServer()
  1494  
  1495  			// Back up the remaining two states. Once the first is
  1496  			// processed, the session will be exhausted but the
  1497  			// client won't be able to regnegotiate a session for
  1498  			// the final state. We'll only wait for the first five
  1499  			// states to arrive at the tower.
  1500  			h.backupStates(chanID, maxUpdates-1, numUpdates, nil)
  1501  			h.waitServerUpdates(hints[:maxUpdates], 5*time.Second)
  1502  
  1503  			// Finally, stop the client which will continue to
  1504  			// attempt session negotiation since it has one more
  1505  			// state to process. After the force quite delay
  1506  			// expires, the client should force quite itself and
  1507  			// allow the test to complete.
  1508  			err = h.client.Stop()
  1509  			require.Nil(h.t, err)
  1510  		},
  1511  	},
  1512  }
  1513  
  1514  // TestClient executes the client test suite, asserting the ability to backup
  1515  // states in a number of failure cases and it's reliability during shutdown.
  1516  func TestClient(t *testing.T) {
  1517  	for _, test := range clientTests {
  1518  		tc := test
  1519  		t.Run(tc.name, func(t *testing.T) {
  1520  			t.Parallel()
  1521  
  1522  			h := newHarness(t, tc.cfg)
  1523  			defer h.server.Stop()
  1524  			defer h.client.ForceQuit()
  1525  
  1526  			tc.fn(h)
  1527  		})
  1528  	}
  1529  }