github.com/decred/dcrlnd@v0.7.6/watchtower/wtmock/client_db.go (about) 1 package wtmock 2 3 import ( 4 "net" 5 "sync" 6 "sync/atomic" 7 8 "github.com/decred/dcrd/dcrec/secp256k1/v4" 9 "github.com/decred/dcrlnd/lnwire" 10 "github.com/decred/dcrlnd/watchtower/blob" 11 "github.com/decred/dcrlnd/watchtower/wtdb" 12 ) 13 14 type towerPK [33]byte 15 16 type keyIndexKey struct { 17 towerID wtdb.TowerID 18 blobType blob.Type 19 } 20 21 // ClientDB is a mock, in-memory database or testing the watchtower client 22 // behavior. 23 type ClientDB struct { 24 nextTowerID uint64 // to be used atomically 25 26 mu sync.Mutex 27 summaries map[lnwire.ChannelID]wtdb.ClientChanSummary 28 activeSessions map[wtdb.SessionID]wtdb.ClientSession 29 towerIndex map[towerPK]wtdb.TowerID 30 towers map[wtdb.TowerID]*wtdb.Tower 31 32 nextIndex uint32 33 indexes map[keyIndexKey]uint32 34 legacyIndexes map[wtdb.TowerID]uint32 35 } 36 37 // NewClientDB initializes a new mock ClientDB. 38 func NewClientDB() *ClientDB { 39 return &ClientDB{ 40 summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), 41 activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), 42 towerIndex: make(map[towerPK]wtdb.TowerID), 43 towers: make(map[wtdb.TowerID]*wtdb.Tower), 44 indexes: make(map[keyIndexKey]uint32), 45 legacyIndexes: make(map[wtdb.TowerID]uint32), 46 } 47 } 48 49 // CreateTower initialize an address record used to communicate with a 50 // watchtower. Each Tower is assigned a unique ID, that is used to amortize 51 // storage costs of the public key when used by multiple sessions. If the tower 52 // already exists, the address is appended to the list of all addresses used to 53 // that tower previously and its corresponding sessions are marked as active. 54 func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { 55 m.mu.Lock() 56 defer m.mu.Unlock() 57 58 var towerPubKey towerPK 59 copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) 60 61 var tower *wtdb.Tower 62 towerID, ok := m.towerIndex[towerPubKey] 63 if ok { 64 tower = m.towers[towerID] 65 tower.AddAddress(lnAddr.Address) 66 67 towerSessions, err := m.listClientSessions(&towerID) 68 if err != nil { 69 return nil, err 70 } 71 for id, session := range towerSessions { 72 session.Status = wtdb.CSessionActive 73 m.activeSessions[id] = *session 74 } 75 } else { 76 towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) 77 tower = &wtdb.Tower{ 78 ID: towerID, 79 IdentityKey: lnAddr.IdentityKey, 80 Addresses: []net.Addr{lnAddr.Address}, 81 } 82 } 83 84 m.towerIndex[towerPubKey] = towerID 85 m.towers[towerID] = tower 86 87 return copyTower(tower), nil 88 } 89 90 // RemoveTower modifies a tower's record within the database. If an address is 91 // provided, then _only_ the address record should be removed from the tower's 92 // persisted state. Otherwise, we'll attempt to mark the tower as inactive by 93 // marking all of its sessions inactive. If any of its sessions has unacked 94 // updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have 95 // any sessions at all, it'll be completely removed from the database. 96 // 97 // NOTE: An error is not returned if the tower doesn't exist. 98 func (m *ClientDB) RemoveTower(pubKey *secp256k1.PublicKey, addr net.Addr) error { 99 m.mu.Lock() 100 defer m.mu.Unlock() 101 102 tower, err := m.loadTower(pubKey) 103 if err == wtdb.ErrTowerNotFound { 104 return nil 105 } 106 if err != nil { 107 return err 108 } 109 110 if addr != nil { 111 tower.RemoveAddress(addr) 112 if len(tower.Addresses) == 0 { 113 return wtdb.ErrLastTowerAddr 114 } 115 m.towers[tower.ID] = tower 116 return nil 117 } 118 119 towerSessions, err := m.listClientSessions(&tower.ID) 120 if err != nil { 121 return err 122 } 123 if len(towerSessions) == 0 { 124 var towerPK towerPK 125 copy(towerPK[:], pubKey.SerializeCompressed()) 126 delete(m.towerIndex, towerPK) 127 delete(m.towers, tower.ID) 128 return nil 129 } 130 131 for id, session := range towerSessions { 132 if len(session.CommittedUpdates) > 0 { 133 return wtdb.ErrTowerUnackedUpdates 134 } 135 session.Status = wtdb.CSessionInactive 136 m.activeSessions[id] = *session 137 } 138 139 return nil 140 } 141 142 // LoadTower retrieves a tower by its public key. 143 func (m *ClientDB) LoadTower(pubKey *secp256k1.PublicKey) (*wtdb.Tower, error) { 144 m.mu.Lock() 145 defer m.mu.Unlock() 146 return m.loadTower(pubKey) 147 } 148 149 // loadTower retrieves a tower by its public key. 150 // 151 // NOTE: This method requires the database's lock to be acquired. 152 func (m *ClientDB) loadTower(pubKey *secp256k1.PublicKey) (*wtdb.Tower, error) { 153 var towerPK towerPK 154 copy(towerPK[:], pubKey.SerializeCompressed()) 155 156 towerID, ok := m.towerIndex[towerPK] 157 if !ok { 158 return nil, wtdb.ErrTowerNotFound 159 } 160 tower, ok := m.towers[towerID] 161 if !ok { 162 return nil, wtdb.ErrTowerNotFound 163 } 164 165 return copyTower(tower), nil 166 } 167 168 // LoadTowerByID retrieves a tower by its tower ID. 169 func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) { 170 m.mu.Lock() 171 defer m.mu.Unlock() 172 173 if tower, ok := m.towers[towerID]; ok { 174 return copyTower(tower), nil 175 } 176 177 return nil, wtdb.ErrTowerNotFound 178 } 179 180 // ListTowers retrieves the list of towers available within the database. 181 func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { 182 m.mu.Lock() 183 defer m.mu.Unlock() 184 185 towers := make([]*wtdb.Tower, 0, len(m.towers)) 186 for _, tower := range m.towers { 187 towers = append(towers, copyTower(tower)) 188 } 189 190 return towers, nil 191 } 192 193 // MarkBackupIneligible records that particular commit height is ineligible for 194 // backup. This allows the client to track which updates it should not attempt 195 // to retry after startup. 196 func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { 197 return nil 198 } 199 200 // ListClientSessions returns the set of all client sessions known to the db. An 201 // optional tower ID can be used to filter out any client sessions in the 202 // response that do not correspond to this tower. 203 func (m *ClientDB) ListClientSessions( 204 tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { 205 206 m.mu.Lock() 207 defer m.mu.Unlock() 208 return m.listClientSessions(tower) 209 } 210 211 // listClientSessions returns the set of all client sessions known to the db. An 212 // optional tower ID can be used to filter out any client sessions in the 213 // response that do not correspond to this tower. 214 func (m *ClientDB) listClientSessions( 215 tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { 216 217 sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) 218 for _, session := range m.activeSessions { 219 session := session 220 if tower != nil && *tower != session.TowerID { 221 continue 222 } 223 sessions[session.ID] = &session 224 } 225 226 return sessions, nil 227 } 228 229 // CreateClientSession records a newly negotiated client session in the set of 230 // active sessions. The session can be identified by its SessionID. 231 func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { 232 m.mu.Lock() 233 defer m.mu.Unlock() 234 235 // Ensure that we aren't overwriting an existing session. 236 if _, ok := m.activeSessions[session.ID]; ok { 237 return wtdb.ErrClientSessionAlreadyExists 238 } 239 240 key := keyIndexKey{ 241 towerID: session.TowerID, 242 blobType: session.Policy.BlobType, 243 } 244 245 // Ensure that a session key index has been reserved for this tower. 246 keyIndex, err := m.getSessionKeyIndex(key) 247 if err != nil { 248 return err 249 } 250 251 // Ensure that the session's index matches the reserved index. 252 if keyIndex != session.KeyIndex { 253 return wtdb.ErrIncorrectKeyIndex 254 } 255 256 // Remove the key index reservation for this tower. Once committed, this 257 // permits us to create another session with this tower. 258 delete(m.indexes, key) 259 if key.blobType == blob.TypeAltruistCommit { 260 delete(m.legacyIndexes, key.towerID) 261 } 262 263 m.activeSessions[session.ID] = wtdb.ClientSession{ 264 ID: session.ID, 265 ClientSessionBody: wtdb.ClientSessionBody{ 266 SeqNum: session.SeqNum, 267 TowerLastApplied: session.TowerLastApplied, 268 TowerID: session.TowerID, 269 KeyIndex: session.KeyIndex, 270 Policy: session.Policy, 271 RewardPkScript: cloneBytes(session.RewardPkScript), 272 }, 273 CommittedUpdates: make([]wtdb.CommittedUpdate, 0), 274 AckedUpdates: make(map[uint16]wtdb.BackupID), 275 } 276 277 return nil 278 } 279 280 // NextSessionKeyIndex reserves a new session key derivation index for a 281 // particular tower id. The index is reserved for that tower until 282 // CreateClientSession is invoked for that tower and index, at which point a new 283 // index for that tower can be reserved. Multiple calls to this method before 284 // CreateClientSession is invoked should return the same index. 285 func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, 286 blobType blob.Type) (uint32, error) { 287 288 m.mu.Lock() 289 defer m.mu.Unlock() 290 291 key := keyIndexKey{ 292 towerID: towerID, 293 blobType: blobType, 294 } 295 296 if index, err := m.getSessionKeyIndex(key); err == nil { 297 return index, nil 298 } 299 300 m.nextIndex++ 301 index := m.nextIndex 302 m.indexes[key] = index 303 304 return index, nil 305 } 306 307 func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) { 308 if index, ok := m.indexes[key]; ok { 309 return index, nil 310 } 311 312 if key.blobType == blob.TypeAltruistCommit { 313 if index, ok := m.legacyIndexes[key.towerID]; ok { 314 return index, nil 315 } 316 } 317 318 return 0, wtdb.ErrNoReservedKeyIndex 319 } 320 321 // CommitUpdate persists the CommittedUpdate provided in the slot for (session, 322 // seqNum). This allows the client to retransmit this update on startup. 323 func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, 324 update *wtdb.CommittedUpdate) (uint16, error) { 325 326 m.mu.Lock() 327 defer m.mu.Unlock() 328 329 // Fail if session doesn't exist. 330 session, ok := m.activeSessions[*id] 331 if !ok { 332 return 0, wtdb.ErrClientSessionNotFound 333 } 334 335 // Check if an update has already been committed for this state. 336 for _, dbUpdate := range session.CommittedUpdates { 337 if dbUpdate.SeqNum == update.SeqNum { 338 // If the breach hint matches, we'll just return the 339 // last applied value so the client can retransmit. 340 if dbUpdate.Hint == update.Hint { 341 return session.TowerLastApplied, nil 342 } 343 344 // Otherwise, fail since the breach hint doesn't match. 345 return 0, wtdb.ErrUpdateAlreadyCommitted 346 } 347 } 348 349 // Sequence number must increment. 350 if update.SeqNum != session.SeqNum+1 { 351 return 0, wtdb.ErrCommitUnorderedUpdate 352 } 353 354 // Save the update and increment the sequence number. 355 session.CommittedUpdates = append(session.CommittedUpdates, *update) 356 session.SeqNum++ 357 m.activeSessions[*id] = session 358 359 return session.TowerLastApplied, nil 360 } 361 362 // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This 363 // removes the update from the set of committed updates, and validates the 364 // lastApplied value returned from the tower. 365 func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { 366 m.mu.Lock() 367 defer m.mu.Unlock() 368 369 // Fail if session doesn't exist. 370 session, ok := m.activeSessions[*id] 371 if !ok { 372 return wtdb.ErrClientSessionNotFound 373 } 374 375 // Ensure the returned last applied value does not exceed the highest 376 // allocated sequence number. 377 if lastApplied > session.SeqNum { 378 return wtdb.ErrUnallocatedLastApplied 379 } 380 381 // Ensure the last applied value isn't lower than a previous one sent by 382 // the tower. 383 if lastApplied < session.TowerLastApplied { 384 return wtdb.ErrLastAppliedReversion 385 } 386 387 // Retrieve the committed update, failing if none is found. We should 388 // only receive acks for state updates that we send. 389 updates := session.CommittedUpdates 390 for i, update := range updates { 391 if update.SeqNum != seqNum { 392 continue 393 } 394 395 // Remove the committed update from disk and mark the update as 396 // acked. The tower last applied value is also recorded to send 397 // along with the next update. 398 copy(updates[:i], updates[i+1:]) 399 updates[len(updates)-1] = wtdb.CommittedUpdate{} 400 session.CommittedUpdates = updates[:len(updates)-1] 401 402 session.AckedUpdates[seqNum] = update.BackupID 403 session.TowerLastApplied = lastApplied 404 405 m.activeSessions[*id] = session 406 return nil 407 } 408 409 return wtdb.ErrCommittedUpdateNotFound 410 } 411 412 // FetchChanSummaries loads a mapping from all registered channels to their 413 // channel summaries. 414 func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { 415 m.mu.Lock() 416 defer m.mu.Unlock() 417 418 summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) 419 for chanID, summary := range m.summaries { 420 summaries[chanID] = wtdb.ClientChanSummary{ 421 SweepPkScript: cloneBytes(summary.SweepPkScript), 422 } 423 } 424 425 return summaries, nil 426 } 427 428 // RegisterChannel registers a channel for use within the client database. For 429 // now, all that is stored in the channel summary is the sweep pkscript that 430 // we'd like any tower sweeps to pay into. In the future, this will be extended 431 // to contain more info to allow the client efficiently request historical 432 // states to be backed up under the client's active policy. 433 func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, 434 sweepPkScript []byte) error { 435 436 m.mu.Lock() 437 defer m.mu.Unlock() 438 439 if _, ok := m.summaries[chanID]; ok { 440 return wtdb.ErrChannelAlreadyRegistered 441 } 442 443 m.summaries[chanID] = wtdb.ClientChanSummary{ 444 SweepPkScript: cloneBytes(sweepPkScript), 445 } 446 447 return nil 448 } 449 450 func cloneBytes(b []byte) []byte { 451 if b == nil { 452 return nil 453 } 454 455 bb := make([]byte, len(b)) 456 copy(bb, b) 457 458 return bb 459 } 460 461 func copyTower(tower *wtdb.Tower) *wtdb.Tower { 462 t := &wtdb.Tower{ 463 ID: tower.ID, 464 IdentityKey: tower.IdentityKey, 465 Addresses: make([]net.Addr, len(tower.Addresses)), 466 } 467 copy(t.Addresses, tower.Addresses) 468 469 return t 470 }