github.com/decred/dcrlnd@v0.7.6/watchtower/wtmock/client_db.go (about)

     1  package wtmock
     2  
     3  import (
     4  	"net"
     5  	"sync"
     6  	"sync/atomic"
     7  
     8  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
     9  	"github.com/decred/dcrlnd/lnwire"
    10  	"github.com/decred/dcrlnd/watchtower/blob"
    11  	"github.com/decred/dcrlnd/watchtower/wtdb"
    12  )
    13  
    14  type towerPK [33]byte
    15  
    16  type keyIndexKey struct {
    17  	towerID  wtdb.TowerID
    18  	blobType blob.Type
    19  }
    20  
    21  // ClientDB is a mock, in-memory database or testing the watchtower client
    22  // behavior.
    23  type ClientDB struct {
    24  	nextTowerID uint64 // to be used atomically
    25  
    26  	mu             sync.Mutex
    27  	summaries      map[lnwire.ChannelID]wtdb.ClientChanSummary
    28  	activeSessions map[wtdb.SessionID]wtdb.ClientSession
    29  	towerIndex     map[towerPK]wtdb.TowerID
    30  	towers         map[wtdb.TowerID]*wtdb.Tower
    31  
    32  	nextIndex     uint32
    33  	indexes       map[keyIndexKey]uint32
    34  	legacyIndexes map[wtdb.TowerID]uint32
    35  }
    36  
    37  // NewClientDB initializes a new mock ClientDB.
    38  func NewClientDB() *ClientDB {
    39  	return &ClientDB{
    40  		summaries:      make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
    41  		activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
    42  		towerIndex:     make(map[towerPK]wtdb.TowerID),
    43  		towers:         make(map[wtdb.TowerID]*wtdb.Tower),
    44  		indexes:        make(map[keyIndexKey]uint32),
    45  		legacyIndexes:  make(map[wtdb.TowerID]uint32),
    46  	}
    47  }
    48  
    49  // CreateTower initialize an address record used to communicate with a
    50  // watchtower. Each Tower is assigned a unique ID, that is used to amortize
    51  // storage costs of the public key when used by multiple sessions. If the tower
    52  // already exists, the address is appended to the list of all addresses used to
    53  // that tower previously and its corresponding sessions are marked as active.
    54  func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
    55  	m.mu.Lock()
    56  	defer m.mu.Unlock()
    57  
    58  	var towerPubKey towerPK
    59  	copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
    60  
    61  	var tower *wtdb.Tower
    62  	towerID, ok := m.towerIndex[towerPubKey]
    63  	if ok {
    64  		tower = m.towers[towerID]
    65  		tower.AddAddress(lnAddr.Address)
    66  
    67  		towerSessions, err := m.listClientSessions(&towerID)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		for id, session := range towerSessions {
    72  			session.Status = wtdb.CSessionActive
    73  			m.activeSessions[id] = *session
    74  		}
    75  	} else {
    76  		towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
    77  		tower = &wtdb.Tower{
    78  			ID:          towerID,
    79  			IdentityKey: lnAddr.IdentityKey,
    80  			Addresses:   []net.Addr{lnAddr.Address},
    81  		}
    82  	}
    83  
    84  	m.towerIndex[towerPubKey] = towerID
    85  	m.towers[towerID] = tower
    86  
    87  	return copyTower(tower), nil
    88  }
    89  
    90  // RemoveTower modifies a tower's record within the database. If an address is
    91  // provided, then _only_ the address record should be removed from the tower's
    92  // persisted state. Otherwise, we'll attempt to mark the tower as inactive by
    93  // marking all of its sessions inactive. If any of its sessions has unacked
    94  // updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
    95  // any sessions at all, it'll be completely removed from the database.
    96  //
    97  // NOTE: An error is not returned if the tower doesn't exist.
    98  func (m *ClientDB) RemoveTower(pubKey *secp256k1.PublicKey, addr net.Addr) error {
    99  	m.mu.Lock()
   100  	defer m.mu.Unlock()
   101  
   102  	tower, err := m.loadTower(pubKey)
   103  	if err == wtdb.ErrTowerNotFound {
   104  		return nil
   105  	}
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	if addr != nil {
   111  		tower.RemoveAddress(addr)
   112  		if len(tower.Addresses) == 0 {
   113  			return wtdb.ErrLastTowerAddr
   114  		}
   115  		m.towers[tower.ID] = tower
   116  		return nil
   117  	}
   118  
   119  	towerSessions, err := m.listClientSessions(&tower.ID)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	if len(towerSessions) == 0 {
   124  		var towerPK towerPK
   125  		copy(towerPK[:], pubKey.SerializeCompressed())
   126  		delete(m.towerIndex, towerPK)
   127  		delete(m.towers, tower.ID)
   128  		return nil
   129  	}
   130  
   131  	for id, session := range towerSessions {
   132  		if len(session.CommittedUpdates) > 0 {
   133  			return wtdb.ErrTowerUnackedUpdates
   134  		}
   135  		session.Status = wtdb.CSessionInactive
   136  		m.activeSessions[id] = *session
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  // LoadTower retrieves a tower by its public key.
   143  func (m *ClientDB) LoadTower(pubKey *secp256k1.PublicKey) (*wtdb.Tower, error) {
   144  	m.mu.Lock()
   145  	defer m.mu.Unlock()
   146  	return m.loadTower(pubKey)
   147  }
   148  
   149  // loadTower retrieves a tower by its public key.
   150  //
   151  // NOTE: This method requires the database's lock to be acquired.
   152  func (m *ClientDB) loadTower(pubKey *secp256k1.PublicKey) (*wtdb.Tower, error) {
   153  	var towerPK towerPK
   154  	copy(towerPK[:], pubKey.SerializeCompressed())
   155  
   156  	towerID, ok := m.towerIndex[towerPK]
   157  	if !ok {
   158  		return nil, wtdb.ErrTowerNotFound
   159  	}
   160  	tower, ok := m.towers[towerID]
   161  	if !ok {
   162  		return nil, wtdb.ErrTowerNotFound
   163  	}
   164  
   165  	return copyTower(tower), nil
   166  }
   167  
   168  // LoadTowerByID retrieves a tower by its tower ID.
   169  func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) {
   170  	m.mu.Lock()
   171  	defer m.mu.Unlock()
   172  
   173  	if tower, ok := m.towers[towerID]; ok {
   174  		return copyTower(tower), nil
   175  	}
   176  
   177  	return nil, wtdb.ErrTowerNotFound
   178  }
   179  
   180  // ListTowers retrieves the list of towers available within the database.
   181  func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
   182  	m.mu.Lock()
   183  	defer m.mu.Unlock()
   184  
   185  	towers := make([]*wtdb.Tower, 0, len(m.towers))
   186  	for _, tower := range m.towers {
   187  		towers = append(towers, copyTower(tower))
   188  	}
   189  
   190  	return towers, nil
   191  }
   192  
   193  // MarkBackupIneligible records that particular commit height is ineligible for
   194  // backup. This allows the client to track which updates it should not attempt
   195  // to retry after startup.
   196  func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
   197  	return nil
   198  }
   199  
   200  // ListClientSessions returns the set of all client sessions known to the db. An
   201  // optional tower ID can be used to filter out any client sessions in the
   202  // response that do not correspond to this tower.
   203  func (m *ClientDB) ListClientSessions(
   204  	tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
   205  
   206  	m.mu.Lock()
   207  	defer m.mu.Unlock()
   208  	return m.listClientSessions(tower)
   209  }
   210  
   211  // listClientSessions returns the set of all client sessions known to the db. An
   212  // optional tower ID can be used to filter out any client sessions in the
   213  // response that do not correspond to this tower.
   214  func (m *ClientDB) listClientSessions(
   215  	tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
   216  
   217  	sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
   218  	for _, session := range m.activeSessions {
   219  		session := session
   220  		if tower != nil && *tower != session.TowerID {
   221  			continue
   222  		}
   223  		sessions[session.ID] = &session
   224  	}
   225  
   226  	return sessions, nil
   227  }
   228  
   229  // CreateClientSession records a newly negotiated client session in the set of
   230  // active sessions. The session can be identified by its SessionID.
   231  func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
   232  	m.mu.Lock()
   233  	defer m.mu.Unlock()
   234  
   235  	// Ensure that we aren't overwriting an existing session.
   236  	if _, ok := m.activeSessions[session.ID]; ok {
   237  		return wtdb.ErrClientSessionAlreadyExists
   238  	}
   239  
   240  	key := keyIndexKey{
   241  		towerID:  session.TowerID,
   242  		blobType: session.Policy.BlobType,
   243  	}
   244  
   245  	// Ensure that a session key index has been reserved for this tower.
   246  	keyIndex, err := m.getSessionKeyIndex(key)
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	// Ensure that the session's index matches the reserved index.
   252  	if keyIndex != session.KeyIndex {
   253  		return wtdb.ErrIncorrectKeyIndex
   254  	}
   255  
   256  	// Remove the key index reservation for this tower. Once committed, this
   257  	// permits us to create another session with this tower.
   258  	delete(m.indexes, key)
   259  	if key.blobType == blob.TypeAltruistCommit {
   260  		delete(m.legacyIndexes, key.towerID)
   261  	}
   262  
   263  	m.activeSessions[session.ID] = wtdb.ClientSession{
   264  		ID: session.ID,
   265  		ClientSessionBody: wtdb.ClientSessionBody{
   266  			SeqNum:           session.SeqNum,
   267  			TowerLastApplied: session.TowerLastApplied,
   268  			TowerID:          session.TowerID,
   269  			KeyIndex:         session.KeyIndex,
   270  			Policy:           session.Policy,
   271  			RewardPkScript:   cloneBytes(session.RewardPkScript),
   272  		},
   273  		CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
   274  		AckedUpdates:     make(map[uint16]wtdb.BackupID),
   275  	}
   276  
   277  	return nil
   278  }
   279  
   280  // NextSessionKeyIndex reserves a new session key derivation index for a
   281  // particular tower id. The index is reserved for that tower until
   282  // CreateClientSession is invoked for that tower and index, at which point a new
   283  // index for that tower can be reserved. Multiple calls to this method before
   284  // CreateClientSession is invoked should return the same index.
   285  func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID,
   286  	blobType blob.Type) (uint32, error) {
   287  
   288  	m.mu.Lock()
   289  	defer m.mu.Unlock()
   290  
   291  	key := keyIndexKey{
   292  		towerID:  towerID,
   293  		blobType: blobType,
   294  	}
   295  
   296  	if index, err := m.getSessionKeyIndex(key); err == nil {
   297  		return index, nil
   298  	}
   299  
   300  	m.nextIndex++
   301  	index := m.nextIndex
   302  	m.indexes[key] = index
   303  
   304  	return index, nil
   305  }
   306  
   307  func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) {
   308  	if index, ok := m.indexes[key]; ok {
   309  		return index, nil
   310  	}
   311  
   312  	if key.blobType == blob.TypeAltruistCommit {
   313  		if index, ok := m.legacyIndexes[key.towerID]; ok {
   314  			return index, nil
   315  		}
   316  	}
   317  
   318  	return 0, wtdb.ErrNoReservedKeyIndex
   319  }
   320  
   321  // CommitUpdate persists the CommittedUpdate provided in the slot for (session,
   322  // seqNum). This allows the client to retransmit this update on startup.
   323  func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
   324  	update *wtdb.CommittedUpdate) (uint16, error) {
   325  
   326  	m.mu.Lock()
   327  	defer m.mu.Unlock()
   328  
   329  	// Fail if session doesn't exist.
   330  	session, ok := m.activeSessions[*id]
   331  	if !ok {
   332  		return 0, wtdb.ErrClientSessionNotFound
   333  	}
   334  
   335  	// Check if an update has already been committed for this state.
   336  	for _, dbUpdate := range session.CommittedUpdates {
   337  		if dbUpdate.SeqNum == update.SeqNum {
   338  			// If the breach hint matches, we'll just return the
   339  			// last applied value so the client can retransmit.
   340  			if dbUpdate.Hint == update.Hint {
   341  				return session.TowerLastApplied, nil
   342  			}
   343  
   344  			// Otherwise, fail since the breach hint doesn't match.
   345  			return 0, wtdb.ErrUpdateAlreadyCommitted
   346  		}
   347  	}
   348  
   349  	// Sequence number must increment.
   350  	if update.SeqNum != session.SeqNum+1 {
   351  		return 0, wtdb.ErrCommitUnorderedUpdate
   352  	}
   353  
   354  	// Save the update and increment the sequence number.
   355  	session.CommittedUpdates = append(session.CommittedUpdates, *update)
   356  	session.SeqNum++
   357  	m.activeSessions[*id] = session
   358  
   359  	return session.TowerLastApplied, nil
   360  }
   361  
   362  // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
   363  // removes the update from the set of committed updates, and validates the
   364  // lastApplied value returned from the tower.
   365  func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
   366  	m.mu.Lock()
   367  	defer m.mu.Unlock()
   368  
   369  	// Fail if session doesn't exist.
   370  	session, ok := m.activeSessions[*id]
   371  	if !ok {
   372  		return wtdb.ErrClientSessionNotFound
   373  	}
   374  
   375  	// Ensure the returned last applied value does not exceed the highest
   376  	// allocated sequence number.
   377  	if lastApplied > session.SeqNum {
   378  		return wtdb.ErrUnallocatedLastApplied
   379  	}
   380  
   381  	// Ensure the last applied value isn't lower than a previous one sent by
   382  	// the tower.
   383  	if lastApplied < session.TowerLastApplied {
   384  		return wtdb.ErrLastAppliedReversion
   385  	}
   386  
   387  	// Retrieve the committed update, failing if none is found. We should
   388  	// only receive acks for state updates that we send.
   389  	updates := session.CommittedUpdates
   390  	for i, update := range updates {
   391  		if update.SeqNum != seqNum {
   392  			continue
   393  		}
   394  
   395  		// Remove the committed update from disk and mark the update as
   396  		// acked. The tower last applied value is also recorded to send
   397  		// along with the next update.
   398  		copy(updates[:i], updates[i+1:])
   399  		updates[len(updates)-1] = wtdb.CommittedUpdate{}
   400  		session.CommittedUpdates = updates[:len(updates)-1]
   401  
   402  		session.AckedUpdates[seqNum] = update.BackupID
   403  		session.TowerLastApplied = lastApplied
   404  
   405  		m.activeSessions[*id] = session
   406  		return nil
   407  	}
   408  
   409  	return wtdb.ErrCommittedUpdateNotFound
   410  }
   411  
   412  // FetchChanSummaries loads a mapping from all registered channels to their
   413  // channel summaries.
   414  func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
   415  	m.mu.Lock()
   416  	defer m.mu.Unlock()
   417  
   418  	summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
   419  	for chanID, summary := range m.summaries {
   420  		summaries[chanID] = wtdb.ClientChanSummary{
   421  			SweepPkScript: cloneBytes(summary.SweepPkScript),
   422  		}
   423  	}
   424  
   425  	return summaries, nil
   426  }
   427  
   428  // RegisterChannel registers a channel for use within the client database. For
   429  // now, all that is stored in the channel summary is the sweep pkscript that
   430  // we'd like any tower sweeps to pay into. In the future, this will be extended
   431  // to contain more info to allow the client efficiently request historical
   432  // states to be backed up under the client's active policy.
   433  func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
   434  	sweepPkScript []byte) error {
   435  
   436  	m.mu.Lock()
   437  	defer m.mu.Unlock()
   438  
   439  	if _, ok := m.summaries[chanID]; ok {
   440  		return wtdb.ErrChannelAlreadyRegistered
   441  	}
   442  
   443  	m.summaries[chanID] = wtdb.ClientChanSummary{
   444  		SweepPkScript: cloneBytes(sweepPkScript),
   445  	}
   446  
   447  	return nil
   448  }
   449  
   450  func cloneBytes(b []byte) []byte {
   451  	if b == nil {
   452  		return nil
   453  	}
   454  
   455  	bb := make([]byte, len(b))
   456  	copy(bb, b)
   457  
   458  	return bb
   459  }
   460  
   461  func copyTower(tower *wtdb.Tower) *wtdb.Tower {
   462  	t := &wtdb.Tower{
   463  		ID:          tower.ID,
   464  		IdentityKey: tower.IdentityKey,
   465  		Addresses:   make([]net.Addr, len(tower.Addresses)),
   466  	}
   467  	copy(t.Addresses, tower.Addresses)
   468  
   469  	return t
   470  }