github.com/decred/dcrlnd@v0.7.6/watchtower/wtdb/client_db_test.go (about) 1 package wtdb_test 2 3 import ( 4 "bytes" 5 crand "crypto/rand" 6 "io" 7 "io/ioutil" 8 "net" 9 "os" 10 "reflect" 11 "testing" 12 13 "github.com/decred/dcrd/dcrec/secp256k1/v4" 14 "github.com/decred/dcrlnd/kvdb" 15 "github.com/decred/dcrlnd/lnwire" 16 "github.com/decred/dcrlnd/watchtower/blob" 17 "github.com/decred/dcrlnd/watchtower/wtclient" 18 "github.com/decred/dcrlnd/watchtower/wtdb" 19 "github.com/decred/dcrlnd/watchtower/wtmock" 20 "github.com/decred/dcrlnd/watchtower/wtpolicy" 21 ) 22 23 // clientDBInit is a closure used to initialize a wtclient.DB instance its 24 // cleanup function. 25 type clientDBInit func(t *testing.T) (wtclient.DB, func()) 26 27 type clientDBHarness struct { 28 t *testing.T 29 db wtclient.DB 30 } 31 32 func newClientDBHarness(t *testing.T, init clientDBInit) (*clientDBHarness, func()) { 33 db, cleanup := init(t) 34 35 h := &clientDBHarness{ 36 t: t, 37 db: db, 38 } 39 40 return h, cleanup 41 } 42 43 func (h *clientDBHarness) insertSession(session *wtdb.ClientSession, expErr error) { 44 h.t.Helper() 45 46 err := h.db.CreateClientSession(session) 47 if err != expErr { 48 h.t.Fatalf("expected create client session error: %v, got: %v", 49 expErr, err) 50 } 51 } 52 53 func (h *clientDBHarness) listSessions(id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { 54 h.t.Helper() 55 56 sessions, err := h.db.ListClientSessions(id) 57 if err != nil { 58 h.t.Fatalf("unable to list client sessions: %v", err) 59 } 60 61 return sessions 62 } 63 64 func (h *clientDBHarness) nextKeyIndex(id wtdb.TowerID, 65 blobType blob.Type) uint32 { 66 67 h.t.Helper() 68 69 index, err := h.db.NextSessionKeyIndex(id, blobType) 70 if err != nil { 71 h.t.Fatalf("unable to create next session key index: %v", err) 72 } 73 74 if index == 0 { 75 h.t.Fatalf("next key index should never be 0") 76 } 77 78 return index 79 } 80 81 func (h *clientDBHarness) createTower(lnAddr *lnwire.NetAddress, 82 expErr error) *wtdb.Tower { 83 84 h.t.Helper() 85 86 tower, err := h.db.CreateTower(lnAddr) 87 if err != expErr { 88 h.t.Fatalf("expected create tower error: %v, got: %v", expErr, err) 89 } 90 91 if tower.ID == 0 { 92 h.t.Fatalf("tower id should never be 0") 93 } 94 95 for _, session := range h.listSessions(&tower.ID) { 96 if session.Status != wtdb.CSessionActive { 97 h.t.Fatalf("expected status for session %v to be %v, "+ 98 "got %v", session.ID, wtdb.CSessionActive, 99 session.Status) 100 } 101 } 102 103 return tower 104 } 105 106 func (h *clientDBHarness) removeTower(pubKey *secp256k1.PublicKey, addr net.Addr, 107 hasSessions bool, expErr error) { 108 109 h.t.Helper() 110 111 if err := h.db.RemoveTower(pubKey, addr); err != expErr { 112 h.t.Fatalf("expected remove tower error: %v, got %v", expErr, err) 113 } 114 if expErr != nil { 115 return 116 } 117 118 if addr != nil { 119 tower, err := h.db.LoadTower(pubKey) 120 if err != nil { 121 h.t.Fatalf("expected tower %x to still exist", 122 pubKey.SerializeCompressed()) 123 } 124 125 removedAddr := addr.String() 126 for _, towerAddr := range tower.Addresses { 127 if towerAddr.String() == removedAddr { 128 h.t.Fatalf("address %v not removed for tower %x", 129 removedAddr, pubKey.SerializeCompressed()) 130 } 131 } 132 } else { 133 tower, err := h.db.LoadTower(pubKey) 134 if hasSessions && err != nil { 135 h.t.Fatalf("expected tower %x with sessions to still "+ 136 "exist", pubKey.SerializeCompressed()) 137 } 138 if !hasSessions && err == nil { 139 h.t.Fatalf("expected tower %x with no sessions to not "+ 140 "exist", pubKey.SerializeCompressed()) 141 } 142 if !hasSessions { 143 return 144 } 145 for _, session := range h.listSessions(&tower.ID) { 146 if session.Status != wtdb.CSessionInactive { 147 h.t.Fatalf("expected status for session %v to "+ 148 "be %v, got %v", session.ID, 149 wtdb.CSessionInactive, session.Status) 150 } 151 } 152 } 153 } 154 155 func (h *clientDBHarness) loadTower(pubKey *secp256k1.PublicKey, expErr error) *wtdb.Tower { 156 h.t.Helper() 157 158 tower, err := h.db.LoadTower(pubKey) 159 if err != expErr { 160 h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) 161 } 162 163 return tower 164 } 165 166 func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, expErr error) *wtdb.Tower { 167 h.t.Helper() 168 169 tower, err := h.db.LoadTowerByID(id) 170 if err != expErr { 171 h.t.Fatalf("expected load tower error: %v, got: %v", expErr, err) 172 } 173 174 return tower 175 } 176 177 func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary { 178 h.t.Helper() 179 180 summaries, err := h.db.FetchChanSummaries() 181 if err != nil { 182 h.t.Fatalf("unable to fetch chan summaries: %v", err) 183 } 184 185 return summaries 186 } 187 188 func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, 189 sweepPkScript []byte, expErr error) { 190 191 h.t.Helper() 192 193 err := h.db.RegisterChannel(chanID, sweepPkScript) 194 if err != expErr { 195 h.t.Fatalf("expected register channel error: %v, got: %v", 196 expErr, err) 197 } 198 } 199 200 func (h *clientDBHarness) commitUpdate(id *wtdb.SessionID, 201 update *wtdb.CommittedUpdate, expErr error) uint16 { 202 203 h.t.Helper() 204 205 lastApplied, err := h.db.CommitUpdate(id, update) 206 if err != expErr { 207 h.t.Fatalf("expected commit update error: %v, got: %v", 208 expErr, err) 209 } 210 211 return lastApplied 212 } 213 214 func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, 215 lastApplied uint16, expErr error) { 216 217 h.t.Helper() 218 219 err := h.db.AckUpdate(id, seqNum, lastApplied) 220 if err != expErr { 221 h.t.Fatalf("expected commit update error: %v, got: %v", 222 expErr, err) 223 } 224 } 225 226 // testCreateClientSession asserts various conditions regarding the creation of 227 // a new ClientSession. The test asserts: 228 // - client sessions can only be created if a session key index is reserved. 229 // - client sessions cannot be created with an incorrect session key index . 230 // - inserting duplicate sessions fails. 231 func testCreateClientSession(h *clientDBHarness) { 232 const blobType = blob.TypeAltruistAnchorCommit 233 234 // Create a test client session to insert. 235 session := &wtdb.ClientSession{ 236 ClientSessionBody: wtdb.ClientSessionBody{ 237 TowerID: wtdb.TowerID(3), 238 Policy: wtpolicy.Policy{ 239 TxPolicy: wtpolicy.TxPolicy{ 240 BlobType: blobType, 241 }, 242 MaxUpdates: 100, 243 }, 244 RewardPkScript: []byte{0x01, 0x02, 0x03}, 245 }, 246 ID: wtdb.SessionID([33]byte{0x01}), 247 } 248 249 // First, assert that this session is not already present in the 250 // database. 251 if _, ok := h.listSessions(nil)[session.ID]; ok { 252 h.t.Fatalf("session for id %x should not exist yet", session.ID) 253 } 254 255 // Attempting to insert the client session without reserving a session 256 // key index should fail. 257 h.insertSession(session, wtdb.ErrNoReservedKeyIndex) 258 259 // Now, reserve a session key for this tower. 260 keyIndex := h.nextKeyIndex(session.TowerID, blobType) 261 262 // The client session hasn't been updated with the reserved key index 263 // (since it's still zero). Inserting should fail due to the mismatch. 264 h.insertSession(session, wtdb.ErrIncorrectKeyIndex) 265 266 // Reserve another key for the same index. Since no session has been 267 // successfully created, it should return the same index to maintain 268 // idempotency across restarts. 269 keyIndex2 := h.nextKeyIndex(session.TowerID, blobType) 270 if keyIndex != keyIndex2 { 271 h.t.Fatalf("next key index should be idempotent: want: %v, "+ 272 "got %v", keyIndex, keyIndex2) 273 } 274 275 // Now, set the client session's key index so that it is proper and 276 // insert it. This should succeed. 277 session.KeyIndex = keyIndex 278 h.insertSession(session, nil) 279 280 // Verify that the session now exists in the database. 281 if _, ok := h.listSessions(nil)[session.ID]; !ok { 282 h.t.Fatalf("session for id %x should exist now", session.ID) 283 } 284 285 // Attempt to insert the session again, which should fail due to the 286 // session already existing. 287 h.insertSession(session, wtdb.ErrClientSessionAlreadyExists) 288 289 // Finally, assert that reserving another key index succeeds with a 290 // different key index, now that the first one has been finalized. 291 keyIndex3 := h.nextKeyIndex(session.TowerID, blobType) 292 if keyIndex == keyIndex3 { 293 h.t.Fatalf("key index still reserved after creating session") 294 } 295 } 296 297 // testFilterClientSessions asserts that we can correctly filter client sessions 298 // for a specific tower. 299 func testFilterClientSessions(h *clientDBHarness) { 300 // We'll create three client sessions, the first two belonging to one 301 // tower, and the last belonging to another one. 302 const numSessions = 3 303 const blobType = blob.TypeAltruistCommit 304 towerSessions := make(map[wtdb.TowerID][]wtdb.SessionID) 305 for i := 0; i < numSessions; i++ { 306 towerID := wtdb.TowerID(1) 307 if i == numSessions-1 { 308 towerID = wtdb.TowerID(2) 309 } 310 keyIndex := h.nextKeyIndex(towerID, blobType) 311 sessionID := wtdb.SessionID([33]byte{byte(i)}) 312 h.insertSession(&wtdb.ClientSession{ 313 ClientSessionBody: wtdb.ClientSessionBody{ 314 TowerID: towerID, 315 Policy: wtpolicy.Policy{ 316 TxPolicy: wtpolicy.TxPolicy{ 317 BlobType: blobType, 318 }, 319 MaxUpdates: 100, 320 }, 321 RewardPkScript: []byte{0x01, 0x02, 0x03}, 322 KeyIndex: keyIndex, 323 }, 324 ID: sessionID, 325 }, nil) 326 towerSessions[towerID] = append(towerSessions[towerID], sessionID) 327 } 328 329 // We should see the expected sessions for each tower when filtering 330 // them. 331 for towerID, expectedSessions := range towerSessions { 332 sessions := h.listSessions(&towerID) 333 if len(sessions) != len(expectedSessions) { 334 h.t.Fatalf("expected %v sessions for tower %v, got %v", 335 len(expectedSessions), towerID, len(sessions)) 336 } 337 for _, expectedSession := range expectedSessions { 338 if _, ok := sessions[expectedSession]; !ok { 339 h.t.Fatalf("expected session %v for tower %v", 340 expectedSession, towerID) 341 } 342 } 343 } 344 } 345 346 // testCreateTower asserts the behavior of creating new Tower objects within the 347 // database, and that the latest address is always prepended to the list of 348 // known addresses for the tower. 349 func testCreateTower(h *clientDBHarness) { 350 // Test that loading a tower with an arbitrary tower id fails. 351 h.loadTowerByID(20, wtdb.ErrTowerNotFound) 352 353 pk, err := randPubKey() 354 if err != nil { 355 h.t.Fatalf("unable to generate pubkey: %v", err) 356 } 357 358 addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} 359 lnAddr := &lnwire.NetAddress{ 360 IdentityKey: pk, 361 Address: addr1, 362 } 363 364 // Insert a random tower into the database. 365 tower := h.createTower(lnAddr, nil) 366 367 // Load the tower from the database and assert that it matches the tower 368 // we created. 369 tower2 := h.loadTowerByID(tower.ID, nil) 370 if !reflect.DeepEqual(tower, tower2) { 371 h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", 372 tower, tower2) 373 } 374 tower2 = h.loadTower(pk, err) 375 if !reflect.DeepEqual(tower, tower2) { 376 h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", 377 tower, tower2) 378 } 379 380 // Insert the address again into the database. Since the address is the 381 // same, this should result in an unmodified tower record. 382 towerDupAddr := h.createTower(lnAddr, nil) 383 if len(towerDupAddr.Addresses) != 1 { 384 h.t.Fatalf("duplicate address should be deduped") 385 } 386 if !reflect.DeepEqual(tower, towerDupAddr) { 387 h.t.Fatalf("mismatch towers, want: %v, got: %v", 388 tower, towerDupAddr) 389 } 390 391 // Generate a new address for this tower. 392 addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} 393 394 lnAddr2 := &lnwire.NetAddress{ 395 IdentityKey: pk, 396 Address: addr2, 397 } 398 399 // Insert the updated address, which should produce a tower with a new 400 // address. 401 towerNewAddr := h.createTower(lnAddr2, nil) 402 403 // Load the tower from the database, and assert that it matches the 404 // tower returned from creation. 405 towerNewAddr2 := h.loadTowerByID(tower.ID, nil) 406 if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { 407 h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", 408 towerNewAddr, towerNewAddr2) 409 } 410 towerNewAddr2 = h.loadTower(pk, nil) 411 if !reflect.DeepEqual(towerNewAddr, towerNewAddr2) { 412 h.t.Fatalf("loaded tower mismatch, want: %v, got: %v", 413 towerNewAddr, towerNewAddr2) 414 } 415 416 // Assert that there are now two addresses on the tower object. 417 if len(towerNewAddr.Addresses) != 2 { 418 h.t.Fatalf("new address should be added") 419 } 420 421 // Finally, assert that the new address was prepended since it is deemed 422 // fresher. 423 if !reflect.DeepEqual(tower.Addresses, towerNewAddr.Addresses[1:]) { 424 h.t.Fatalf("new address should be prepended") 425 } 426 } 427 428 // testRemoveTower asserts the behavior of removing Tower objects as a whole and 429 // removing addresses from Tower objects within the database. 430 func testRemoveTower(h *clientDBHarness) { 431 // Generate a random public key we'll use for our tower. 432 pk, err := randPubKey() 433 if err != nil { 434 h.t.Fatalf("unable to generate pubkey: %v", err) 435 } 436 437 // Removing a tower that does not exist within the database should 438 // result in a NOP. 439 h.removeTower(pk, nil, false, nil) 440 441 // We'll create a tower with two addresses. 442 addr1 := &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} 443 addr2 := &net.TCPAddr{IP: []byte{0x02, 0x00, 0x00, 0x00}, Port: 9911} 444 h.createTower(&lnwire.NetAddress{ 445 IdentityKey: pk, 446 Address: addr1, 447 }, nil) 448 h.createTower(&lnwire.NetAddress{ 449 IdentityKey: pk, 450 Address: addr2, 451 }, nil) 452 453 // We'll then remove the second address. We should now only see the 454 // first. 455 h.removeTower(pk, addr2, false, nil) 456 457 // We'll then remove the first address. We should now see that the tower 458 // has no addresses left. 459 h.removeTower(pk, addr1, false, wtdb.ErrLastTowerAddr) 460 461 // Removing the tower as a whole from the database should succeed since 462 // there aren't any active sessions for it. 463 h.removeTower(pk, nil, false, nil) 464 465 // We'll then recreate the tower, but this time we'll create a session 466 // for it. 467 tower := h.createTower(&lnwire.NetAddress{ 468 IdentityKey: pk, 469 Address: addr1, 470 }, nil) 471 472 const blobType = blob.TypeAltruistCommit 473 session := &wtdb.ClientSession{ 474 ClientSessionBody: wtdb.ClientSessionBody{ 475 TowerID: tower.ID, 476 Policy: wtpolicy.Policy{ 477 TxPolicy: wtpolicy.TxPolicy{ 478 BlobType: blobType, 479 }, 480 MaxUpdates: 100, 481 }, 482 RewardPkScript: []byte{0x01, 0x02, 0x03}, 483 KeyIndex: h.nextKeyIndex(tower.ID, blobType), 484 }, 485 ID: wtdb.SessionID([33]byte{0x01}), 486 } 487 h.insertSession(session, nil) 488 update := randCommittedUpdate(h.t, 1) 489 h.commitUpdate(&session.ID, update, nil) 490 491 // We should not be able to fully remove it from the database since 492 // there's a session and it has unacked updates. 493 h.removeTower(pk, nil, true, wtdb.ErrTowerUnackedUpdates) 494 495 // Removing the tower after all sessions no longer have unacked updates 496 // should result in the sessions becoming inactive. 497 h.ackUpdate(&session.ID, 1, 1, nil) 498 h.removeTower(pk, nil, true, nil) 499 500 // Creating the tower again should mark all of the sessions active once 501 // again. 502 h.createTower(&lnwire.NetAddress{ 503 IdentityKey: pk, 504 Address: addr1, 505 }, nil) 506 } 507 508 // testChanSummaries tests the process of a registering a channel and its 509 // associated sweep pkscript. 510 func testChanSummaries(h *clientDBHarness) { 511 // First, assert that this channel is not already registered. 512 var chanID lnwire.ChannelID 513 if _, ok := h.fetchChanSummaries()[chanID]; ok { 514 h.t.Fatalf("pkscript for channel %x should not exist yet", 515 chanID) 516 } 517 518 // Generate a random sweep pkscript and register it for this channel. 519 expPkScript := make([]byte, 22) 520 if _, err := io.ReadFull(crand.Reader, expPkScript); err != nil { 521 h.t.Fatalf("unable to generate pkscript: %v", err) 522 } 523 h.registerChan(chanID, expPkScript, nil) 524 525 // Assert that the channel exists and that its sweep pkscript matches 526 // the one we registered. 527 summary, ok := h.fetchChanSummaries()[chanID] 528 if !ok { 529 h.t.Fatalf("pkscript for channel %x should not exist yet", 530 chanID) 531 } else if !bytes.Equal(expPkScript, summary.SweepPkScript) { 532 h.t.Fatalf("pkscript mismatch, want: %x, got: %x", 533 expPkScript, summary.SweepPkScript) 534 } 535 536 // Finally, assert that re-registering the same channel produces a 537 // failure. 538 h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered) 539 } 540 541 // testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can 542 func testCommitUpdate(h *clientDBHarness) { 543 const blobType = blob.TypeAltruistCommit 544 session := &wtdb.ClientSession{ 545 ClientSessionBody: wtdb.ClientSessionBody{ 546 TowerID: wtdb.TowerID(3), 547 Policy: wtpolicy.Policy{ 548 TxPolicy: wtpolicy.TxPolicy{ 549 BlobType: blobType, 550 }, 551 MaxUpdates: 100, 552 }, 553 RewardPkScript: []byte{0x01, 0x02, 0x03}, 554 }, 555 ID: wtdb.SessionID([33]byte{0x02}), 556 } 557 558 // Generate a random update and try to commit before inserting the 559 // session, which should fail. 560 update1 := randCommittedUpdate(h.t, 1) 561 h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) 562 563 // Reserve a session key index and insert the session. 564 session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType) 565 h.insertSession(session, nil) 566 567 // Now, try to commit the update that failed initially which should 568 // succeed. The lastApplied value should be 0 since we have not received 569 // an ack from the tower. 570 lastApplied := h.commitUpdate(&session.ID, update1, nil) 571 if lastApplied != 0 { 572 h.t.Fatalf("last applied mismatch, want: 0, got: %v", 573 lastApplied) 574 } 575 576 // Assert that the committed update appears in the client session's 577 // CommittedUpdates map when loaded from disk and that there are no 578 // AckedUpdates. 579 dbSession := h.listSessions(nil)[session.ID] 580 checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ 581 *update1, 582 }) 583 checkAckedUpdates(h.t, dbSession, nil) 584 585 // Try to commit the same update, which should succeed due to 586 // idempotency (which is preserved when the breach hint is identical to 587 // the on-disk update's hint). The lastApplied value should remain 588 // unchanged. 589 lastApplied2 := h.commitUpdate(&session.ID, update1, nil) 590 if lastApplied2 != lastApplied { 591 h.t.Fatalf("last applied should not have changed, got %v", 592 lastApplied2) 593 } 594 595 // Assert that the loaded ClientSession is the same as before. 596 dbSession = h.listSessions(nil)[session.ID] 597 checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ 598 *update1, 599 }) 600 checkAckedUpdates(h.t, dbSession, nil) 601 602 // Generate another random update and try to commit it at the identical 603 // sequence number. Since the breach hint has changed, this should fail. 604 update2 := randCommittedUpdate(h.t, 1) 605 h.commitUpdate(&session.ID, update2, wtdb.ErrUpdateAlreadyCommitted) 606 607 // Next, insert the new update at the next unallocated sequence number 608 // which should succeed. 609 update2.SeqNum = 2 610 lastApplied3 := h.commitUpdate(&session.ID, update2, nil) 611 if lastApplied3 != lastApplied { 612 h.t.Fatalf("last applied should not have changed, got %v", 613 lastApplied3) 614 } 615 616 // Check that both updates now appear as committed on the ClientSession 617 // loaded from disk. 618 dbSession = h.listSessions(nil)[session.ID] 619 checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ 620 *update1, 621 *update2, 622 }) 623 checkAckedUpdates(h.t, dbSession, nil) 624 625 // Finally, create one more random update and try to commit it at index 626 // 4, which should be rejected since 3 is the next slot the database 627 // expects. 628 update4 := randCommittedUpdate(h.t, 4) 629 h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) 630 631 // Assert that the ClientSession loaded from disk remains unchanged. 632 dbSession = h.listSessions(nil)[session.ID] 633 checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{ 634 *update1, 635 *update2, 636 }) 637 checkAckedUpdates(h.t, dbSession, nil) 638 } 639 640 // testAckUpdate asserts the behavior of AckUpdate. 641 func testAckUpdate(h *clientDBHarness) { 642 const blobType = blob.TypeAltruistCommit 643 644 // Create a new session that the updates in this will be tied to. 645 session := &wtdb.ClientSession{ 646 ClientSessionBody: wtdb.ClientSessionBody{ 647 TowerID: wtdb.TowerID(3), 648 Policy: wtpolicy.Policy{ 649 TxPolicy: wtpolicy.TxPolicy{ 650 BlobType: blobType, 651 }, 652 MaxUpdates: 100, 653 }, 654 RewardPkScript: []byte{0x01, 0x02, 0x03}, 655 }, 656 ID: wtdb.SessionID([33]byte{0x03}), 657 } 658 659 // Try to ack an update before inserting the client session, which 660 // should fail. 661 h.ackUpdate(&session.ID, 1, 0, wtdb.ErrClientSessionNotFound) 662 663 // Reserve a session key and insert the client session. 664 session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType) 665 h.insertSession(session, nil) 666 667 // Now, try to ack update 1. This should fail since update 1 was never 668 // committed. 669 h.ackUpdate(&session.ID, 1, 0, wtdb.ErrCommittedUpdateNotFound) 670 671 // Commit to a random update at seqnum 1. 672 update1 := randCommittedUpdate(h.t, 1) 673 lastApplied := h.commitUpdate(&session.ID, update1, nil) 674 if lastApplied != 0 { 675 h.t.Fatalf("last applied mismatch, want: 0, got: %v", 676 lastApplied) 677 } 678 679 // Acking seqnum 1 should succeed. 680 h.ackUpdate(&session.ID, 1, 1, nil) 681 682 // Acking seqnum 1 again should fail. 683 h.ackUpdate(&session.ID, 1, 1, wtdb.ErrCommittedUpdateNotFound) 684 685 // Acking a valid seqnum with a reverted last applied value should fail. 686 h.ackUpdate(&session.ID, 1, 0, wtdb.ErrLastAppliedReversion) 687 688 // Acking with a last applied greater than any allocated seqnum should 689 // fail. 690 h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) 691 692 // Assert that the ClientSession loaded from disk has one update in it's 693 // AckedUpdates map, and that the committed update has been removed. 694 dbSession := h.listSessions(nil)[session.ID] 695 checkCommittedUpdates(h.t, dbSession, nil) 696 checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ 697 1: update1.BackupID, 698 }) 699 700 // Commit to another random update, and assert that the last applied 701 // value is 1, since this was what was provided in the last successful 702 // ack. 703 update2 := randCommittedUpdate(h.t, 2) 704 lastApplied = h.commitUpdate(&session.ID, update2, nil) 705 if lastApplied != 1 { 706 h.t.Fatalf("last applied mismatch, want: 1, got: %v", 707 lastApplied) 708 } 709 710 // Ack seqnum 2. 711 h.ackUpdate(&session.ID, 2, 2, nil) 712 713 // Assert that both updates exist as AckedUpdates when loaded from disk. 714 dbSession = h.listSessions(nil)[session.ID] 715 checkCommittedUpdates(h.t, dbSession, nil) 716 checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{ 717 1: update1.BackupID, 718 2: update2.BackupID, 719 }) 720 721 // Acking again with a lower last applied should fail. 722 h.ackUpdate(&session.ID, 2, 1, wtdb.ErrLastAppliedReversion) 723 724 // Acking an unallocated seqnum should fail. 725 h.ackUpdate(&session.ID, 4, 2, wtdb.ErrCommittedUpdateNotFound) 726 727 // Acking with a last applied greater than any allocated seqnum should 728 // fail. 729 h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) 730 } 731 732 // checkCommittedUpdates asserts that the CommittedUpdates on session match the 733 // expUpdates provided. 734 func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, 735 expUpdates []wtdb.CommittedUpdate) { 736 737 t.Helper() 738 739 // We promote nil expUpdates to an initialized slice since the database 740 // should never return a nil slice. This promotion is done purely out of 741 // convenience for the testing framework. 742 if expUpdates == nil { 743 expUpdates = make([]wtdb.CommittedUpdate, 0) 744 } 745 746 if !reflect.DeepEqual(session.CommittedUpdates, expUpdates) { 747 t.Fatalf("committed updates mismatch, want: %v, got: %v", 748 expUpdates, session.CommittedUpdates) 749 } 750 } 751 752 // checkAckedUpdates asserts that the AckedUpdates on a sessio match the 753 // expUpdates provided. 754 func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, 755 expUpdates map[uint16]wtdb.BackupID) { 756 757 // We promote nil expUpdates to an initialized map since the database 758 // should never return a nil map. This promotion is done purely out of 759 // convenience for the testing framework. 760 if expUpdates == nil { 761 expUpdates = make(map[uint16]wtdb.BackupID) 762 } 763 764 if !reflect.DeepEqual(session.AckedUpdates, expUpdates) { 765 t.Fatalf("acked updates mismatch, want: %v, got: %v", 766 expUpdates, session.AckedUpdates) 767 } 768 } 769 770 // TestClientDB asserts the behavior of a fresh client db, a reopened client db, 771 // and the mock implementation. This ensures that all databases function 772 // identically, especially in the negative paths. 773 func TestClientDB(t *testing.T) { 774 dbCfg := &kvdb.BoltConfig{DBTimeout: kvdb.DefaultDBTimeout} 775 dbs := []struct { 776 name string 777 init clientDBInit 778 }{ 779 { 780 name: "fresh clientdb", 781 init: func(t *testing.T) (wtclient.DB, func()) { 782 path, err := ioutil.TempDir("", "clientdb") 783 if err != nil { 784 t.Fatalf("unable to make temp dir: %v", 785 err) 786 } 787 788 bdb, err := wtdb.NewBoltBackendCreator( 789 true, path, "wtclient.db", 790 )(dbCfg) 791 if err != nil { 792 os.RemoveAll(path) 793 t.Fatalf("unable to open db: %v", err) 794 } 795 796 db, err := wtdb.OpenClientDB(bdb) 797 if err != nil { 798 os.RemoveAll(path) 799 t.Fatalf("unable to open db: %v", err) 800 } 801 802 cleanup := func() { 803 db.Close() 804 os.RemoveAll(path) 805 } 806 807 return db, cleanup 808 }, 809 }, 810 { 811 name: "reopened clientdb", 812 init: func(t *testing.T) (wtclient.DB, func()) { 813 path, err := ioutil.TempDir("", "clientdb") 814 if err != nil { 815 t.Fatalf("unable to make temp dir: %v", 816 err) 817 } 818 819 bdb, err := wtdb.NewBoltBackendCreator( 820 true, path, "wtclient.db", 821 )(dbCfg) 822 if err != nil { 823 os.RemoveAll(path) 824 t.Fatalf("unable to open db: %v", err) 825 } 826 827 db, err := wtdb.OpenClientDB(bdb) 828 if err != nil { 829 os.RemoveAll(path) 830 t.Fatalf("unable to open db: %v", err) 831 } 832 db.Close() 833 834 bdb, err = wtdb.NewBoltBackendCreator( 835 true, path, "wtclient.db", 836 )(dbCfg) 837 if err != nil { 838 os.RemoveAll(path) 839 t.Fatalf("unable to open db: %v", err) 840 } 841 842 db, err = wtdb.OpenClientDB(bdb) 843 if err != nil { 844 os.RemoveAll(path) 845 t.Fatalf("unable to reopen db: %v", err) 846 } 847 848 cleanup := func() { 849 db.Close() 850 os.RemoveAll(path) 851 } 852 853 return db, cleanup 854 }, 855 }, 856 { 857 name: "mock", 858 init: func(t *testing.T) (wtclient.DB, func()) { 859 return wtmock.NewClientDB(), func() {} 860 }, 861 }, 862 } 863 864 tests := []struct { 865 name string 866 run func(*clientDBHarness) 867 }{ 868 { 869 name: "create client session", 870 run: testCreateClientSession, 871 }, 872 { 873 name: "filter client sessions", 874 run: testFilterClientSessions, 875 }, 876 { 877 name: "create tower", 878 run: testCreateTower, 879 }, 880 { 881 name: "remove tower", 882 run: testRemoveTower, 883 }, 884 { 885 name: "chan summaries", 886 run: testChanSummaries, 887 }, 888 { 889 name: "commit update", 890 run: testCommitUpdate, 891 }, 892 { 893 name: "ack update", 894 run: testAckUpdate, 895 }, 896 } 897 898 for _, database := range dbs { 899 db := database 900 t.Run(db.name, func(t *testing.T) { 901 t.Parallel() 902 903 for _, test := range tests { 904 t.Run(test.name, func(t *testing.T) { 905 h, cleanup := newClientDBHarness( 906 t, db.init, 907 ) 908 defer cleanup() 909 910 test.run(h) 911 }) 912 } 913 }) 914 } 915 } 916 917 // randCommittedUpdate generates a random committed update. 918 func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { 919 var chanID lnwire.ChannelID 920 if _, err := io.ReadFull(crand.Reader, chanID[:]); err != nil { 921 t.Fatalf("unable to generate chan id: %v", err) 922 } 923 924 var hint blob.BreachHint 925 if _, err := io.ReadFull(crand.Reader, hint[:]); err != nil { 926 t.Fatalf("unable to generate breach hint: %v", err) 927 } 928 929 encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) 930 if _, err := io.ReadFull(crand.Reader, encBlob); err != nil { 931 t.Fatalf("unable to generate encrypted blob: %v", err) 932 } 933 934 return &wtdb.CommittedUpdate{ 935 SeqNum: seqNum, 936 CommittedUpdateBody: wtdb.CommittedUpdateBody{ 937 BackupID: wtdb.BackupID{ 938 ChanID: chanID, 939 CommitHeight: 666, 940 }, 941 Hint: hint, 942 EncryptedBlob: encBlob, 943 }, 944 } 945 }