github.com/decred/dcrlnd@v0.7.6/watchtower/wtserver/server_test.go (about)

     1  package wtserver_test
     2  
     3  import (
     4  	"bytes"
     5  	"reflect"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/decred/dcrd/chaincfg/v3"
    10  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    11  	"github.com/decred/dcrd/txscript/v4/stdaddr"
    12  	"github.com/decred/dcrlnd/input"
    13  	"github.com/decred/dcrlnd/lnwire"
    14  	"github.com/decred/dcrlnd/watchtower/blob"
    15  	"github.com/decred/dcrlnd/watchtower/wtdb"
    16  	"github.com/decred/dcrlnd/watchtower/wtmock"
    17  	"github.com/decred/dcrlnd/watchtower/wtserver"
    18  	"github.com/decred/dcrlnd/watchtower/wtwire"
    19  )
    20  
    21  var (
    22  	// addr is the server's reward address given to watchtower clients.
    23  	addr, _ = stdaddr.DecodeAddress("TsVDyY1k1N2jZ7xYuoA1PEbwSP2mQnXR9qb",
    24  		chaincfg.TestNet3Params())
    25  
    26  	addrScript, _ = input.PayToAddrScript(addr)
    27  
    28  	testnetChainHash = chaincfg.TestNet3Params().GenesisHash
    29  
    30  	testBlob = make([]byte, blob.Size(blob.TypeAltruistCommit))
    31  )
    32  
    33  // randPubKey generates a new secp keypair, and returns the public key.
    34  func randPubKey(t *testing.T) *secp256k1.PublicKey {
    35  	t.Helper()
    36  
    37  	sk, err := secp256k1.GeneratePrivateKey()
    38  	if err != nil {
    39  		t.Fatalf("unable to generate pubkey: %v", err)
    40  	}
    41  
    42  	return sk.PubKey()
    43  }
    44  
    45  // initServer creates and starts a new server using the server.DB and timeout.
    46  // If the provided database is nil, a mock db will be used.
    47  func initServer(t *testing.T, db wtserver.DB,
    48  	timeout time.Duration) wtserver.Interface {
    49  
    50  	t.Helper()
    51  
    52  	if db == nil {
    53  		db = wtmock.NewTowerDB()
    54  	}
    55  
    56  	s, err := wtserver.New(&wtserver.Config{
    57  		DB:           db,
    58  		ReadTimeout:  timeout,
    59  		WriteTimeout: timeout,
    60  		NewAddress: func() (stdaddr.Address, error) {
    61  			return addr, nil
    62  		},
    63  		ChainHash: testnetChainHash,
    64  	})
    65  	if err != nil {
    66  		t.Fatalf("unable to create server: %v", err)
    67  	}
    68  
    69  	if err = s.Start(); err != nil {
    70  		t.Fatalf("unable to start server: %v", err)
    71  	}
    72  
    73  	return s
    74  }
    75  
    76  // TestServerOnlyAcceptOnePeer checks that the server will reject duplicate
    77  // peers with the same session id by disconnecting them. This is accomplished by
    78  // connecting two distinct peers with the same session id, and trying to send
    79  // messages on both connections. Since one should be rejected, we verify that
    80  // only one of the connections is able to send messages.
    81  func TestServerOnlyAcceptOnePeer(t *testing.T) {
    82  	t.Parallel()
    83  
    84  	const timeoutDuration = 500 * time.Millisecond
    85  
    86  	s := initServer(t, nil, timeoutDuration)
    87  	defer s.Stop()
    88  
    89  	localPub := randPubKey(t)
    90  
    91  	// Create two peers using the same session id.
    92  	peerPub := randPubKey(t)
    93  	peer1 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
    94  	peer2 := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
    95  
    96  	// Serialize a Init message to be sent by both peers.
    97  	init := wtwire.NewInitMessage(
    98  		lnwire.NewRawFeatureVector(), testnetChainHash,
    99  	)
   100  
   101  	var b bytes.Buffer
   102  	_, err := wtwire.WriteMessage(&b, init, 0)
   103  	if err != nil {
   104  		t.Fatalf("unable to write message: %v", err)
   105  	}
   106  
   107  	msg := b.Bytes()
   108  
   109  	// Connect both peers to the server simultaneously.
   110  	s.InboundPeerConnected(peer1)
   111  	s.InboundPeerConnected(peer2)
   112  
   113  	// Use a timeout of twice the server's timeouts, to ensure the server
   114  	// has time to process the messages.
   115  	timeout := time.After(2 * timeoutDuration)
   116  
   117  	// Try to send a message on either peer, and record the opposite peer as
   118  	// the one we assume to be rejected.
   119  	var (
   120  		rejectedPeer *wtmock.MockPeer
   121  		acceptedPeer *wtmock.MockPeer
   122  	)
   123  	select {
   124  	case peer1.IncomingMsgs <- msg:
   125  		acceptedPeer = peer1
   126  		rejectedPeer = peer2
   127  	case peer2.IncomingMsgs <- msg:
   128  		acceptedPeer = peer2
   129  		rejectedPeer = peer1
   130  	case <-timeout:
   131  		t.Fatalf("unable to send message via either peer")
   132  	}
   133  
   134  	// Try again to send a message, this time only via the assumed-rejected
   135  	// peer. We expect our conservative timeout to expire, as the server
   136  	// isn't reading from this peer. Before the timeout, the accepted peer
   137  	// should also receive a reply to its Init message.
   138  	select {
   139  	case <-acceptedPeer.OutgoingMsgs:
   140  		select {
   141  		case rejectedPeer.IncomingMsgs <- msg:
   142  			t.Fatalf("rejected peer should not have received message")
   143  		case <-timeout:
   144  			// Accepted peer got reply, rejected peer go nothing.
   145  		}
   146  	case rejectedPeer.IncomingMsgs <- msg:
   147  		t.Fatalf("rejected peer should not have received message")
   148  	case <-timeout:
   149  		t.Fatalf("accepted peer should have received init message")
   150  	}
   151  }
   152  
   153  type createSessionTestCase struct {
   154  	name            string
   155  	initMsg         *wtwire.Init
   156  	createMsg       *wtwire.CreateSession
   157  	expReply        *wtwire.CreateSessionReply
   158  	expDupReply     *wtwire.CreateSessionReply
   159  	sendStateUpdate bool
   160  }
   161  
   162  var createSessionTests = []createSessionTestCase{
   163  	{
   164  		name: "duplicate session create altruist anchor commit",
   165  		initMsg: wtwire.NewInitMessage(
   166  			lnwire.NewRawFeatureVector(),
   167  			testnetChainHash,
   168  		),
   169  		createMsg: &wtwire.CreateSession{
   170  			BlobType:     blob.TypeAltruistAnchorCommit,
   171  			MaxUpdates:   1000,
   172  			RewardBase:   0,
   173  			RewardRate:   0,
   174  			SweepFeeRate: 10000,
   175  		},
   176  		expReply: &wtwire.CreateSessionReply{
   177  			Code: wtwire.CodeOK,
   178  			Data: []byte{},
   179  		},
   180  		expDupReply: &wtwire.CreateSessionReply{
   181  			Code: wtwire.CodeOK,
   182  			Data: []byte{},
   183  		},
   184  	},
   185  	{
   186  		name: "duplicate session create",
   187  		initMsg: wtwire.NewInitMessage(
   188  			lnwire.NewRawFeatureVector(),
   189  			testnetChainHash,
   190  		),
   191  		createMsg: &wtwire.CreateSession{
   192  			BlobType:     blob.TypeAltruistCommit,
   193  			MaxUpdates:   1000,
   194  			RewardBase:   0,
   195  			RewardRate:   0,
   196  			SweepFeeRate: 10000,
   197  		},
   198  		expReply: &wtwire.CreateSessionReply{
   199  			Code: wtwire.CodeOK,
   200  			Data: []byte{},
   201  		},
   202  		expDupReply: &wtwire.CreateSessionReply{
   203  			Code: wtwire.CodeOK,
   204  			Data: []byte{},
   205  		},
   206  	},
   207  	{
   208  		name: "duplicate session create after use",
   209  		initMsg: wtwire.NewInitMessage(
   210  			lnwire.NewRawFeatureVector(),
   211  			testnetChainHash,
   212  		),
   213  		createMsg: &wtwire.CreateSession{
   214  			BlobType:     blob.TypeAltruistCommit,
   215  			MaxUpdates:   1000,
   216  			RewardBase:   0,
   217  			RewardRate:   0,
   218  			SweepFeeRate: 10000,
   219  		},
   220  		expReply: &wtwire.CreateSessionReply{
   221  			Code: wtwire.CodeOK,
   222  			Data: []byte{},
   223  		},
   224  		expDupReply: &wtwire.CreateSessionReply{
   225  			Code:        wtwire.CreateSessionCodeAlreadyExists,
   226  			LastApplied: 1,
   227  			Data:        []byte{},
   228  		},
   229  		sendStateUpdate: true,
   230  	},
   231  	{
   232  		name: "duplicate session create reward",
   233  		initMsg: wtwire.NewInitMessage(
   234  			lnwire.NewRawFeatureVector(),
   235  			testnetChainHash,
   236  		),
   237  		createMsg: &wtwire.CreateSession{
   238  			BlobType:     blob.TypeRewardCommit,
   239  			MaxUpdates:   1000,
   240  			RewardBase:   0,
   241  			RewardRate:   0,
   242  			SweepFeeRate: 10000,
   243  		},
   244  		expReply: &wtwire.CreateSessionReply{
   245  			Code: wtwire.CodeOK,
   246  			Data: addrScript,
   247  		},
   248  		expDupReply: &wtwire.CreateSessionReply{
   249  			Code: wtwire.CodeOK,
   250  			Data: addrScript,
   251  		},
   252  	},
   253  	{
   254  		name: "reject unsupported blob type",
   255  		initMsg: wtwire.NewInitMessage(
   256  			lnwire.NewRawFeatureVector(),
   257  			testnetChainHash,
   258  		),
   259  		createMsg: &wtwire.CreateSession{
   260  			BlobType:     0,
   261  			MaxUpdates:   1000,
   262  			RewardBase:   0,
   263  			RewardRate:   0,
   264  			SweepFeeRate: 10000,
   265  		},
   266  		expReply: &wtwire.CreateSessionReply{
   267  			Code: wtwire.CreateSessionCodeRejectBlobType,
   268  			Data: []byte{},
   269  		},
   270  	},
   271  	// TODO(conner): add policy rejection tests
   272  }
   273  
   274  // TestServerCreateSession checks the server's behavior in response to a
   275  // table-driven set of CreateSession messages.
   276  func TestServerCreateSession(t *testing.T) {
   277  	t.Parallel()
   278  
   279  	for i, test := range createSessionTests {
   280  		t.Run(test.name, func(t *testing.T) {
   281  			testServerCreateSession(t, i, test)
   282  		})
   283  	}
   284  }
   285  
   286  func testServerCreateSession(t *testing.T, i int, test createSessionTestCase) {
   287  	const timeoutDuration = 500 * time.Millisecond
   288  
   289  	s := initServer(t, nil, timeoutDuration)
   290  	defer s.Stop()
   291  
   292  	localPub := randPubKey(t)
   293  
   294  	// Create a new client and connect to server.
   295  	peerPub := randPubKey(t)
   296  	peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   297  	connect(t, s, peer, test.initMsg, timeoutDuration)
   298  
   299  	// Send the CreateSession message, and wait for a reply.
   300  	sendMsg(t, test.createMsg, peer, timeoutDuration)
   301  
   302  	reply := recvReply(
   303  		t, "MsgCreateSessionReply", peer, timeoutDuration,
   304  	).(*wtwire.CreateSessionReply)
   305  
   306  	// Verify that the server's response matches our expectation.
   307  	if !reflect.DeepEqual(reply, test.expReply) {
   308  		t.Fatalf("[test %d] expected reply %v, got %d",
   309  			i, test.expReply, reply)
   310  	}
   311  
   312  	// Assert that the server closes the connection after processing the
   313  	// CreateSession.
   314  	assertConnClosed(t, peer, 2*timeoutDuration)
   315  
   316  	// If this test did not request sending a duplicate CreateSession, we can
   317  	// continue to the next test.
   318  	if test.expDupReply == nil {
   319  		return
   320  	}
   321  
   322  	if test.sendStateUpdate {
   323  		peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   324  		connect(t, s, peer, test.initMsg, timeoutDuration)
   325  		update := &wtwire.StateUpdate{
   326  			SeqNum:        1,
   327  			IsComplete:    1,
   328  			EncryptedBlob: testBlob,
   329  		}
   330  		sendMsg(t, update, peer, timeoutDuration)
   331  
   332  		assertConnClosed(t, peer, 2*timeoutDuration)
   333  	}
   334  
   335  	// Simulate a peer with the same session id connection to the server
   336  	// again.
   337  	peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   338  	connect(t, s, peer, test.initMsg, timeoutDuration)
   339  
   340  	// Send the _same_ CreateSession message as the first attempt.
   341  	sendMsg(t, test.createMsg, peer, timeoutDuration)
   342  
   343  	reply = recvReply(
   344  		t, "MsgCreateSessionReply", peer, timeoutDuration,
   345  	).(*wtwire.CreateSessionReply)
   346  
   347  	// Ensure that the server's reply matches our expected response for a
   348  	// duplicate send.
   349  	if !reflect.DeepEqual(reply, test.expDupReply) {
   350  		t.Fatalf("[test %d] expected reply %v, got %v",
   351  			i, test.expDupReply, reply)
   352  	}
   353  
   354  	// Finally, check that the server tore down the connection.
   355  	assertConnClosed(t, peer, 2*timeoutDuration)
   356  }
   357  
   358  type stateUpdateTestCase struct {
   359  	name      string
   360  	initMsg   *wtwire.Init
   361  	createMsg *wtwire.CreateSession
   362  	updates   []*wtwire.StateUpdate
   363  	replies   []*wtwire.StateUpdateReply
   364  }
   365  
   366  var stateUpdateTests = []stateUpdateTestCase{
   367  	// Valid update sequence, send seqnum == lastapplied as last update.
   368  	{
   369  		name: "perm fail after sending seqnum equal lastapplied",
   370  		initMsg: wtwire.NewInitMessage(
   371  			lnwire.NewRawFeatureVector(),
   372  			testnetChainHash,
   373  		),
   374  		createMsg: &wtwire.CreateSession{
   375  			BlobType:     blob.TypeAltruistCommit,
   376  			MaxUpdates:   3,
   377  			RewardBase:   0,
   378  			RewardRate:   0,
   379  			SweepFeeRate: 10000,
   380  		},
   381  		updates: []*wtwire.StateUpdate{
   382  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   383  			{SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
   384  			{SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
   385  			{SeqNum: 3, LastApplied: 3, EncryptedBlob: testBlob},
   386  		},
   387  		replies: []*wtwire.StateUpdateReply{
   388  			{Code: wtwire.CodeOK, LastApplied: 1},
   389  			{Code: wtwire.CodeOK, LastApplied: 2},
   390  			{Code: wtwire.CodeOK, LastApplied: 3},
   391  			{
   392  				Code:        wtwire.CodePermanentFailure,
   393  				LastApplied: 3,
   394  			},
   395  		},
   396  	},
   397  	// Send update that skips next expected sequence number.
   398  	{
   399  		name: "skip sequence number",
   400  		initMsg: wtwire.NewInitMessage(
   401  			lnwire.NewRawFeatureVector(),
   402  			testnetChainHash,
   403  		),
   404  		createMsg: &wtwire.CreateSession{
   405  			BlobType:     blob.TypeAltruistCommit,
   406  			MaxUpdates:   4,
   407  			RewardBase:   0,
   408  			RewardRate:   0,
   409  			SweepFeeRate: 10000,
   410  		},
   411  		updates: []*wtwire.StateUpdate{
   412  			{SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
   413  		},
   414  		replies: []*wtwire.StateUpdateReply{
   415  			{
   416  				Code:        wtwire.StateUpdateCodeSeqNumOutOfOrder,
   417  				LastApplied: 0,
   418  			},
   419  		},
   420  	},
   421  	// Send update that reverts to older sequence number.
   422  	{
   423  		name: "revert to older seqnum",
   424  		initMsg: wtwire.NewInitMessage(
   425  			lnwire.NewRawFeatureVector(),
   426  			testnetChainHash,
   427  		),
   428  		createMsg: &wtwire.CreateSession{
   429  			BlobType:     blob.TypeAltruistCommit,
   430  			MaxUpdates:   4,
   431  			RewardBase:   0,
   432  			RewardRate:   0,
   433  			SweepFeeRate: 10000,
   434  		},
   435  		updates: []*wtwire.StateUpdate{
   436  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   437  			{SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
   438  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   439  		},
   440  		replies: []*wtwire.StateUpdateReply{
   441  			{Code: wtwire.CodeOK, LastApplied: 1},
   442  			{Code: wtwire.CodeOK, LastApplied: 2},
   443  			{
   444  				Code:        wtwire.StateUpdateCodeSeqNumOutOfOrder,
   445  				LastApplied: 2,
   446  			},
   447  		},
   448  	},
   449  	// Send update echoing a last applied that is lower than previous value.
   450  	{
   451  		name: "revert to older lastapplied",
   452  		initMsg: wtwire.NewInitMessage(
   453  			lnwire.NewRawFeatureVector(),
   454  			testnetChainHash,
   455  		),
   456  		createMsg: &wtwire.CreateSession{
   457  			BlobType:     blob.TypeAltruistCommit,
   458  			MaxUpdates:   4,
   459  			RewardBase:   0,
   460  			RewardRate:   0,
   461  			SweepFeeRate: 10000,
   462  		},
   463  		updates: []*wtwire.StateUpdate{
   464  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   465  			{SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
   466  			{SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
   467  			{SeqNum: 4, LastApplied: 1, EncryptedBlob: testBlob},
   468  		},
   469  		replies: []*wtwire.StateUpdateReply{
   470  			{Code: wtwire.CodeOK, LastApplied: 1},
   471  			{Code: wtwire.CodeOK, LastApplied: 2},
   472  			{Code: wtwire.CodeOK, LastApplied: 3},
   473  			{Code: wtwire.StateUpdateCodeClientBehind, LastApplied: 3},
   474  		},
   475  	},
   476  	// Valid update sequence with disconnection, ensure resumes resume.
   477  	// Client echos last applied as they are received.
   478  	{
   479  		name: "resume after disconnect",
   480  		initMsg: wtwire.NewInitMessage(
   481  			lnwire.NewRawFeatureVector(),
   482  			testnetChainHash,
   483  		),
   484  		createMsg: &wtwire.CreateSession{
   485  			BlobType:     blob.TypeAltruistCommit,
   486  			MaxUpdates:   4,
   487  			RewardBase:   0,
   488  			RewardRate:   0,
   489  			SweepFeeRate: 10000,
   490  		},
   491  		updates: []*wtwire.StateUpdate{
   492  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   493  			{SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
   494  			nil, // Wait for read timeout to drop conn, then reconnect.
   495  			{SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
   496  			{SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
   497  		},
   498  		replies: []*wtwire.StateUpdateReply{
   499  			{Code: wtwire.CodeOK, LastApplied: 1},
   500  			{Code: wtwire.CodeOK, LastApplied: 2},
   501  			nil,
   502  			{Code: wtwire.CodeOK, LastApplied: 3},
   503  			{Code: wtwire.CodeOK, LastApplied: 4},
   504  		},
   505  	},
   506  	// Valid update sequence with disconnection, resume next update. Client
   507  	// doesn't echo last applied until last message.
   508  	{
   509  		name: "resume after disconnect lagging lastapplied",
   510  		initMsg: wtwire.NewInitMessage(
   511  			lnwire.NewRawFeatureVector(),
   512  			testnetChainHash,
   513  		),
   514  		createMsg: &wtwire.CreateSession{
   515  			BlobType:     blob.TypeAltruistCommit,
   516  			MaxUpdates:   4,
   517  			RewardBase:   0,
   518  			RewardRate:   0,
   519  			SweepFeeRate: 10000,
   520  		},
   521  		updates: []*wtwire.StateUpdate{
   522  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   523  			{SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
   524  			nil, // Wait for read timeout to drop conn, then reconnect.
   525  			{SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob},
   526  			{SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
   527  		},
   528  		replies: []*wtwire.StateUpdateReply{
   529  			{Code: wtwire.CodeOK, LastApplied: 1},
   530  			{Code: wtwire.CodeOK, LastApplied: 2},
   531  			nil,
   532  			{Code: wtwire.CodeOK, LastApplied: 3},
   533  			{Code: wtwire.CodeOK, LastApplied: 4},
   534  		},
   535  	},
   536  	// Valid update sequence with disconnection, resume last update.  Client
   537  	// doesn't echo last applied until last message.
   538  	{
   539  		name: "resume after disconnect lagging lastapplied",
   540  		initMsg: wtwire.NewInitMessage(
   541  			lnwire.NewRawFeatureVector(),
   542  			testnetChainHash,
   543  		),
   544  		createMsg: &wtwire.CreateSession{
   545  			BlobType:     blob.TypeAltruistCommit,
   546  			MaxUpdates:   4,
   547  			RewardBase:   0,
   548  			RewardRate:   0,
   549  			SweepFeeRate: 10000,
   550  		},
   551  		updates: []*wtwire.StateUpdate{
   552  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   553  			{SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
   554  			nil, // Wait for read timeout to drop conn, then reconnect.
   555  			{SeqNum: 2, LastApplied: 0, EncryptedBlob: testBlob},
   556  			{SeqNum: 3, LastApplied: 0, EncryptedBlob: testBlob},
   557  			{SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
   558  		},
   559  		replies: []*wtwire.StateUpdateReply{
   560  			{Code: wtwire.CodeOK, LastApplied: 1},
   561  			{Code: wtwire.CodeOK, LastApplied: 2},
   562  			nil,
   563  			{Code: wtwire.CodeOK, LastApplied: 2},
   564  			{Code: wtwire.CodeOK, LastApplied: 3},
   565  			{Code: wtwire.CodeOK, LastApplied: 4},
   566  		},
   567  	},
   568  	// Send update with sequence number that exceeds MaxUpdates.
   569  	{
   570  		name: "seqnum exceed maxupdates",
   571  		initMsg: wtwire.NewInitMessage(
   572  			lnwire.NewRawFeatureVector(),
   573  			testnetChainHash,
   574  		),
   575  		createMsg: &wtwire.CreateSession{
   576  			BlobType:     blob.TypeAltruistCommit,
   577  			MaxUpdates:   3,
   578  			RewardBase:   0,
   579  			RewardRate:   0,
   580  			SweepFeeRate: 10000,
   581  		},
   582  		updates: []*wtwire.StateUpdate{
   583  			{SeqNum: 1, LastApplied: 0, EncryptedBlob: testBlob},
   584  			{SeqNum: 2, LastApplied: 1, EncryptedBlob: testBlob},
   585  			{SeqNum: 3, LastApplied: 2, EncryptedBlob: testBlob},
   586  			{SeqNum: 4, LastApplied: 3, EncryptedBlob: testBlob},
   587  		},
   588  		replies: []*wtwire.StateUpdateReply{
   589  			{Code: wtwire.CodeOK, LastApplied: 1},
   590  			{Code: wtwire.CodeOK, LastApplied: 2},
   591  			{Code: wtwire.CodeOK, LastApplied: 3},
   592  			{
   593  				Code:        wtwire.StateUpdateCodeMaxUpdatesExceeded,
   594  				LastApplied: 3,
   595  			},
   596  		},
   597  	},
   598  	// Ensure sequence number 0 causes permanent failure.
   599  	{
   600  		name: "perm fail after seqnum 0",
   601  		initMsg: wtwire.NewInitMessage(
   602  			lnwire.NewRawFeatureVector(),
   603  			testnetChainHash,
   604  		),
   605  		createMsg: &wtwire.CreateSession{
   606  			BlobType:     blob.TypeAltruistCommit,
   607  			MaxUpdates:   3,
   608  			RewardBase:   0,
   609  			RewardRate:   0,
   610  			SweepFeeRate: 10000,
   611  		},
   612  		updates: []*wtwire.StateUpdate{
   613  			{SeqNum: 0, LastApplied: 0, EncryptedBlob: testBlob},
   614  		},
   615  		replies: []*wtwire.StateUpdateReply{
   616  			{
   617  				Code:        wtwire.CodePermanentFailure,
   618  				LastApplied: 0,
   619  			},
   620  		},
   621  	},
   622  }
   623  
   624  // TestServerStateUpdates tests the behavior of the server in response to
   625  // watchtower clients sending StateUpdate messages, after having already
   626  // established an open session. The test asserts that the server responds
   627  // with the appropriate failure codes in a number of failure conditions where
   628  // the server and client desynchronize. It also checks the ability of the client
   629  // to disconnect, connect, and continue updating from the last successful state
   630  // update.
   631  func TestServerStateUpdates(t *testing.T) {
   632  	t.Parallel()
   633  
   634  	for _, test := range stateUpdateTests {
   635  		t.Run(test.name, func(t *testing.T) {
   636  			testServerStateUpdates(t, test)
   637  		})
   638  	}
   639  }
   640  
   641  func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) {
   642  	const timeoutDuration = 100 * time.Millisecond
   643  
   644  	s := initServer(t, nil, timeoutDuration)
   645  	defer s.Stop()
   646  
   647  	localPub := randPubKey(t)
   648  
   649  	// Create a new client and connect to the server.
   650  	peerPub := randPubKey(t)
   651  	peer := wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   652  	connect(t, s, peer, test.initMsg, timeoutDuration)
   653  
   654  	// Register a session for this client to use in the subsequent tests.
   655  	sendMsg(t, test.createMsg, peer, timeoutDuration)
   656  	initReply := recvReply(
   657  		t, "MsgCreateSessionReply", peer, timeoutDuration,
   658  	).(*wtwire.CreateSessionReply)
   659  
   660  	// Fail if the server rejected our proposed CreateSession message.
   661  	if initReply.Code != wtwire.CodeOK {
   662  		t.Fatalf("server rejected session init")
   663  	}
   664  
   665  	// Check that the server closed the connection used to register the
   666  	// session.
   667  	assertConnClosed(t, peer, 2*timeoutDuration)
   668  
   669  	// Now that the original connection has been closed, connect a new
   670  	// client with the same session id.
   671  	peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   672  	connect(t, s, peer, test.initMsg, timeoutDuration)
   673  
   674  	// Send the intended StateUpdate messages in series.
   675  	for j, update := range test.updates {
   676  		// A nil update signals that we should wait for the prior
   677  		// connection to die, before re-register with the same session
   678  		// identifier.
   679  		if update == nil {
   680  			assertConnClosed(t, peer, 2*timeoutDuration)
   681  
   682  			peer = wtmock.NewMockPeer(localPub, peerPub, nil, 0)
   683  			connect(t, s, peer, test.initMsg, timeoutDuration)
   684  
   685  			continue
   686  		}
   687  
   688  		// Send the state update and verify it against our expected
   689  		// response.
   690  		sendMsg(t, update, peer, timeoutDuration)
   691  		reply := recvReply(
   692  			t, "MsgStateUpdateReply", peer, timeoutDuration,
   693  		).(*wtwire.StateUpdateReply)
   694  
   695  		if !reflect.DeepEqual(reply, test.replies[j]) {
   696  			t.Fatalf("[update %d] expected reply "+
   697  				"%v, got %d", j,
   698  				test.replies[j], reply)
   699  		}
   700  	}
   701  
   702  	// Check that the final connection is properly cleaned up by the server.
   703  	assertConnClosed(t, peer, 2*timeoutDuration)
   704  }
   705  
   706  // TestServerDeleteSession asserts the response to a DeleteSession request, and
   707  // checking that the proper error is returned when the session doesn't exist and
   708  // that a successful deletion does not disrupt other sessions.
   709  func TestServerDeleteSession(t *testing.T) {
   710  	db := wtmock.NewTowerDB()
   711  
   712  	localPub := randPubKey(t)
   713  
   714  	// Initialize two distinct peers with different session ids.
   715  	peerPub1 := randPubKey(t)
   716  	peerPub2 := randPubKey(t)
   717  
   718  	id1 := wtdb.NewSessionIDFromPubKey(peerPub1)
   719  	id2 := wtdb.NewSessionIDFromPubKey(peerPub2)
   720  
   721  	// Create closure to simplify assertions on session existence with the
   722  	// server's database.
   723  	hasSession := func(t *testing.T, id *wtdb.SessionID, shouldHave bool) {
   724  		t.Helper()
   725  
   726  		_, err := db.GetSessionInfo(id)
   727  		switch {
   728  		case shouldHave && err != nil:
   729  			t.Fatalf("expected server to have session %s, got: %v",
   730  				id, err)
   731  		case !shouldHave && err != wtdb.ErrSessionNotFound:
   732  			t.Fatalf("expected ErrSessionNotFound for session %s, "+
   733  				"got: %v", id, err)
   734  		}
   735  	}
   736  
   737  	initMsg := wtwire.NewInitMessage(
   738  		lnwire.NewRawFeatureVector(),
   739  		testnetChainHash,
   740  	)
   741  
   742  	createSession := &wtwire.CreateSession{
   743  		BlobType:     blob.TypeAltruistCommit,
   744  		MaxUpdates:   1000,
   745  		RewardBase:   0,
   746  		RewardRate:   0,
   747  		SweepFeeRate: 10000,
   748  	}
   749  
   750  	const timeoutDuration = 100 * time.Millisecond
   751  
   752  	s := initServer(t, db, timeoutDuration)
   753  	defer s.Stop()
   754  
   755  	// Create a session for peer2 so that the server's db isn't completely
   756  	// empty.
   757  	peer2 := wtmock.NewMockPeer(localPub, peerPub2, nil, 0)
   758  	connect(t, s, peer2, initMsg, timeoutDuration)
   759  	sendMsg(t, createSession, peer2, timeoutDuration)
   760  	assertConnClosed(t, peer2, 2*timeoutDuration)
   761  
   762  	// Our initial assertions are that peer2 has a valid session, but peer1
   763  	// has not created one.
   764  	hasSession(t, &id1, false)
   765  	hasSession(t, &id2, true)
   766  
   767  	peer1Msgs := []struct {
   768  		send   wtwire.Message
   769  		recv   wtwire.Message
   770  		assert func(t *testing.T)
   771  	}{
   772  		{
   773  			// Deleting unknown session should fail.
   774  			send: &wtwire.DeleteSession{},
   775  			recv: &wtwire.DeleteSessionReply{
   776  				Code: wtwire.DeleteSessionCodeNotFound,
   777  			},
   778  			assert: func(t *testing.T) {
   779  				// Peer2 should still be only session.
   780  				hasSession(t, &id1, false)
   781  				hasSession(t, &id2, true)
   782  			},
   783  		},
   784  		{
   785  			// Create session for peer1.
   786  			send: createSession,
   787  			recv: &wtwire.CreateSessionReply{
   788  				Code: wtwire.CodeOK,
   789  				Data: []byte{},
   790  			},
   791  			assert: func(t *testing.T) {
   792  				// Both peers should have sessions.
   793  				hasSession(t, &id1, true)
   794  				hasSession(t, &id2, true)
   795  			},
   796  		},
   797  
   798  		{
   799  			// Delete peer1's session.
   800  			send: &wtwire.DeleteSession{},
   801  			recv: &wtwire.DeleteSessionReply{
   802  				Code: wtwire.CodeOK,
   803  			},
   804  			assert: func(t *testing.T) {
   805  				// Peer1's session should have been removed.
   806  				hasSession(t, &id1, false)
   807  				hasSession(t, &id2, true)
   808  			},
   809  		},
   810  	}
   811  
   812  	// Now as peer1, process the canned messages defined above. This will:
   813  	// 1. Try to delete an unknown session and get a not found error code.
   814  	// 2. Create a new session using the same parameters as peer2.
   815  	// 3. Delete the newly created session and get an OK.
   816  	for _, msg := range peer1Msgs {
   817  		peer1 := wtmock.NewMockPeer(localPub, peerPub1, nil, 0)
   818  		connect(t, s, peer1, initMsg, timeoutDuration)
   819  		sendMsg(t, msg.send, peer1, timeoutDuration)
   820  		reply := recvReply(
   821  			t, msg.recv.MsgType().String(), peer1, timeoutDuration,
   822  		)
   823  
   824  		if !reflect.DeepEqual(reply, msg.recv) {
   825  			t.Fatalf("expected reply: %v, got: %v", msg.recv, reply)
   826  		}
   827  
   828  		assertConnClosed(t, peer1, 2*timeoutDuration)
   829  
   830  		// Invoke assertions after completing the request/response
   831  		// dance.
   832  		msg.assert(t)
   833  	}
   834  }
   835  
   836  func connect(t *testing.T, s wtserver.Interface, peer *wtmock.MockPeer,
   837  	initMsg *wtwire.Init, timeout time.Duration) {
   838  
   839  	t.Helper()
   840  
   841  	s.InboundPeerConnected(peer)
   842  	sendMsg(t, initMsg, peer, timeout)
   843  	recvReply(t, "MsgInit", peer, timeout)
   844  }
   845  
   846  // sendMsg sends a wtwire.Message message via a wtmock.MockPeer.
   847  func sendMsg(t *testing.T, msg wtwire.Message,
   848  	peer *wtmock.MockPeer, timeout time.Duration) {
   849  
   850  	t.Helper()
   851  
   852  	var b bytes.Buffer
   853  	_, err := wtwire.WriteMessage(&b, msg, 0)
   854  	if err != nil {
   855  		t.Fatalf("unable to encode %T message: %v",
   856  			msg, err)
   857  	}
   858  
   859  	select {
   860  	case peer.IncomingMsgs <- b.Bytes():
   861  	case <-time.After(2 * timeout):
   862  		t.Fatalf("unable to send %T message", msg)
   863  	}
   864  }
   865  
   866  // recvReply receives a message from the server, and parses it according to
   867  // expected reply type. The supported replies are CreateSessionReply and
   868  // StateUpdateReply.
   869  func recvReply(t *testing.T, name string, peer *wtmock.MockPeer,
   870  	timeout time.Duration) wtwire.Message {
   871  
   872  	t.Helper()
   873  
   874  	var (
   875  		msg wtwire.Message
   876  		err error
   877  	)
   878  
   879  	select {
   880  	case b := <-peer.OutgoingMsgs:
   881  		msg, err = wtwire.ReadMessage(bytes.NewReader(b), 0)
   882  		if err != nil {
   883  			t.Fatalf("unable to decode server "+
   884  				"reply: %v", err)
   885  		}
   886  
   887  	case <-time.After(2 * timeout):
   888  		t.Fatalf("server did not reply")
   889  	}
   890  
   891  	switch name {
   892  	case "MsgInit":
   893  		if _, ok := msg.(*wtwire.Init); !ok {
   894  			t.Fatalf("expected %s reply message, "+
   895  				"got %T", name, msg)
   896  		}
   897  	case "MsgCreateSessionReply":
   898  		if _, ok := msg.(*wtwire.CreateSessionReply); !ok {
   899  			t.Fatalf("expected %s reply message, "+
   900  				"got %T", name, msg)
   901  		}
   902  	case "MsgStateUpdateReply":
   903  		if _, ok := msg.(*wtwire.StateUpdateReply); !ok {
   904  			t.Fatalf("expected %s reply message, "+
   905  				"got %T", name, msg)
   906  		}
   907  	case "MsgDeleteSessionReply":
   908  		if _, ok := msg.(*wtwire.DeleteSessionReply); !ok {
   909  			t.Fatalf("expected %s reply message, "+
   910  				"got %T", name, msg)
   911  		}
   912  	}
   913  
   914  	return msg
   915  }
   916  
   917  // assertConnClosed checks that the peer's connection is closed before the
   918  // timeout expires.
   919  func assertConnClosed(t *testing.T, peer *wtmock.MockPeer, duration time.Duration) {
   920  	t.Helper()
   921  
   922  	select {
   923  	case <-peer.Quit:
   924  	case <-time.After(duration):
   925  		t.Fatalf("expected connection to be closed")
   926  	}
   927  }