github.com/decred/dcrlnd@v0.7.6/watchtower/wtmock/tower_db.go (about)

     1  package wtmock
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/decred/dcrlnd/chainntnfs"
     7  	"github.com/decred/dcrlnd/watchtower/blob"
     8  	"github.com/decred/dcrlnd/watchtower/wtdb"
     9  )
    10  
    11  // TowerDB is a mock, in-memory implementation of a watchtower.DB.
    12  type TowerDB struct {
    13  	mu        sync.Mutex
    14  	lastEpoch *chainntnfs.BlockEpoch
    15  	sessions  map[wtdb.SessionID]*wtdb.SessionInfo
    16  	blobs     map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate
    17  }
    18  
    19  // NewTowerDB initializes a fresh mock TowerDB.
    20  func NewTowerDB() *TowerDB {
    21  	return &TowerDB{
    22  		sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo),
    23  		blobs:    make(map[blob.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate),
    24  	}
    25  }
    26  
    27  // InsertStateUpdate stores an update sent by the client after validating that
    28  // the update is well-formed in the context of other updates sent for the same
    29  // session. This include verifying that the sequence number is incremented
    30  // properly and the last applied values echoed by the client are sane.
    31  func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) {
    32  	db.mu.Lock()
    33  	defer db.mu.Unlock()
    34  
    35  	info, ok := db.sessions[update.ID]
    36  	if !ok {
    37  		return 0, wtdb.ErrSessionNotFound
    38  	}
    39  
    40  	// Assert that the blob is the correct size for the session's blob type.
    41  	if len(update.EncryptedBlob) != blob.Size(info.Policy.BlobType) {
    42  		return 0, wtdb.ErrInvalidBlobSize
    43  	}
    44  
    45  	err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied)
    46  	if err != nil {
    47  		return info.LastApplied, err
    48  	}
    49  
    50  	sessionsToUpdates, ok := db.blobs[update.Hint]
    51  	if !ok {
    52  		sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate)
    53  		db.blobs[update.Hint] = sessionsToUpdates
    54  	}
    55  	sessionsToUpdates[update.ID] = update
    56  
    57  	return info.LastApplied, nil
    58  }
    59  
    60  // GetSessionInfo retrieves the session for the passed session id. An error is
    61  // returned if the session could not be found.
    62  func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) {
    63  	db.mu.Lock()
    64  	defer db.mu.Unlock()
    65  
    66  	if info, ok := db.sessions[*id]; ok {
    67  		return info, nil
    68  	}
    69  
    70  	return nil, wtdb.ErrSessionNotFound
    71  }
    72  
    73  // InsertSessionInfo records a negotiated session in the tower database. An
    74  // error is returned if the session already exists.
    75  func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error {
    76  	db.mu.Lock()
    77  	defer db.mu.Unlock()
    78  
    79  	dbInfo, ok := db.sessions[info.ID]
    80  	if ok && dbInfo.LastApplied > 0 {
    81  		return wtdb.ErrSessionAlreadyExists
    82  	}
    83  
    84  	// Perform a quick sanity check on the session policy before accepting.
    85  	if err := info.Policy.Validate(); err != nil {
    86  		return err
    87  	}
    88  
    89  	db.sessions[info.ID] = info
    90  
    91  	return nil
    92  }
    93  
    94  // DeleteSession removes all data associated with a particular session id from
    95  // the tower's database.
    96  func (db *TowerDB) DeleteSession(target wtdb.SessionID) error {
    97  	db.mu.Lock()
    98  	defer db.mu.Unlock()
    99  
   100  	// Fail if the session doesn't exit.
   101  	if _, ok := db.sessions[target]; !ok {
   102  		return wtdb.ErrSessionNotFound
   103  	}
   104  
   105  	// Remove the target session.
   106  	delete(db.sessions, target)
   107  
   108  	// Remove the state updates for any blobs stored under the target
   109  	// session identifier.
   110  	for hint, sessionUpdates := range db.blobs {
   111  		delete(sessionUpdates, target)
   112  
   113  		// If this was the last state update, we can also remove the
   114  		// hint that would map to an empty set.
   115  		if len(sessionUpdates) == 0 {
   116  			delete(db.blobs, hint)
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  // QueryMatches searches against all known state updates for any that match the
   124  // passed breachHints. More than one Match will be returned for a given hint if
   125  // they exist in the database.
   126  func (db *TowerDB) QueryMatches(
   127  	breachHints []blob.BreachHint) ([]wtdb.Match, error) {
   128  
   129  	db.mu.Lock()
   130  	defer db.mu.Unlock()
   131  
   132  	var matches []wtdb.Match
   133  	for _, hint := range breachHints {
   134  		sessionsToUpdates, ok := db.blobs[hint]
   135  		if !ok {
   136  			continue
   137  		}
   138  
   139  		for id, update := range sessionsToUpdates {
   140  			info, ok := db.sessions[id]
   141  			if !ok {
   142  				panic("session not found")
   143  			}
   144  
   145  			match := wtdb.Match{
   146  				ID:            id,
   147  				SeqNum:        update.SeqNum,
   148  				Hint:          hint,
   149  				EncryptedBlob: update.EncryptedBlob,
   150  				SessionInfo:   info,
   151  			}
   152  			matches = append(matches, match)
   153  		}
   154  	}
   155  
   156  	return matches, nil
   157  }
   158  
   159  // SetLookoutTip stores the provided epoch as the latest lookout tip epoch in
   160  // the tower database.
   161  func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error {
   162  	db.lastEpoch = epoch
   163  	return nil
   164  }
   165  
   166  // GetLookoutTip retrieves the current lookout tip block epoch from the tower
   167  // database.
   168  func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) {
   169  	db.mu.Lock()
   170  	defer db.mu.Unlock()
   171  
   172  	return db.lastEpoch, nil
   173  }