github.com/decred/dcrlnd@v0.7.6/watchtower/wtdb/client_db_test.go (about)

     1  package wtdb_test
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	"io"
     7  	"io/ioutil"
     8  	"net"
     9  	"os"
    10  	"reflect"
    11  	"testing"
    12  
    13  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    14  	"github.com/decred/dcrlnd/kvdb"
    15  	"github.com/decred/dcrlnd/lnwire"
    16  	"github.com/decred/dcrlnd/watchtower/blob"
    17  	"github.com/decred/dcrlnd/watchtower/wtclient"
    18  	"github.com/decred/dcrlnd/watchtower/wtdb"
    19  	"github.com/decred/dcrlnd/watchtower/wtmock"
    20  	"github.com/decred/dcrlnd/watchtower/wtpolicy"
    21  )
    22  
    23  // clientDBInit is a closure used to initialize a wtclient.DB instance its
    24  // cleanup function.
    25  type clientDBInit func(t *testing.T) (wtclient.DB, func())
    26  
    27  type clientDBHarness struct {
    28  	t  *testing.T
    29  	db wtclient.DB
    30  }
    31  
    32  func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) {
    33  	db, cleanup := init(t)
    34  
    35  	h := &clientDBHarness{
    36  		t:  t,
    37  		db: db,
    38  	}
    39  
    40  	return h, cleanup
    41  }
    42  
    43  func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) {
    44  	h.t.Helper()
    45  
    46  	err := h.db.CreateClientSession(session)
    47  	if err != expErr {
    48  		h.t.Fatalf("expected create client session error: %v, got: %v",
    49  			expErr, err)
    50  	}
    51  }
    52  
    53  func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
    54  	h.t.Helper()
    55  
    56  	sessions, err := h.db.ListClientSessions(id)
    57  	if err != nil {
    58  		h.t.Fatalf("unable to list client sessions: %v", err)
    59  	}
    60  
    61  	return sessions
    62  }
    63  
    64  func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID,
    65  	blobType blob.Type) uint32 {
    66  
    67  	h.t.Helper()
    68  
    69  	index, err := h.db.NextSessionKeyIndex(id, blobType)
    70  	if err != nil {
    71  		h.t.Fatalf("unable to create next session key index: %v", err)
    72  	}
    73  
    74  	if index == 0 {
    75  		h.t.Fatalf("next key index should never be 0")
    76  	}
    77  
    78  	return index
    79  }
    80  
    81  func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress,
    82  	expErr error) *wtdb.Tower {
    83  
    84  	h.t.Helper()
    85  
    86  	tower, err := h.db.CreateTower(lnAddr)
    87  	if err != expErr {
    88  		h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err)
    89  	}
    90  
    91  	if tower.ID == 0 {
    92  		h.t.Fatalf("tower id should never be 0")
    93  	}
    94  
    95  	for _, session := range h.listSessions(&tower.ID) {
    96  		if session.Status != wtdb.CSessionActive {
    97  			h.t.Fatalf("expected status for session %v to be %v, "+
    98  				"got %v", session.ID, wtdb.CSessionActive,
    99  				session.Status)
   100  		}
   101  	}
   102  
   103  	return tower
   104  }
   105  
   106  func (h *clientDBHarness) removeTower(pubKey *secp256k1.PublicKey, addr net.Addr,
   107  	hasSessions bool, expErr error) {
   108  
   109  	h.t.Helper()
   110  
   111  	if err := h.db.RemoveTower(pubKey, addr); err != expErr {
   112  		h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err)
   113  	}
   114  	if expErr != nil {
   115  		return
   116  	}
   117  
   118  	if addr != nil {
   119  		tower, err := h.db.LoadTower(pubKey)
   120  		if err != nil {
   121  			h.t.Fatalf("expected tower %x to still exist",
   122  				pubKey.SerializeCompressed())
   123  		}
   124  
   125  		removedAddr := addr.String()
   126  		for _, towerAddr := range tower.Addresses {
   127  			if towerAddr.String() == removedAddr {
   128  				h.t.Fatalf("address %v not removed for tower %x",
   129  					removedAddr, pubKey.SerializeCompressed())
   130  			}
   131  		}
   132  	} else {
   133  		tower, err := h.db.LoadTower(pubKey)
   134  		if hasSessions && err != nil {
   135  			h.t.Fatalf("expected tower %x with sessions to still "+
   136  				"exist", pubKey.SerializeCompressed())
   137  		}
   138  		if !hasSessions && err == nil {
   139  			h.t.Fatalf("expected tower %x with no sessions to not "+
   140  				"exist", pubKey.SerializeCompressed())
   141  		}
   142  		if !hasSessions {
   143  			return
   144  		}
   145  		for _, session := range h.listSessions(&tower.ID) {
   146  			if session.Status != wtdb.CSessionInactive {
   147  				h.t.Fatalf("expected status for session %v to "+
   148  					"be %v, got %v", session.ID,
   149  					wtdb.CSessionInactive, session.Status)
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  func (h *clientDBHarness) loadTower(pubKey *secp256k1.PublicKey, expErr error) *wtdb.Tower {
   156  	h.t.Helper()
   157  
   158  	tower, err := h.db.LoadTower(pubKey)
   159  	if err != expErr {
   160  		h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
   161  	}
   162  
   163  	return tower
   164  }
   165  
   166  func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower {
   167  	h.t.Helper()
   168  
   169  	tower, err := h.db.LoadTowerByID(id)
   170  	if err != expErr {
   171  		h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err)
   172  	}
   173  
   174  	return tower
   175  }
   176  
   177  func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary {
   178  	h.t.Helper()
   179  
   180  	summaries, err := h.db.FetchChanSummaries()
   181  	if err != nil {
   182  		h.t.Fatalf("unable to fetch chan summaries: %v", err)
   183  	}
   184  
   185  	return summaries
   186  }
   187  
   188  func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
   189  	sweepPkScript []byte, expErr error) {
   190  
   191  	h.t.Helper()
   192  
   193  	err := h.db.RegisterChannel(chanID, sweepPkScript)
   194  	if err != expErr {
   195  		h.t.Fatalf("expected register channel error: %v, got: %v",
   196  			expErr, err)
   197  	}
   198  }
   199  
   200  func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID,
   201  	update *wtdb.CommittedUpdate, expErr error) uint16 {
   202  
   203  	h.t.Helper()
   204  
   205  	lastApplied, err := h.db.CommitUpdate(id, update)
   206  	if err != expErr {
   207  		h.t.Fatalf("expected commit update error: %v, got: %v",
   208  			expErr, err)
   209  	}
   210  
   211  	return lastApplied
   212  }
   213  
   214  func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16,
   215  	lastApplied uint16, expErr error) {
   216  
   217  	h.t.Helper()
   218  
   219  	err := h.db.AckUpdate(id, seqNum, lastApplied)
   220  	if err != expErr {
   221  		h.t.Fatalf("expected commit update error: %v, got: %v",
   222  			expErr, err)
   223  	}
   224  }
   225  
   226  // testCreateClientSession asserts various conditions regarding the creation of
   227  // a new ClientSession. The test asserts:
   228  //   - client sessions can only be created if a session key index is reserved.
   229  //   - client sessions cannot be created with an incorrect session key index .
   230  //   - inserting duplicate sessions fails.
   231  func testCreateClientSession(h *clientDBHarness) {
   232  	const blobType = blob.TypeAltruistAnchorCommit
   233  
   234  	// Create a test client session to insert.
   235  	session := &wtdb.ClientSession{
   236  		ClientSessionBody: wtdb.ClientSessionBody{
   237  			TowerID: wtdb.TowerID(3),
   238  			Policy: wtpolicy.Policy{
   239  				TxPolicy: wtpolicy.TxPolicy{
   240  					BlobType: blobType,
   241  				},
   242  				MaxUpdates: 100,
   243  			},
   244  			RewardPkScript: []byte{0x01, 0x02, 0x03},
   245  		},
   246  		ID: wtdb.SessionID([33]byte{0x01}),
   247  	}
   248  
   249  	// First, assert that this session is not already present in the
   250  	// database.
   251  	if _, ok := h.listSessions(nil)[session.ID]; ok {
   252  		h.t.Fatalf("session for id %x should not exist yet", session.ID)
   253  	}
   254  
   255  	// Attempting to insert the client session without reserving a session
   256  	// key index should fail.
   257  	h.insertSession(session, wtdb.ErrNoReservedKeyIndex)
   258  
   259  	// Now, reserve a session key for this tower.
   260  	keyIndex := h.nextKeyIndex(session.TowerID, blobType)
   261  
   262  	// The client session hasn't been updated with the reserved key index
   263  	// (since it's still zero). Inserting should fail due to the mismatch.
   264  	h.insertSession(session, wtdb.ErrIncorrectKeyIndex)
   265  
   266  	// Reserve another key for the same index. Since no session has been
   267  	// successfully created, it should return the same index to maintain
   268  	// idempotency across restarts.
   269  	keyIndex2 := h.nextKeyIndex(session.TowerID, blobType)
   270  	if keyIndex != keyIndex2 {
   271  		h.t.Fatalf("next key index should be idempotent: want: %v, "+
   272  			"got %v", keyIndex, keyIndex2)
   273  	}
   274  
   275  	// Now, set the client session's key index so that it is proper and
   276  	// insert it. This should succeed.
   277  	session.KeyIndex = keyIndex
   278  	h.insertSession(session, nil)
   279  
   280  	// Verify that the session now exists in the database.
   281  	if _, ok := h.listSessions(nil)[session.ID]; !ok {
   282  		h.t.Fatalf("session for id %x should exist now", session.ID)
   283  	}
   284  
   285  	// Attempt to insert the session again, which should fail due to the
   286  	// session already existing.
   287  	h.insertSession(session, wtdb.ErrClientSessionAlreadyExists)
   288  
   289  	// Finally, assert that reserving another key index succeeds with a
   290  	// different key index, now that the first one has been finalized.
   291  	keyIndex3 := h.nextKeyIndex(session.TowerID, blobType)
   292  	if keyIndex == keyIndex3 {
   293  		h.t.Fatalf("key index still reserved after creating session")
   294  	}
   295  }
   296  
   297  // testFilterClientSessions asserts that we can correctly filter client sessions
   298  // for a specific tower.
   299  func testFilterClientSessions(h *clientDBHarness) {
   300  	// We'll create three client sessions, the first two belonging to one
   301  	// tower, and the last belonging to another one.
   302  	const numSessions = 3
   303  	const blobType = blob.TypeAltruistCommit
   304  	towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID)
   305  	for i := 0; i < numSessions; i++ {
   306  		towerID := wtdb.TowerID(1)
   307  		if i == numSessions-1 {
   308  			towerID = wtdb.TowerID(2)
   309  		}
   310  		keyIndex := h.nextKeyIndex(towerID, blobType)
   311  		sessionID := wtdb.SessionID([33]byte{byte(i)})
   312  		h.insertSession(&wtdb.ClientSession{
   313  			ClientSessionBody: wtdb.ClientSessionBody{
   314  				TowerID: towerID,
   315  				Policy: wtpolicy.Policy{
   316  					TxPolicy: wtpolicy.TxPolicy{
   317  						BlobType: blobType,
   318  					},
   319  					MaxUpdates: 100,
   320  				},
   321  				RewardPkScript: []byte{0x01, 0x02, 0x03},
   322  				KeyIndex:       keyIndex,
   323  			},
   324  			ID: sessionID,
   325  		}, nil)
   326  		towerSessions[towerID] = append(towerSessions[towerID], sessionID)
   327  	}
   328  
   329  	// We should see the expected sessions for each tower when filtering
   330  	// them.
   331  	for towerID, expectedSessions := range towerSessions {
   332  		sessions := h.listSessions(&towerID)
   333  		if len(sessions) != len(expectedSessions) {
   334  			h.t.Fatalf("expected %v sessions for tower %v, got %v",
   335  				len(expectedSessions), towerID, len(sessions))
   336  		}
   337  		for _, expectedSession := range expectedSessions {
   338  			if _, ok := sessions[expectedSession]; !ok {
   339  				h.t.Fatalf("expected session %v for tower %v",
   340  					expectedSession, towerID)
   341  			}
   342  		}
   343  	}
   344  }
   345  
   346  // testCreateTower asserts the behavior of creating new Tower objects within the
   347  // database, and that the latest address is always prepended to the list of
   348  // known addresses for the tower.
   349  func testCreateTower(h *clientDBHarness) {
   350  	// Test that loading a tower with an arbitrary tower id fails.
   351  	h.loadTowerByID(20, wtdb.ErrTowerNotFound)
   352  
   353  	pk, err := randPubKey()
   354  	if err != nil {
   355  		h.t.Fatalf("unable to generate pubkey: %v", err)
   356  	}
   357  
   358  	addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
   359  	lnAddr := &lnwire.NetAddress{
   360  		IdentityKey: pk,
   361  		Address:     addr1,
   362  	}
   363  
   364  	// Insert a random tower into the database.
   365  	tower := h.createTower(lnAddr, nil)
   366  
   367  	// Load the tower from the database and assert that it matches the tower
   368  	// we created.
   369  	tower2 := h.loadTowerByID(tower.ID, nil)
   370  	if !reflect.DeepEqual(tower, tower2) {
   371  		h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
   372  			tower, tower2)
   373  	}
   374  	tower2 = h.loadTower(pk, err)
   375  	if !reflect.DeepEqual(tower, tower2) {
   376  		h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
   377  			tower, tower2)
   378  	}
   379  
   380  	// Insert the address again into the database. Since the address is the
   381  	// same, this should result in an unmodified tower record.
   382  	towerDupAddr := h.createTower(lnAddr, nil)
   383  	if len(towerDupAddr.Addresses) != 1 {
   384  		h.t.Fatalf("duplicate address should be deduped")
   385  	}
   386  	if !reflect.DeepEqual(tower, towerDupAddr) {
   387  		h.t.Fatalf("mismatch towers, want: %v, got: %v",
   388  			tower, towerDupAddr)
   389  	}
   390  
   391  	// Generate a new address for this tower.
   392  	addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
   393  
   394  	lnAddr2 := &lnwire.NetAddress{
   395  		IdentityKey: pk,
   396  		Address:     addr2,
   397  	}
   398  
   399  	// Insert the updated address, which should produce a tower with a new
   400  	// address.
   401  	towerNewAddr := h.createTower(lnAddr2, nil)
   402  
   403  	// Load the tower from the database, and assert that it matches the
   404  	// tower returned from creation.
   405  	towerNewAddr2 := h.loadTowerByID(tower.ID, nil)
   406  	if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
   407  		h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
   408  			towerNewAddr, towerNewAddr2)
   409  	}
   410  	towerNewAddr2 = h.loadTower(pk, nil)
   411  	if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) {
   412  		h.t.Fatalf("loaded tower mismatch, want: %v, got: %v",
   413  			towerNewAddr, towerNewAddr2)
   414  	}
   415  
   416  	// Assert that there are now two addresses on the tower object.
   417  	if len(towerNewAddr.Addresses) != 2 {
   418  		h.t.Fatalf("new address should be added")
   419  	}
   420  
   421  	// Finally, assert that the new address was prepended since it is deemed
   422  	// fresher.
   423  	if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) {
   424  		h.t.Fatalf("new address should be prepended")
   425  	}
   426  }
   427  
   428  // testRemoveTower asserts the behavior of removing Tower objects as a whole and
   429  // removing addresses from Tower objects within the database.
   430  func testRemoveTower(h *clientDBHarness) {
   431  	// Generate a random public key we'll use for our tower.
   432  	pk, err := randPubKey()
   433  	if err != nil {
   434  		h.t.Fatalf("unable to generate pubkey: %v", err)
   435  	}
   436  
   437  	// Removing a tower that does not exist within the database should
   438  	// result in a NOP.
   439  	h.removeTower(pk, nil, false, nil)
   440  
   441  	// We'll create a tower with two addresses.
   442  	addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911}
   443  	addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911}
   444  	h.createTower(&lnwire.NetAddress{
   445  		IdentityKey: pk,
   446  		Address:     addr1,
   447  	}, nil)
   448  	h.createTower(&lnwire.NetAddress{
   449  		IdentityKey: pk,
   450  		Address:     addr2,
   451  	}, nil)
   452  
   453  	// We'll then remove the second address. We should now only see the
   454  	// first.
   455  	h.removeTower(pk, addr2, false, nil)
   456  
   457  	// We'll then remove the first address. We should now see that the tower
   458  	// has no addresses left.
   459  	h.removeTower(pk, addr1, false, wtdb.ErrLastTowerAddr)
   460  
   461  	// Removing the tower as a whole from the database should succeed since
   462  	// there aren't any active sessions for it.
   463  	h.removeTower(pk, nil, false, nil)
   464  
   465  	// We'll then recreate the tower, but this time we'll create a session
   466  	// for it.
   467  	tower := h.createTower(&lnwire.NetAddress{
   468  		IdentityKey: pk,
   469  		Address:     addr1,
   470  	}, nil)
   471  
   472  	const blobType = blob.TypeAltruistCommit
   473  	session := &wtdb.ClientSession{
   474  		ClientSessionBody: wtdb.ClientSessionBody{
   475  			TowerID: tower.ID,
   476  			Policy: wtpolicy.Policy{
   477  				TxPolicy: wtpolicy.TxPolicy{
   478  					BlobType: blobType,
   479  				},
   480  				MaxUpdates: 100,
   481  			},
   482  			RewardPkScript: []byte{0x01, 0x02, 0x03},
   483  			KeyIndex:       h.nextKeyIndex(tower.ID, blobType),
   484  		},
   485  		ID: wtdb.SessionID([33]byte{0x01}),
   486  	}
   487  	h.insertSession(session, nil)
   488  	update := randCommittedUpdate(h.t, 1)
   489  	h.commitUpdate(&session.ID, update, nil)
   490  
   491  	// We should not be able to fully remove it from the database since
   492  	// there's a session and it has unacked updates.
   493  	h.removeTower(pk, nil, true, wtdb.ErrTowerUnackedUpdates)
   494  
   495  	// Removing the tower after all sessions no longer have unacked updates
   496  	// should result in the sessions becoming inactive.
   497  	h.ackUpdate(&session.ID, 1, 1, nil)
   498  	h.removeTower(pk, nil, true, nil)
   499  
   500  	// Creating the tower again should mark all of the sessions active once
   501  	// again.
   502  	h.createTower(&lnwire.NetAddress{
   503  		IdentityKey: pk,
   504  		Address:     addr1,
   505  	}, nil)
   506  }
   507  
   508  // testChanSummaries tests the process of a registering a channel and its
   509  // associated sweep pkscript.
   510  func testChanSummaries(h *clientDBHarness) {
   511  	// First, assert that this channel is not already registered.
   512  	var chanID lnwire.ChannelID
   513  	if _, ok := h.fetchChanSummaries()[chanID]; ok {
   514  		h.t.Fatalf("pkscript for channel %x should not exist yet",
   515  			chanID)
   516  	}
   517  
   518  	// Generate a random sweep pkscript and register it for this channel.
   519  	expPkScript := make([]byte, 22)
   520  	if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil {
   521  		h.t.Fatalf("unable to generate pkscript: %v", err)
   522  	}
   523  	h.registerChan(chanID, expPkScript, nil)
   524  
   525  	// Assert that the channel exists and that its sweep pkscript matches
   526  	// the one we registered.
   527  	summary, ok := h.fetchChanSummaries()[chanID]
   528  	if !ok {
   529  		h.t.Fatalf("pkscript for channel %x should not exist yet",
   530  			chanID)
   531  	} else if !bytes.Equal(expPkScript, summary.SweepPkScript) {
   532  		h.t.Fatalf("pkscript mismatch, want: %x, got: %x",
   533  			expPkScript, summary.SweepPkScript)
   534  	}
   535  
   536  	// Finally, assert that re-registering the same channel produces a
   537  	// failure.
   538  	h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered)
   539  }
   540  
   541  // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can
   542  func testCommitUpdate(h *clientDBHarness) {
   543  	const blobType = blob.TypeAltruistCommit
   544  	session := &wtdb.ClientSession{
   545  		ClientSessionBody: wtdb.ClientSessionBody{
   546  			TowerID: wtdb.TowerID(3),
   547  			Policy: wtpolicy.Policy{
   548  				TxPolicy: wtpolicy.TxPolicy{
   549  					BlobType: blobType,
   550  				},
   551  				MaxUpdates: 100,
   552  			},
   553  			RewardPkScript: []byte{0x01, 0x02, 0x03},
   554  		},
   555  		ID: wtdb.SessionID([33]byte{0x02}),
   556  	}
   557  
   558  	// Generate a random update and try to commit before inserting the
   559  	// session, which should fail.
   560  	update1 := randCommittedUpdate(h.t, 1)
   561  	h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
   562  
   563  	// Reserve a session key index and insert the session.
   564  	session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
   565  	h.insertSession(session, nil)
   566  
   567  	// Now, try to commit the update that failed initially which should
   568  	// succeed. The lastApplied value should be 0 since we have not received
   569  	// an ack from the tower.
   570  	lastApplied := h.commitUpdate(&session.ID, update1, nil)
   571  	if lastApplied != 0 {
   572  		h.t.Fatalf("last applied mismatch, want: 0, got: %v",
   573  			lastApplied)
   574  	}
   575  
   576  	// Assert that the committed update appears in the client session's
   577  	// CommittedUpdates map when loaded from disk and that there are no
   578  	// AckedUpdates.
   579  	dbSession := h.listSessions(nil)[session.ID]
   580  	checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
   581  		*update1,
   582  	})
   583  	checkAckedUpdates(h.t, dbSession, nil)
   584  
   585  	// Try to commit the same update, which should succeed due to
   586  	// idempotency (which is preserved when the breach hint is identical to
   587  	// the on-disk update's hint). The lastApplied value should remain
   588  	// unchanged.
   589  	lastApplied2 := h.commitUpdate(&session.ID, update1, nil)
   590  	if lastApplied2 != lastApplied {
   591  		h.t.Fatalf("last applied should not have changed, got %v",
   592  			lastApplied2)
   593  	}
   594  
   595  	// Assert that the loaded ClientSession is the same as before.
   596  	dbSession = h.listSessions(nil)[session.ID]
   597  	checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
   598  		*update1,
   599  	})
   600  	checkAckedUpdates(h.t, dbSession, nil)
   601  
   602  	// Generate another random update and try to commit it at the identical
   603  	// sequence number. Since the breach hint has changed, this should fail.
   604  	update2 := randCommittedUpdate(h.t, 1)
   605  	h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted)
   606  
   607  	// Next, insert the new update at the next unallocated sequence number
   608  	// which should succeed.
   609  	update2.SeqNum = 2
   610  	lastApplied3 := h.commitUpdate(&session.ID, update2, nil)
   611  	if lastApplied3 != lastApplied {
   612  		h.t.Fatalf("last applied should not have changed, got %v",
   613  			lastApplied3)
   614  	}
   615  
   616  	// Check that both updates now appear as committed on the ClientSession
   617  	// loaded from disk.
   618  	dbSession = h.listSessions(nil)[session.ID]
   619  	checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
   620  		*update1,
   621  		*update2,
   622  	})
   623  	checkAckedUpdates(h.t, dbSession, nil)
   624  
   625  	// Finally, create one more random update and try to commit it at index
   626  	// 4, which should be rejected since 3 is the next slot the database
   627  	// expects.
   628  	update4 := randCommittedUpdate(h.t, 4)
   629  	h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
   630  
   631  	// Assert that the ClientSession loaded from disk remains unchanged.
   632  	dbSession = h.listSessions(nil)[session.ID]
   633  	checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
   634  		*update1,
   635  		*update2,
   636  	})
   637  	checkAckedUpdates(h.t, dbSession, nil)
   638  }
   639  
   640  // testAckUpdate asserts the behavior of AckUpdate.
   641  func testAckUpdate(h *clientDBHarness) {
   642  	const blobType = blob.TypeAltruistCommit
   643  
   644  	// Create a new session that the updates in this will be tied to.
   645  	session := &wtdb.ClientSession{
   646  		ClientSessionBody: wtdb.ClientSessionBody{
   647  			TowerID: wtdb.TowerID(3),
   648  			Policy: wtpolicy.Policy{
   649  				TxPolicy: wtpolicy.TxPolicy{
   650  					BlobType: blobType,
   651  				},
   652  				MaxUpdates: 100,
   653  			},
   654  			RewardPkScript: []byte{0x01, 0x02, 0x03},
   655  		},
   656  		ID: wtdb.SessionID([33]byte{0x03}),
   657  	}
   658  
   659  	// Try to ack an update before inserting the client session, which
   660  	// should fail.
   661  	h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound)
   662  
   663  	// Reserve a session key and insert the client session.
   664  	session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
   665  	h.insertSession(session, nil)
   666  
   667  	// Now, try to ack update 1. This should fail since update 1 was never
   668  	// committed.
   669  	h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound)
   670  
   671  	// Commit to a random update at seqnum 1.
   672  	update1 := randCommittedUpdate(h.t, 1)
   673  	lastApplied := h.commitUpdate(&session.ID, update1, nil)
   674  	if lastApplied != 0 {
   675  		h.t.Fatalf("last applied mismatch, want: 0, got: %v",
   676  			lastApplied)
   677  	}
   678  
   679  	// Acking seqnum 1 should succeed.
   680  	h.ackUpdate(&session.ID, 1, 1, nil)
   681  
   682  	// Acking seqnum 1 again should fail.
   683  	h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound)
   684  
   685  	// Acking a valid seqnum with a reverted last applied value should fail.
   686  	h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion)
   687  
   688  	// Acking with a last applied greater than any allocated seqnum should
   689  	// fail.
   690  	h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
   691  
   692  	// Assert that the ClientSession loaded from disk has one update in it's
   693  	// AckedUpdates map, and that the committed update has been removed.
   694  	dbSession := h.listSessions(nil)[session.ID]
   695  	checkCommittedUpdates(h.t, dbSession, nil)
   696  	checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
   697  		1: update1.BackupID,
   698  	})
   699  
   700  	// Commit to another random update, and assert that the last applied
   701  	// value is 1, since this was what was provided in the last successful
   702  	// ack.
   703  	update2 := randCommittedUpdate(h.t, 2)
   704  	lastApplied = h.commitUpdate(&session.ID, update2, nil)
   705  	if lastApplied != 1 {
   706  		h.t.Fatalf("last applied mismatch, want: 1, got: %v",
   707  			lastApplied)
   708  	}
   709  
   710  	// Ack seqnum 2.
   711  	h.ackUpdate(&session.ID, 2, 2, nil)
   712  
   713  	// Assert that both updates exist as AckedUpdates when loaded from disk.
   714  	dbSession = h.listSessions(nil)[session.ID]
   715  	checkCommittedUpdates(h.t, dbSession, nil)
   716  	checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
   717  		1: update1.BackupID,
   718  		2: update2.BackupID,
   719  	})
   720  
   721  	// Acking again with a lower last applied should fail.
   722  	h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion)
   723  
   724  	// Acking an unallocated seqnum should fail.
   725  	h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound)
   726  
   727  	// Acking with a last applied greater than any allocated seqnum should
   728  	// fail.
   729  	h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
   730  }
   731  
   732  // checkCommittedUpdates asserts that the CommittedUpdates on session match the
   733  // expUpdates provided.
   734  func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
   735  	expUpdates []wtdb.CommittedUpdate) {
   736  
   737  	t.Helper()
   738  
   739  	// We promote nil expUpdates to an initialized slice since the database
   740  	// should never return a nil slice. This promotion is done purely out of
   741  	// convenience for the testing framework.
   742  	if expUpdates == nil {
   743  		expUpdates = make([]wtdb.CommittedUpdate, 0)
   744  	}
   745  
   746  	if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) {
   747  		t.Fatalf("committed updates mismatch, want: %v, got: %v",
   748  			expUpdates, session.CommittedUpdates)
   749  	}
   750  }
   751  
   752  // checkAckedUpdates asserts that the AckedUpdates on a sessio match the
   753  // expUpdates provided.
   754  func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
   755  	expUpdates map[uint16]wtdb.BackupID) {
   756  
   757  	// We promote nil expUpdates to an initialized map since the database
   758  	// should never return a nil map. This promotion is done purely out of
   759  	// convenience for the testing framework.
   760  	if expUpdates == nil {
   761  		expUpdates = make(map[uint16]wtdb.BackupID)
   762  	}
   763  
   764  	if !reflect.DeepEqual(session.AckedUpdates, expUpdates) {
   765  		t.Fatalf("acked updates mismatch, want: %v, got: %v",
   766  			expUpdates, session.AckedUpdates)
   767  	}
   768  }
   769  
   770  // TestClientDB asserts the behavior of a fresh client db, a reopened client db,
   771  // and the mock implementation. This ensures that all databases function
   772  // identically, especially in the negative paths.
   773  func TestClientDB(t *testing.T) {
   774  	dbCfg := &kvdb.BoltConfig{DBTimeout: kvdb.DefaultDBTimeout}
   775  	dbs := []struct {
   776  		name string
   777  		init clientDBInit
   778  	}{
   779  		{
   780  			name: "fresh clientdb",
   781  			init: func(t *testing.T) (wtclient.DB, func()) {
   782  				path, err := ioutil.TempDir("", "clientdb")
   783  				if err != nil {
   784  					t.Fatalf("unable to make temp dir: %v",
   785  						err)
   786  				}
   787  
   788  				bdb, err := wtdb.NewBoltBackendCreator(
   789  					true, path, "wtclient.db",
   790  				)(dbCfg)
   791  				if err != nil {
   792  					os.RemoveAll(path)
   793  					t.Fatalf("unable to open db: %v", err)
   794  				}
   795  
   796  				db, err := wtdb.OpenClientDB(bdb)
   797  				if err != nil {
   798  					os.RemoveAll(path)
   799  					t.Fatalf("unable to open db: %v", err)
   800  				}
   801  
   802  				cleanup := func() {
   803  					db.Close()
   804  					os.RemoveAll(path)
   805  				}
   806  
   807  				return db, cleanup
   808  			},
   809  		},
   810  		{
   811  			name: "reopened clientdb",
   812  			init: func(t *testing.T) (wtclient.DB, func()) {
   813  				path, err := ioutil.TempDir("", "clientdb")
   814  				if err != nil {
   815  					t.Fatalf("unable to make temp dir: %v",
   816  						err)
   817  				}
   818  
   819  				bdb, err := wtdb.NewBoltBackendCreator(
   820  					true, path, "wtclient.db",
   821  				)(dbCfg)
   822  				if err != nil {
   823  					os.RemoveAll(path)
   824  					t.Fatalf("unable to open db: %v", err)
   825  				}
   826  
   827  				db, err := wtdb.OpenClientDB(bdb)
   828  				if err != nil {
   829  					os.RemoveAll(path)
   830  					t.Fatalf("unable to open db: %v", err)
   831  				}
   832  				db.Close()
   833  
   834  				bdb, err = wtdb.NewBoltBackendCreator(
   835  					true, path, "wtclient.db",
   836  				)(dbCfg)
   837  				if err != nil {
   838  					os.RemoveAll(path)
   839  					t.Fatalf("unable to open db: %v", err)
   840  				}
   841  
   842  				db, err = wtdb.OpenClientDB(bdb)
   843  				if err != nil {
   844  					os.RemoveAll(path)
   845  					t.Fatalf("unable to reopen db: %v", err)
   846  				}
   847  
   848  				cleanup := func() {
   849  					db.Close()
   850  					os.RemoveAll(path)
   851  				}
   852  
   853  				return db, cleanup
   854  			},
   855  		},
   856  		{
   857  			name: "mock",
   858  			init: func(t *testing.T) (wtclient.DB, func()) {
   859  				return wtmock.NewClientDB(), func() {}
   860  			},
   861  		},
   862  	}
   863  
   864  	tests := []struct {
   865  		name string
   866  		run  func(*clientDBHarness)
   867  	}{
   868  		{
   869  			name: "create client session",
   870  			run:  testCreateClientSession,
   871  		},
   872  		{
   873  			name: "filter client sessions",
   874  			run:  testFilterClientSessions,
   875  		},
   876  		{
   877  			name: "create tower",
   878  			run:  testCreateTower,
   879  		},
   880  		{
   881  			name: "remove tower",
   882  			run:  testRemoveTower,
   883  		},
   884  		{
   885  			name: "chan summaries",
   886  			run:  testChanSummaries,
   887  		},
   888  		{
   889  			name: "commit update",
   890  			run:  testCommitUpdate,
   891  		},
   892  		{
   893  			name: "ack update",
   894  			run:  testAckUpdate,
   895  		},
   896  	}
   897  
   898  	for _, database := range dbs {
   899  		db := database
   900  		t.Run(db.name, func(t *testing.T) {
   901  			t.Parallel()
   902  
   903  			for _, test := range tests {
   904  				t.Run(test.name, func(t *testing.T) {
   905  					h, cleanup := newClientDBHarness(
   906  						t, db.init,
   907  					)
   908  					defer cleanup()
   909  
   910  					test.run(h)
   911  				})
   912  			}
   913  		})
   914  	}
   915  }
   916  
   917  // randCommittedUpdate generates a random committed update.
   918  func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate {
   919  	var chanID lnwire.ChannelID
   920  	if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil {
   921  		t.Fatalf("unable to generate chan id: %v", err)
   922  	}
   923  
   924  	var hint blob.BreachHint
   925  	if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil {
   926  		t.Fatalf("unable to generate breach hint: %v", err)
   927  	}
   928  
   929  	encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type()))
   930  	if _, err := io.ReadFull(crand.Reader, encBlob); err != nil {
   931  		t.Fatalf("unable to generate encrypted blob: %v", err)
   932  	}
   933  
   934  	return &wtdb.CommittedUpdate{
   935  		SeqNum: seqNum,
   936  		CommittedUpdateBody: wtdb.CommittedUpdateBody{
   937  			BackupID: wtdb.BackupID{
   938  				ChanID:       chanID,
   939  				CommitHeight: 666,
   940  			},
   941  			Hint:          hint,
   942  			EncryptedBlob: encBlob,
   943  		},
   944  	}
   945  }