github.com/decred/dcrlnd@v0.7.6/channeldb/channel_test.go (about) 1 package channeldb 2 3 import ( 4 "bytes" 5 "math/rand" 6 "net" 7 "reflect" 8 "runtime" 9 "testing" 10 11 "github.com/davecgh/go-spew/spew" 12 "github.com/stretchr/testify/require" 13 14 "github.com/decred/dcrd/chaincfg/chainhash" 15 "github.com/decred/dcrd/dcrec/secp256k1/v4" 16 "github.com/decred/dcrd/dcrutil/v4" 17 "github.com/decred/dcrd/wire" 18 "github.com/decred/dcrlnd/clock" 19 "github.com/decred/dcrlnd/keychain" 20 "github.com/decred/dcrlnd/kvdb" 21 "github.com/decred/dcrlnd/lntest/channels" 22 "github.com/decred/dcrlnd/lnwire" 23 "github.com/decred/dcrlnd/shachain" 24 ) 25 26 func privKeyFromBytes(b []byte) (*secp256k1.PrivateKey, *secp256k1.PublicKey) { 27 k := secp256k1.PrivKeyFromBytes(b) 28 return k, k.PubKey() 29 } 30 31 var ( 32 key = [chainhash.HashSize]byte{ 33 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, 34 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, 35 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, 36 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, 37 } 38 rev = [chainhash.HashSize]byte{ 39 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, 40 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, 41 0x2d, 0xe7, 0x93, 0xe4, 42 } 43 44 privKey, pubKey = privKeyFromBytes(key[:]) 45 46 wireSig, _ = lnwire.NewSigFromSignature(testSig) 47 48 testClock = clock.NewTestClock(testNow) 49 50 // defaultPendingHeight is the default height at which we set 51 // channels to pending. 52 defaultPendingHeight = 100 53 54 // defaultAddr is the default address that we mark test channels pending 55 // with. 56 defaultAddr = &net.TCPAddr{ 57 IP: net.ParseIP("127.0.0.1"), 58 Port: 18555, 59 } 60 61 // keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding. 62 keyLocIndex = uint32(2049) 63 ) 64 65 // testChannelParams is a struct which details the specifics of how a channel 66 // should be created. 67 type testChannelParams struct { 68 // channel is the channel that will be written to disk. 69 channel *OpenChannel 70 71 // addr is the address that the channel will be synced pending with. 72 addr *net.TCPAddr 73 74 // pendingHeight is the height that the channel should be recorded as 75 // pending. 76 pendingHeight uint32 77 78 // openChannel is set to true if the channel should be fully marked as 79 // open if this is false, the channel will be left in pending state. 80 openChannel bool 81 } 82 83 // testChannelOption is a functional option which can be used to alter the 84 // default channel that is creates for testing. 85 type testChannelOption func(params *testChannelParams) 86 87 // channelCommitmentOption is an option which allows overwriting of the default 88 // commitment height and balances. The local boolean can be used to set these 89 // balances on the local or remote commit. 90 func channelCommitmentOption(height uint64, localBalance, 91 remoteBalance lnwire.MilliAtom, local bool) testChannelOption { 92 93 return func(params *testChannelParams) { 94 if local { 95 params.channel.LocalCommitment.CommitHeight = height 96 params.channel.LocalCommitment.LocalBalance = localBalance 97 params.channel.LocalCommitment.RemoteBalance = remoteBalance 98 } else { 99 params.channel.RemoteCommitment.CommitHeight = height 100 params.channel.RemoteCommitment.LocalBalance = localBalance 101 params.channel.RemoteCommitment.RemoteBalance = remoteBalance 102 } 103 } 104 } 105 106 // pendingHeightOption is an option which can be used to set the height the 107 // channel is marked as pending at. 108 func pendingHeightOption(height uint32) testChannelOption { 109 return func(params *testChannelParams) { 110 params.pendingHeight = height 111 } 112 } 113 114 // openChannelOption is an option which can be used to create a test channel 115 // that is open. 116 func openChannelOption() testChannelOption { 117 return func(params *testChannelParams) { 118 params.openChannel = true 119 } 120 } 121 122 // localHtlcsOption is an option which allows setting of htlcs on the local 123 // commitment. 124 func localHtlcsOption(htlcs []HTLC) testChannelOption { 125 return func(params *testChannelParams) { 126 params.channel.LocalCommitment.Htlcs = htlcs 127 } 128 } 129 130 // remoteHtlcsOption is an option which allows setting of htlcs on the remote 131 // commitment. 132 func remoteHtlcsOption(htlcs []HTLC) testChannelOption { 133 return func(params *testChannelParams) { 134 params.channel.RemoteCommitment.Htlcs = htlcs 135 } 136 } 137 138 // loadFwdPkgs is a helper method that reads all forwarding packages for a 139 // particular packager. 140 func loadFwdPkgs(t *testing.T, db kvdb.Backend, 141 packager FwdPackager) []*FwdPkg { 142 143 var ( 144 fwdPkgs []*FwdPkg 145 err error 146 ) 147 148 err = kvdb.View(db, func(tx kvdb.RTx) error { 149 fwdPkgs, err = packager.LoadFwdPkgs(tx) 150 return err 151 }, func() {}) 152 require.NoError(t, err, "unable to load fwd pkgs") 153 154 return fwdPkgs 155 } 156 157 // localShutdownOption is an option which sets the local upfront shutdown 158 // script for the channel. 159 func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { 160 return func(params *testChannelParams) { 161 params.channel.LocalShutdownScript = addr 162 } 163 } 164 165 // remoteShutdownOption is an option which sets the remote upfront shutdown 166 // script for the channel. 167 func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption { 168 return func(params *testChannelParams) { 169 params.channel.RemoteShutdownScript = addr 170 } 171 } 172 173 // fundingPointOption is an option which sets the funding outpoint of the 174 // channel. 175 func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { 176 return func(params *testChannelParams) { 177 params.channel.FundingOutpoint = chanPoint 178 } 179 } 180 181 // channelIDOption is an option which sets the short channel ID of the channel. 182 var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { 183 return func(params *testChannelParams) { 184 params.channel.ShortChannelID = chanID 185 } 186 } 187 188 // createTestChannel writes a test channel to the database. It takes a set of 189 // functional options which can be used to overwrite the default of creating 190 // a pending channel that was broadcast at height 100. 191 func createTestChannel(t *testing.T, cdb *ChannelStateDB, 192 opts ...testChannelOption) *OpenChannel { 193 194 // Create a default set of parameters. 195 params := &testChannelParams{ 196 channel: createTestChannelState(t, cdb), 197 addr: defaultAddr, 198 openChannel: false, 199 pendingHeight: uint32(defaultPendingHeight), 200 } 201 202 // Apply all functional options to the test channel params. 203 for _, o := range opts { 204 o(params) 205 } 206 207 // Mark the channel as pending. 208 err := params.channel.SyncPending(params.addr, params.pendingHeight) 209 if err != nil { 210 t.Fatalf("unable to save and serialize channel "+ 211 "state: %v", err) 212 } 213 214 // If the parameters do not specify that we should open the channel 215 // fully, we return the pending channel. 216 if !params.openChannel { 217 return params.channel 218 } 219 220 // Mark the channel as open with the short channel id provided. 221 err = params.channel.MarkAsOpen(params.channel.ShortChannelID) 222 if err != nil { 223 t.Fatalf("unable to mark channel open: %v", err) 224 } 225 226 return params.channel 227 } 228 229 func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { 230 // Simulate 1000 channel updates. 231 producer, err := shachain.NewRevocationProducerFromBytes(key[:]) 232 if err != nil { 233 t.Fatalf("could not get producer: %v", err) 234 } 235 store := shachain.NewRevocationStore() 236 for i := 0; i < 1; i++ { 237 preImage, err := producer.AtIndex(uint64(i)) 238 if err != nil { 239 t.Fatalf("could not get "+ 240 "preimage: %v", err) 241 } 242 243 if err := store.AddNextEntry(preImage); err != nil { 244 t.Fatalf("could not add entry: %v", err) 245 } 246 } 247 248 localCfg := ChannelConfig{ 249 ChannelConstraints: ChannelConstraints{ 250 DustLimit: dcrutil.Amount(rand.Int63()), 251 MaxPendingAmount: lnwire.MilliAtom(rand.Int63()), 252 ChanReserve: dcrutil.Amount(rand.Int63()), 253 MinHTLC: lnwire.MilliAtom(rand.Int63()), 254 MaxAcceptedHtlcs: uint16(rand.Int31()), 255 CsvDelay: uint16(rand.Int31()), 256 }, 257 MultiSigKey: keychain.KeyDescriptor{ 258 PubKey: privKey.PubKey(), 259 }, 260 RevocationBasePoint: keychain.KeyDescriptor{ 261 PubKey: privKey.PubKey(), 262 }, 263 PaymentBasePoint: keychain.KeyDescriptor{ 264 PubKey: privKey.PubKey(), 265 }, 266 DelayBasePoint: keychain.KeyDescriptor{ 267 PubKey: privKey.PubKey(), 268 }, 269 HtlcBasePoint: keychain.KeyDescriptor{ 270 PubKey: privKey.PubKey(), 271 }, 272 } 273 remoteCfg := ChannelConfig{ 274 ChannelConstraints: ChannelConstraints{ 275 DustLimit: dcrutil.Amount(rand.Int63()), 276 MaxPendingAmount: lnwire.MilliAtom(rand.Int63()), 277 ChanReserve: dcrutil.Amount(rand.Int63()), 278 MinHTLC: lnwire.MilliAtom(rand.Int63()), 279 MaxAcceptedHtlcs: uint16(rand.Int31()), 280 CsvDelay: uint16(rand.Int31()), 281 }, 282 MultiSigKey: keychain.KeyDescriptor{ 283 PubKey: privKey.PubKey(), 284 KeyLocator: keychain.KeyLocator{ 285 Family: keychain.KeyFamilyMultiSig, 286 Index: 9, 287 }, 288 }, 289 RevocationBasePoint: keychain.KeyDescriptor{ 290 PubKey: privKey.PubKey(), 291 KeyLocator: keychain.KeyLocator{ 292 Family: keychain.KeyFamilyRevocationBase, 293 Index: 8, 294 }, 295 }, 296 PaymentBasePoint: keychain.KeyDescriptor{ 297 PubKey: privKey.PubKey(), 298 KeyLocator: keychain.KeyLocator{ 299 Family: keychain.KeyFamilyPaymentBase, 300 Index: 7, 301 }, 302 }, 303 DelayBasePoint: keychain.KeyDescriptor{ 304 PubKey: privKey.PubKey(), 305 KeyLocator: keychain.KeyLocator{ 306 Family: keychain.KeyFamilyDelayBase, 307 Index: 6, 308 }, 309 }, 310 HtlcBasePoint: keychain.KeyDescriptor{ 311 PubKey: privKey.PubKey(), 312 KeyLocator: keychain.KeyLocator{ 313 Family: keychain.KeyFamilyHtlcBase, 314 Index: 5, 315 }, 316 }, 317 } 318 319 chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63())) 320 321 return &OpenChannel{ 322 ChanType: SingleFunderBit | FrozenBit, 323 ChainHash: key, 324 FundingOutpoint: wire.OutPoint{Hash: key, Index: rand.Uint32()}, 325 ShortChannelID: chanID, 326 IsInitiator: true, 327 IsPending: true, 328 IdentityPub: pubKey, 329 Capacity: dcrutil.Amount(10000), 330 LocalChanCfg: localCfg, 331 RemoteChanCfg: remoteCfg, 332 TotalMAtomsSent: 8, 333 TotalMAtomsReceived: 2, 334 LocalCommitment: ChannelCommitment{ 335 CommitHeight: 0, 336 LocalBalance: lnwire.MilliAtom(9000), 337 RemoteBalance: lnwire.MilliAtom(3000), 338 CommitFee: dcrutil.Amount(rand.Int63()), 339 FeePerKB: dcrutil.Amount(5000), 340 CommitTx: channels.TestFundingTx, 341 CommitSig: bytes.Repeat([]byte{1}, 71), 342 }, 343 RemoteCommitment: ChannelCommitment{ 344 CommitHeight: 0, 345 LocalBalance: lnwire.MilliAtom(3000), 346 RemoteBalance: lnwire.MilliAtom(9000), 347 CommitFee: dcrutil.Amount(rand.Int63()), 348 FeePerKB: dcrutil.Amount(5000), 349 CommitTx: channels.TestFundingTx, 350 CommitSig: bytes.Repeat([]byte{1}, 71), 351 }, 352 NumConfsRequired: 4, 353 RemoteCurrentRevocation: privKey.PubKey(), 354 RemoteNextRevocation: privKey.PubKey(), 355 RevocationProducer: producer, 356 RevocationStore: store, 357 Db: cdb, 358 Packager: NewChannelPackager(chanID), 359 FundingTxn: channels.TestFundingTx, 360 ThawHeight: uint32(defaultPendingHeight), 361 } 362 } 363 364 func TestOpenChannelPutGetDelete(t *testing.T) { 365 t.Parallel() 366 367 fullDB, cleanUp, err := MakeTestDB() 368 if err != nil { 369 t.Fatalf("unable to make test database: %v", err) 370 } 371 defer cleanUp() 372 373 cdb := fullDB.ChannelStateDB() 374 375 // Create the test channel state, with additional htlcs on the local 376 // and remote commitment. 377 localHtlcs := []HTLC{ 378 {Signature: testSig.Serialize(), 379 Incoming: true, 380 Amt: 10, 381 RHash: key, 382 RefundTimeout: 1, 383 OnionBlob: []byte("onionblob"), 384 }, 385 } 386 387 remoteHtlcs := []HTLC{ 388 { 389 Signature: testSig.Serialize(), 390 Incoming: false, 391 Amt: 10, 392 RHash: key, 393 RefundTimeout: 1, 394 OnionBlob: []byte("onionblob"), 395 }, 396 } 397 398 state := createTestChannel( 399 t, cdb, 400 remoteHtlcsOption(remoteHtlcs), 401 localHtlcsOption(localHtlcs), 402 ) 403 404 openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) 405 if err != nil { 406 t.Fatalf("unable to fetch open channel: %v", err) 407 } 408 409 newState := openChannels[0] 410 411 // The decoded channel state should be identical to what we stored 412 // above. 413 if !reflect.DeepEqual(state, newState) { 414 t.Fatalf("channel state doesn't match:: %v vs %v", 415 spew.Sdump(state), spew.Sdump(newState)) 416 } 417 418 // We'll also test that the channel is properly able to hot swap the 419 // next revocation for the state machine. This tests the initial 420 // post-funding revocation exchange. 421 nextRevKey, err := secp256k1.GeneratePrivateKey() 422 if err != nil { 423 t.Fatalf("unable to create new private key: %v", err) 424 } 425 nextRevKeyPub := nextRevKey.PubKey() 426 if err := state.InsertNextRevocation(nextRevKeyPub); err != nil { 427 t.Fatalf("unable to update revocation: %v", err) 428 } 429 430 openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) 431 if err != nil { 432 t.Fatalf("unable to fetch open channel: %v", err) 433 } 434 updatedChan := openChannels[0] 435 436 // Ensure that the revocation was set properly. 437 if !nextRevKeyPub.IsEqual(updatedChan.RemoteNextRevocation) { 438 t.Fatalf("next revocation wasn't updated") 439 } 440 441 // Finally to wrap up the test, delete the state of the channel within 442 // the database. This involves "closing" the channel which removes all 443 // written state, and creates a small "summary" elsewhere within the 444 // database. 445 closeSummary := &ChannelCloseSummary{ 446 ChanPoint: state.FundingOutpoint, 447 RemotePub: state.IdentityPub, 448 SettledBalance: dcrutil.Amount(500), 449 TimeLockedBalance: dcrutil.Amount(10000), 450 IsPending: false, 451 CloseType: CooperativeClose, 452 } 453 if err := state.CloseChannel(closeSummary); err != nil { 454 t.Fatalf("unable to close channel: %v", err) 455 } 456 457 // As the channel is now closed, attempting to fetch all open channels 458 // for our fake node ID should return an empty slice. 459 openChans, err := cdb.FetchOpenChannels(state.IdentityPub) 460 if err != nil { 461 t.Fatalf("unable to fetch open channels: %v", err) 462 } 463 if len(openChans) != 0 { 464 t.Fatalf("all channels not deleted, found %v", len(openChans)) 465 } 466 467 // Additionally, attempting to fetch all the open channels globally 468 // should yield no results. 469 openChans, err = cdb.FetchAllChannels() 470 if err != nil { 471 t.Fatal("unable to fetch all open chans") 472 } 473 if len(openChans) != 0 { 474 t.Fatalf("all channels not deleted, found %v", len(openChans)) 475 } 476 } 477 478 // TestOptionalShutdown tests the reading and writing of channels with and 479 // without optional shutdown script fields. 480 func TestOptionalShutdown(t *testing.T) { 481 local := lnwire.DeliveryAddress([]byte("local shutdown script")) 482 remote := lnwire.DeliveryAddress([]byte("remote shutdown script")) 483 484 if _, err := rand.Read(remote); err != nil { 485 t.Fatalf("Could not create random script: %v", err) 486 } 487 488 tests := []struct { 489 name string 490 localShutdown lnwire.DeliveryAddress 491 remoteShutdown lnwire.DeliveryAddress 492 }{ 493 { 494 name: "no shutdown scripts", 495 localShutdown: nil, 496 remoteShutdown: nil, 497 }, 498 { 499 name: "local shutdown script", 500 localShutdown: local, 501 remoteShutdown: nil, 502 }, 503 { 504 name: "remote shutdown script", 505 localShutdown: nil, 506 remoteShutdown: remote, 507 }, 508 { 509 name: "both scripts set", 510 localShutdown: local, 511 remoteShutdown: remote, 512 }, 513 } 514 515 for _, test := range tests { 516 test := test 517 518 t.Run(test.name, func(t *testing.T) { 519 fullDB, cleanUp, err := MakeTestDB() 520 if err != nil { 521 t.Fatalf("unable to make test database: %v", err) 522 } 523 defer cleanUp() 524 525 cdb := fullDB.ChannelStateDB() 526 527 // Create a channel with upfront scripts set as 528 // specified in the test. 529 state := createTestChannel( 530 t, cdb, 531 localShutdownOption(test.localShutdown), 532 remoteShutdownOption(test.remoteShutdown), 533 ) 534 535 openChannels, err := cdb.FetchOpenChannels( 536 state.IdentityPub, 537 ) 538 if err != nil { 539 t.Fatalf("unable to fetch open"+ 540 " channel: %v", err) 541 } 542 543 if len(openChannels) != 1 { 544 t.Fatalf("Expected one channel open,"+ 545 " got: %v", len(openChannels)) 546 } 547 548 if !bytes.Equal(openChannels[0].LocalShutdownScript, 549 test.localShutdown) { 550 551 t.Fatalf("Expected local: %x, got: %x", 552 test.localShutdown, 553 openChannels[0].LocalShutdownScript) 554 } 555 556 if !bytes.Equal(openChannels[0].RemoteShutdownScript, 557 test.remoteShutdown) { 558 559 t.Fatalf("Expected remote: %x, got: %x", 560 test.remoteShutdown, 561 openChannels[0].RemoteShutdownScript) 562 } 563 }) 564 } 565 } 566 567 func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { 568 if !reflect.DeepEqual(a, b) { 569 _, _, line, _ := runtime.Caller(1) 570 t.Fatalf("line %v: commitments don't match: %v vs %v", 571 line, spew.Sdump(a), spew.Sdump(b)) 572 } 573 } 574 575 func TestChannelStateTransition(t *testing.T) { 576 t.Parallel() 577 578 fullDB, cleanUp, err := MakeTestDB() 579 if err != nil { 580 t.Fatalf("unable to make test database: %v", err) 581 } 582 defer cleanUp() 583 584 cdb := fullDB.ChannelStateDB() 585 586 // First create a minimal channel, then perform a full sync in order to 587 // persist the data. 588 channel := createTestChannel(t, cdb) 589 590 // Add some HTLCs which were added during this new state transition. 591 // Half of the HTLCs are incoming, while the other half are outgoing. 592 var ( 593 htlcs []HTLC 594 htlcAmt lnwire.MilliAtom 595 ) 596 for i := uint32(0); i < 10; i++ { 597 var incoming bool 598 if i > 5 { 599 incoming = true 600 } 601 htlc := HTLC{ 602 Signature: testSig.Serialize(), 603 Incoming: incoming, 604 Amt: 10, 605 RHash: key, 606 RefundTimeout: i, 607 OutputIndex: int32(i * 3), 608 LogIndex: uint64(i * 2), 609 HtlcIndex: uint64(i), 610 } 611 htlc.OnionBlob = make([]byte, 10) 612 copy(htlc.OnionBlob, bytes.Repeat([]byte{2}, 10)) 613 htlcs = append(htlcs, htlc) 614 htlcAmt += htlc.Amt 615 } 616 617 // Create a new channel delta which includes the above HTLCs, some 618 // balance updates, and an increment of the current commitment height. 619 // Additionally, modify the signature and commitment transaction. 620 newSequence := uint32(129498) 621 newSig := bytes.Repeat([]byte{3}, 71) 622 newTx := channel.LocalCommitment.CommitTx.Copy() 623 newTx.TxIn[0].Sequence = newSequence 624 commitment := ChannelCommitment{ 625 CommitHeight: 1, 626 LocalLogIndex: 2, 627 LocalHtlcIndex: 1, 628 RemoteLogIndex: 2, 629 RemoteHtlcIndex: 1, 630 LocalBalance: lnwire.MilliAtom(1e8), 631 RemoteBalance: lnwire.MilliAtom(1e8), 632 CommitFee: 55, 633 FeePerKB: 99, 634 CommitTx: newTx, 635 CommitSig: newSig, 636 Htlcs: htlcs, 637 } 638 639 // First update the local node's broadcastable state and also add a 640 // CommitDiff remote node's as well in order to simulate a proper state 641 // transition. 642 unsignedAckedUpdates := []LogUpdate{ 643 { 644 LogIndex: 2, 645 UpdateMsg: &lnwire.UpdateAddHTLC{ 646 ChanID: lnwire.ChannelID{1, 2, 3}, 647 ExtraData: make([]byte, 0), 648 }, 649 }, 650 } 651 652 err = channel.UpdateCommitment(&commitment, unsignedAckedUpdates) 653 if err != nil { 654 t.Fatalf("unable to update commitment: %v", err) 655 } 656 657 // Assert that update is correctly written to the database. 658 dbUnsignedAckedUpdates, err := channel.UnsignedAckedUpdates() 659 if err != nil { 660 t.Fatalf("unable to fetch dangling remote updates: %v", err) 661 } 662 if len(dbUnsignedAckedUpdates) != 1 { 663 t.Fatalf("unexpected number of dangling remote updates") 664 } 665 if !reflect.DeepEqual( 666 dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0], 667 ) { 668 t.Fatalf("unexpected update: expected %v, got %v", 669 spew.Sdump(unsignedAckedUpdates[0]), 670 spew.Sdump(dbUnsignedAckedUpdates)) 671 } 672 673 // The balances, new update, the HTLCs and the changes to the fake 674 // commitment transaction along with the modified signature should all 675 // have been updated. 676 updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) 677 if err != nil { 678 t.Fatalf("unable to fetch updated channel: %v", err) 679 } 680 assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) 681 numDiskUpdates, err := updatedChannel[0].CommitmentHeight() 682 if err != nil { 683 t.Fatalf("unable to read commitment height from disk: %v", err) 684 } 685 if numDiskUpdates != commitment.CommitHeight { 686 t.Fatalf("num disk updates doesn't match: %v vs %v", 687 numDiskUpdates, commitment.CommitHeight) 688 } 689 690 // Attempting to query for a commitment diff should return 691 // ErrNoPendingCommit as we haven't yet created a new state for them. 692 _, err = channel.RemoteCommitChainTip() 693 if err != ErrNoPendingCommit { 694 t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) 695 } 696 697 // To simulate us extending a new state to the remote party, we'll also 698 // create a new commit diff for them. 699 remoteCommit := commitment 700 remoteCommit.LocalBalance = lnwire.MilliAtom(2e8) 701 remoteCommit.RemoteBalance = lnwire.MilliAtom(3e8) 702 remoteCommit.CommitHeight = 1 703 commitDiff := &CommitDiff{ 704 Commitment: remoteCommit, 705 CommitSig: &lnwire.CommitSig{ 706 ChanID: lnwire.ChannelID(key), 707 CommitSig: wireSig, 708 HtlcSigs: []lnwire.Sig{ 709 wireSig, 710 wireSig, 711 }, 712 ExtraData: make([]byte, 0), 713 }, 714 LogUpdates: []LogUpdate{ 715 { 716 LogIndex: 1, 717 UpdateMsg: &lnwire.UpdateAddHTLC{ 718 ID: 1, 719 Amount: lnwire.NewMAtomsFromAtoms(100), 720 Expiry: 25, 721 ExtraData: make([]byte, 0), 722 }, 723 }, 724 { 725 LogIndex: 2, 726 UpdateMsg: &lnwire.UpdateAddHTLC{ 727 ID: 2, 728 Amount: lnwire.NewMAtomsFromAtoms(200), 729 Expiry: 50, 730 ExtraData: make([]byte, 0), 731 }, 732 }, 733 }, 734 OpenedCircuitKeys: []CircuitKey{}, 735 ClosedCircuitKeys: []CircuitKey{}, 736 } 737 copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], 738 bytes.Repeat([]byte{1}, 32)) 739 copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], 740 bytes.Repeat([]byte{2}, 32)) 741 if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { 742 t.Fatalf("unable to add to commit chain: %v", err) 743 } 744 745 // The commitment tip should now match the commitment that we just 746 // inserted. 747 diskCommitDiff, err := channel.RemoteCommitChainTip() 748 if err != nil { 749 t.Fatalf("unable to fetch commit diff: %v", err) 750 } 751 if !reflect.DeepEqual(commitDiff, diskCommitDiff) { 752 t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), 753 spew.Sdump(diskCommitDiff)) 754 } 755 756 // We'll save the old remote commitment as this will be added to the 757 // revocation log shortly. 758 oldRemoteCommit := channel.RemoteCommitment 759 760 // Next, write to the log which tracks the necessary revocation state 761 // needed to rectify any fishy behavior by the remote party. Modify the 762 // current uncollapsed revocation state to simulate a state transition 763 // by the remote party. 764 channel.RemoteCurrentRevocation = channel.RemoteNextRevocation 765 newPriv, err := secp256k1.GeneratePrivateKey() 766 if err != nil { 767 t.Fatalf("unable to generate key: %v", err) 768 } 769 channel.RemoteNextRevocation = newPriv.PubKey() 770 771 fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, 772 diskCommitDiff.LogUpdates, nil) 773 774 err = channel.AdvanceCommitChainTail(fwdPkg, nil) 775 if err != nil { 776 t.Fatalf("unable to append to revocation log: %v", err) 777 } 778 779 // At this point, the remote commit chain should be nil, and the posted 780 // remote commitment should match the one we added as a diff above. 781 if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { 782 t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) 783 } 784 785 // We should be able to fetch the channel delta created above by its 786 // update number with all the state properly reconstructed. 787 diskPrevCommit, err := channel.FindPreviousState( 788 oldRemoteCommit.CommitHeight, 789 ) 790 if err != nil { 791 t.Fatalf("unable to fetch past delta: %v", err) 792 } 793 794 // The two deltas (the original vs the on-disk version) should 795 // identical, and all HTLC data should properly be retained. 796 assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) 797 798 // The state number recovered from the tail of the revocation log 799 // should be identical to this current state. 800 logTail, err := channel.RevocationLogTail() 801 if err != nil { 802 t.Fatalf("unable to retrieve log: %v", err) 803 } 804 if logTail.CommitHeight != oldRemoteCommit.CommitHeight { 805 t.Fatal("update number doesn't match") 806 } 807 808 oldRemoteCommit = channel.RemoteCommitment 809 810 // Next modify the posted diff commitment slightly, then create a new 811 // commitment diff and advance the tail. 812 commitDiff.Commitment.CommitHeight = 2 813 commitDiff.Commitment.LocalBalance -= htlcAmt 814 commitDiff.Commitment.RemoteBalance += htlcAmt 815 commitDiff.LogUpdates = []LogUpdate{} 816 if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { 817 t.Fatalf("unable to add to commit chain: %v", err) 818 } 819 820 fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) 821 822 err = channel.AdvanceCommitChainTail(fwdPkg, nil) 823 if err != nil { 824 t.Fatalf("unable to append to revocation log: %v", err) 825 } 826 827 // Once again, fetch the state and ensure it has been properly updated. 828 prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) 829 if err != nil { 830 t.Fatalf("unable to fetch past delta: %v", err) 831 } 832 assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) 833 834 // Once again, state number recovered from the tail of the revocation 835 // log should be identical to this current state. 836 logTail, err = channel.RevocationLogTail() 837 if err != nil { 838 t.Fatalf("unable to retrieve log: %v", err) 839 } 840 if logTail.CommitHeight != oldRemoteCommit.CommitHeight { 841 t.Fatal("update number doesn't match") 842 } 843 844 // The revocation state stored on-disk should now also be identical. 845 updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) 846 if err != nil { 847 t.Fatalf("unable to fetch updated channel: %v", err) 848 } 849 if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { 850 t.Fatalf("revocation state was not synced") 851 } 852 if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { 853 t.Fatalf("revocation state was not synced") 854 } 855 856 // At this point, we should have 2 forwarding packages added. 857 fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) 858 require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") 859 860 // Now attempt to delete the channel from the database. 861 closeSummary := &ChannelCloseSummary{ 862 ChanPoint: channel.FundingOutpoint, 863 RemotePub: channel.IdentityPub, 864 SettledBalance: dcrutil.Amount(500), 865 TimeLockedBalance: dcrutil.Amount(10000), 866 IsPending: false, 867 CloseType: RemoteForceClose, 868 } 869 if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { 870 t.Fatalf("unable to delete updated channel: %v", err) 871 } 872 873 // If we attempt to fetch the target channel again, it shouldn't be 874 // found. 875 channels, err := cdb.FetchOpenChannels(channel.IdentityPub) 876 if err != nil { 877 t.Fatalf("unable to fetch updated channels: %v", err) 878 } 879 if len(channels) != 0 { 880 t.Fatalf("%v channels, found, but none should be", 881 len(channels)) 882 } 883 884 // Attempting to find previous states on the channel should fail as the 885 // revocation log has been deleted. 886 _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) 887 if err == nil { 888 t.Fatal("revocation log search should have failed") 889 } 890 891 // All forwarding packages of this channel has been deleted too. 892 fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) 893 require.Empty(t, fwdPkgs, "no forwarding packages should exist") 894 } 895 896 func TestFetchPendingChannels(t *testing.T) { 897 t.Parallel() 898 899 fullDB, cleanUp, err := MakeTestDB() 900 if err != nil { 901 t.Fatalf("unable to make test database: %v", err) 902 } 903 defer cleanUp() 904 905 cdb := fullDB.ChannelStateDB() 906 907 // Create a pending channel that was broadcast at height 99. 908 const broadcastHeight = 99 909 createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) 910 911 pendingChannels, err := cdb.FetchPendingChannels() 912 if err != nil { 913 t.Fatalf("unable to list pending channels: %v", err) 914 } 915 916 if len(pendingChannels) != 1 { 917 t.Fatalf("incorrect number of pending channels: expecting %v,"+ 918 "got %v", 1, len(pendingChannels)) 919 } 920 921 // The broadcast height of the pending channel should have been set 922 // properly. 923 if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { 924 t.Fatalf("broadcast height mismatch: expected %v, got %v", 925 pendingChannels[0].FundingBroadcastHeight, 926 broadcastHeight) 927 } 928 929 chanOpenLoc := lnwire.ShortChannelID{ 930 BlockHeight: 5, 931 TxIndex: 10, 932 TxPosition: 15, 933 } 934 err = pendingChannels[0].MarkAsOpen(chanOpenLoc) 935 if err != nil { 936 t.Fatalf("unable to mark channel as open: %v", err) 937 } 938 939 if pendingChannels[0].IsPending { 940 t.Fatalf("channel marked open should no longer be pending") 941 } 942 943 if pendingChannels[0].ShortChanID() != chanOpenLoc { 944 t.Fatalf("channel opening height not updated: expected %v, "+ 945 "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), 946 chanOpenLoc) 947 } 948 949 // Next, we'll re-fetch the channel to ensure that the open height was 950 // properly set. 951 openChans, err := cdb.FetchAllChannels() 952 if err != nil { 953 t.Fatalf("unable to fetch channels: %v", err) 954 } 955 if openChans[0].ShortChanID() != chanOpenLoc { 956 t.Fatalf("channel opening heights don't match: expected %v, "+ 957 "got %v", spew.Sdump(openChans[0].ShortChanID()), 958 chanOpenLoc) 959 } 960 if openChans[0].FundingBroadcastHeight != broadcastHeight { 961 t.Fatalf("broadcast height mismatch: expected %v, got %v", 962 openChans[0].FundingBroadcastHeight, 963 broadcastHeight) 964 } 965 966 pendingChannels, err = cdb.FetchPendingChannels() 967 if err != nil { 968 t.Fatalf("unable to list pending channels: %v", err) 969 } 970 971 if len(pendingChannels) != 0 { 972 t.Fatalf("incorrect number of pending channels: expecting %v,"+ 973 "got %v", 0, len(pendingChannels)) 974 } 975 } 976 977 func TestFetchClosedChannels(t *testing.T) { 978 t.Parallel() 979 980 fullDB, cleanUp, err := MakeTestDB() 981 if err != nil { 982 t.Fatalf("unable to make test database: %v", err) 983 } 984 defer cleanUp() 985 986 cdb := fullDB.ChannelStateDB() 987 988 // Create an open channel in the database. 989 state := createTestChannel(t, cdb, openChannelOption()) 990 991 // Next, close the channel by including a close channel summary in the 992 // database. 993 summary := &ChannelCloseSummary{ 994 ChanPoint: state.FundingOutpoint, 995 ClosingTXID: rev, 996 RemotePub: state.IdentityPub, 997 Capacity: state.Capacity, 998 SettledBalance: state.LocalCommitment.LocalBalance.ToAtoms(), 999 TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToAtoms() + 10000, 1000 CloseType: RemoteForceClose, 1001 IsPending: true, 1002 LocalChanConfig: state.LocalChanCfg, 1003 } 1004 if err := state.CloseChannel(summary); err != nil { 1005 t.Fatalf("unable to close channel: %v", err) 1006 } 1007 1008 // Query the database to ensure that the channel has now been properly 1009 // closed. We should get the same result whether querying for pending 1010 // channels only, or not. 1011 pendingClosed, err := cdb.FetchClosedChannels(true) 1012 if err != nil { 1013 t.Fatalf("failed fetching closed channels: %v", err) 1014 } 1015 if len(pendingClosed) != 1 { 1016 t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ 1017 "got %v", 1, len(pendingClosed)) 1018 } 1019 if !reflect.DeepEqual(summary, pendingClosed[0]) { 1020 t.Fatalf("database summaries don't match: expected %v got %v", 1021 spew.Sdump(summary), spew.Sdump(pendingClosed[0])) 1022 } 1023 closed, err := cdb.FetchClosedChannels(false) 1024 if err != nil { 1025 t.Fatalf("failed fetching all closed channels: %v", err) 1026 } 1027 if len(closed) != 1 { 1028 t.Fatalf("incorrect number of closed channels: expecting %v, "+ 1029 "got %v", 1, len(closed)) 1030 } 1031 if !reflect.DeepEqual(summary, closed[0]) { 1032 t.Fatalf("database summaries don't match: expected %v got %v", 1033 spew.Sdump(summary), spew.Sdump(closed[0])) 1034 } 1035 1036 // Mark the channel as fully closed. 1037 err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) 1038 if err != nil { 1039 t.Fatalf("failed fully closing channel: %v", err) 1040 } 1041 1042 // The channel should no longer be considered pending, but should still 1043 // be retrieved when fetching all the closed channels. 1044 closed, err = cdb.FetchClosedChannels(false) 1045 if err != nil { 1046 t.Fatalf("failed fetching closed channels: %v", err) 1047 } 1048 if len(closed) != 1 { 1049 t.Fatalf("incorrect number of closed channels: expecting %v, "+ 1050 "got %v", 1, len(closed)) 1051 } 1052 pendingClose, err := cdb.FetchClosedChannels(true) 1053 if err != nil { 1054 t.Fatalf("failed fetching channels pending close: %v", err) 1055 } 1056 if len(pendingClose) != 0 { 1057 t.Fatalf("incorrect number of closed channels: expecting %v, "+ 1058 "got %v", 0, len(closed)) 1059 } 1060 } 1061 1062 // TestFetchWaitingCloseChannels ensures that the correct channels that are 1063 // waiting to be closed are returned. 1064 func TestFetchWaitingCloseChannels(t *testing.T) { 1065 t.Parallel() 1066 1067 const numChannels = 2 1068 const broadcastHeight = 99 1069 1070 // We'll start by creating two channels within our test database. One of 1071 // them will have their funding transaction confirmed on-chain, while 1072 // the other one will remain unconfirmed. 1073 fullDB, cleanUp, err := MakeTestDB() 1074 if err != nil { 1075 t.Fatalf("unable to make test database: %v", err) 1076 } 1077 defer cleanUp() 1078 1079 cdb := fullDB.ChannelStateDB() 1080 1081 channels := make([]*OpenChannel, numChannels) 1082 for i := 0; i < numChannels; i++ { 1083 // Create a pending channel in the database at the broadcast 1084 // height. 1085 channels[i] = createTestChannel( 1086 t, cdb, pendingHeightOption(broadcastHeight), 1087 ) 1088 } 1089 1090 // We'll only confirm the first one. 1091 channelConf := lnwire.ShortChannelID{ 1092 BlockHeight: broadcastHeight + 1, 1093 TxIndex: 10, 1094 TxPosition: 15, 1095 } 1096 if err := channels[0].MarkAsOpen(channelConf); err != nil { 1097 t.Fatalf("unable to mark channel as open: %v", err) 1098 } 1099 1100 // Then, we'll mark the channels as if their commitments were broadcast. 1101 // This would happen in the event of a force close and should make the 1102 // channels enter a state of waiting close. 1103 for _, channel := range channels { 1104 closeTx := wire.NewMsgTx() 1105 closeTx.Version = 2 1106 closeTx.AddTxIn( 1107 &wire.TxIn{ 1108 PreviousOutPoint: channel.FundingOutpoint, 1109 }, 1110 ) 1111 1112 if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { 1113 t.Fatalf("unable to mark commitment broadcast: %v", err) 1114 } 1115 1116 // Now try to marking a coop close with a nil tx. This should 1117 // succeed, but it shouldn't exit when queried. 1118 if err = channel.MarkCoopBroadcasted(nil, true); err != nil { 1119 t.Fatalf("unable to mark nil coop broadcast: %v", err) 1120 } 1121 _, err := channel.BroadcastedCooperative() 1122 if err != ErrNoCloseTx { 1123 t.Fatalf("expected no closing tx error, got: %v", err) 1124 } 1125 1126 // Finally, modify the close tx deterministically and also mark 1127 // it as coop closed. Later we will test that distinct 1128 // transactions are returned for both coop and force closes. 1129 closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 1130 if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { 1131 t.Fatalf("unable to mark coop broadcast: %v", err) 1132 } 1133 } 1134 1135 // Now, we'll fetch all the channels waiting to be closed from the 1136 // database. We should expect to see both channels above, even if any of 1137 // them haven't had their funding transaction confirm on-chain. 1138 waitingCloseChannels, err := cdb.FetchWaitingCloseChannels() 1139 if err != nil { 1140 t.Fatalf("unable to fetch all waiting close channels: %v", err) 1141 } 1142 if len(waitingCloseChannels) != numChannels { 1143 t.Fatalf("expected %d channels waiting to be closed, got %d", 2, 1144 len(waitingCloseChannels)) 1145 } 1146 expectedChannels := make(map[wire.OutPoint]struct{}) 1147 for _, channel := range channels { 1148 expectedChannels[channel.FundingOutpoint] = struct{}{} 1149 } 1150 for _, channel := range waitingCloseChannels { 1151 if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { 1152 t.Fatalf("expected channel %v to be waiting close", 1153 channel.FundingOutpoint) 1154 } 1155 1156 chanPoint := channel.FundingOutpoint 1157 1158 // Assert that the force close transaction is retrievable. 1159 forceCloseTx, err := channel.BroadcastedCommitment() 1160 if err != nil { 1161 t.Fatalf("Unable to retrieve commitment: %v", err) 1162 } 1163 1164 if forceCloseTx.TxIn[0].PreviousOutPoint != chanPoint { 1165 t.Fatalf("expected outpoint %v, got %v", 1166 chanPoint, 1167 forceCloseTx.TxIn[0].PreviousOutPoint) 1168 } 1169 1170 // Assert that the coop close transaction is retrievable. 1171 coopCloseTx, err := channel.BroadcastedCooperative() 1172 if err != nil { 1173 t.Fatalf("unable to retrieve coop close: %v", err) 1174 } 1175 1176 chanPoint.Index ^= 1 1177 if coopCloseTx.TxIn[0].PreviousOutPoint != chanPoint { 1178 t.Fatalf("expected outpoint %v, got %v", 1179 chanPoint, 1180 coopCloseTx.TxIn[0].PreviousOutPoint) 1181 } 1182 } 1183 } 1184 1185 // TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory 1186 // state of another OpenChannel to reflect a preceding call to MarkOpen on a 1187 // different OpenChannel. 1188 func TestRefreshShortChanID(t *testing.T) { 1189 t.Parallel() 1190 1191 fullDB, cleanUp, err := MakeTestDB() 1192 if err != nil { 1193 t.Fatalf("unable to make test database: %v", err) 1194 } 1195 defer cleanUp() 1196 1197 cdb := fullDB.ChannelStateDB() 1198 1199 // First create a test channel. 1200 state := createTestChannel(t, cdb) 1201 1202 // Next, locate the pending channel with the database. 1203 pendingChannels, err := cdb.FetchPendingChannels() 1204 if err != nil { 1205 t.Fatalf("unable to load pending channels; %v", err) 1206 } 1207 1208 var pendingChannel *OpenChannel 1209 for _, channel := range pendingChannels { 1210 if channel.FundingOutpoint == state.FundingOutpoint { 1211 pendingChannel = channel 1212 break 1213 } 1214 } 1215 if pendingChannel == nil { 1216 t.Fatalf("unable to find pending channel with funding "+ 1217 "outpoint=%v: %v", state.FundingOutpoint, err) 1218 } 1219 1220 // Next, simulate the confirmation of the channel by marking it as 1221 // pending within the database. 1222 chanOpenLoc := lnwire.ShortChannelID{ 1223 BlockHeight: 105, 1224 TxIndex: 10, 1225 TxPosition: 15, 1226 } 1227 1228 err = state.MarkAsOpen(chanOpenLoc) 1229 if err != nil { 1230 t.Fatalf("unable to mark channel open: %v", err) 1231 } 1232 1233 // The short_chan_id of the receiver to MarkAsOpen should reflect the 1234 // open location, but the other pending channel should remain unchanged. 1235 if state.ShortChanID() == pendingChannel.ShortChanID() { 1236 t.Fatalf("pending channel short_chan_ID should not have been " + 1237 "updated before refreshing short_chan_id") 1238 } 1239 1240 // Now that the receiver's short channel id has been updated, check to 1241 // ensure that the channel packager's source has been updated as well. 1242 // This ensures that the packager will read and write to buckets 1243 // corresponding to the new short chan id, instead of the prior. 1244 if state.Packager.(*ChannelPackager).source != chanOpenLoc { 1245 t.Fatalf("channel packager source was not updated: want %v, "+ 1246 "got %v", chanOpenLoc, 1247 state.Packager.(*ChannelPackager).source) 1248 } 1249 1250 // Now, refresh the short channel ID of the pending channel. 1251 err = pendingChannel.RefreshShortChanID() 1252 if err != nil { 1253 t.Fatalf("unable to refresh short_chan_id: %v", err) 1254 } 1255 1256 // This should result in both OpenChannel's now having the same 1257 // ShortChanID. 1258 if state.ShortChanID() != pendingChannel.ShortChanID() { 1259 t.Fatalf("expected pending channel short_chan_id to be "+ 1260 "refreshed: want %v, got %v", state.ShortChanID(), 1261 pendingChannel.ShortChanID()) 1262 } 1263 1264 // Check to ensure that the _other_ OpenChannel channel packager's 1265 // source has also been updated after the refresh. This ensures that the 1266 // other packagers will read and write to buckets corresponding to the 1267 // updated short chan id. 1268 if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { 1269 t.Fatalf("channel packager source was not updated: want %v, "+ 1270 "got %v", chanOpenLoc, 1271 pendingChannel.Packager.(*ChannelPackager).source) 1272 } 1273 1274 // Check to ensure that this channel is no longer pending and this field 1275 // is up to date. 1276 if pendingChannel.IsPending { 1277 t.Fatalf("channel pending state wasn't updated: want false got true") 1278 } 1279 } 1280 1281 // TestCloseInitiator tests the setting of close initiator statuses for 1282 // cooperative closes and local force closes. 1283 func TestCloseInitiator(t *testing.T) { 1284 tests := []struct { 1285 name string 1286 // updateChannel is called to update the channel as broadcast, 1287 // cooperatively or not, based on the test's requirements. 1288 updateChannel func(c *OpenChannel) error 1289 expectedStatuses []ChannelStatus 1290 }{ 1291 { 1292 name: "local coop close", 1293 // Mark the channel as cooperatively closed, initiated 1294 // by the local party. 1295 updateChannel: func(c *OpenChannel) error { 1296 return c.MarkCoopBroadcasted( 1297 &wire.MsgTx{}, true, 1298 ) 1299 }, 1300 expectedStatuses: []ChannelStatus{ 1301 ChanStatusLocalCloseInitiator, 1302 ChanStatusCoopBroadcasted, 1303 }, 1304 }, 1305 { 1306 name: "remote coop close", 1307 // Mark the channel as cooperatively closed, initiated 1308 // by the remote party. 1309 updateChannel: func(c *OpenChannel) error { 1310 return c.MarkCoopBroadcasted( 1311 &wire.MsgTx{}, false, 1312 ) 1313 }, 1314 expectedStatuses: []ChannelStatus{ 1315 ChanStatusRemoteCloseInitiator, 1316 ChanStatusCoopBroadcasted, 1317 }, 1318 }, 1319 { 1320 name: "local force close", 1321 // Mark the channel's commitment as broadcast with 1322 // local initiator. 1323 updateChannel: func(c *OpenChannel) error { 1324 return c.MarkCommitmentBroadcasted( 1325 &wire.MsgTx{}, true, 1326 ) 1327 }, 1328 expectedStatuses: []ChannelStatus{ 1329 ChanStatusLocalCloseInitiator, 1330 ChanStatusCommitBroadcasted, 1331 }, 1332 }, 1333 } 1334 1335 for _, test := range tests { 1336 test := test 1337 1338 t.Run(test.name, func(t *testing.T) { 1339 t.Parallel() 1340 1341 fullDB, cleanUp, err := MakeTestDB() 1342 if err != nil { 1343 t.Fatalf("unable to make test database: %v", 1344 err) 1345 } 1346 defer cleanUp() 1347 1348 cdb := fullDB.ChannelStateDB() 1349 1350 // Create an open channel. 1351 channel := createTestChannel( 1352 t, cdb, openChannelOption(), 1353 ) 1354 1355 err = test.updateChannel(channel) 1356 if err != nil { 1357 t.Fatalf("unexpected error: %v", err) 1358 } 1359 1360 // Lookup open channels in the database. 1361 dbChans, err := fetchChannels( 1362 cdb, pendingChannelFilter(false), 1363 ) 1364 if err != nil { 1365 t.Fatalf("unexpected error: %v", err) 1366 } 1367 if len(dbChans) != 1 { 1368 t.Fatalf("expected 1 channel, got: %v", 1369 len(dbChans)) 1370 } 1371 1372 // Check that the statuses that we expect were written 1373 // to disk. 1374 for _, status := range test.expectedStatuses { 1375 if !dbChans[0].HasChanStatus(status) { 1376 t.Fatalf("expected channel to have "+ 1377 "status: %v, has status: %v", 1378 status, dbChans[0].chanStatus) 1379 } 1380 } 1381 }) 1382 } 1383 } 1384 1385 // TestCloseChannelStatus tests setting of a channel status on the historical 1386 // channel on channel close. 1387 func TestCloseChannelStatus(t *testing.T) { 1388 fullDB, cleanUp, err := MakeTestDB() 1389 if err != nil { 1390 t.Fatalf("unable to make test database: %v", 1391 err) 1392 } 1393 defer cleanUp() 1394 1395 cdb := fullDB.ChannelStateDB() 1396 1397 // Create an open channel. 1398 channel := createTestChannel( 1399 t, cdb, openChannelOption(), 1400 ) 1401 1402 if err := channel.CloseChannel( 1403 &ChannelCloseSummary{ 1404 ChanPoint: channel.FundingOutpoint, 1405 RemotePub: channel.IdentityPub, 1406 }, ChanStatusRemoteCloseInitiator, 1407 ); err != nil { 1408 t.Fatalf("unexpected error: %v", err) 1409 } 1410 1411 histChan, err := channel.Db.FetchHistoricalChannel( 1412 &channel.FundingOutpoint, 1413 ) 1414 if err != nil { 1415 t.Fatalf("unexpected error: %v", err) 1416 } 1417 1418 if !histChan.HasChanStatus(ChanStatusRemoteCloseInitiator) { 1419 t.Fatalf("channel should have status") 1420 } 1421 } 1422 1423 // TestBalanceAtHeight tests lookup of our local and remote balance at a given 1424 // height. 1425 func TestBalanceAtHeight(t *testing.T) { 1426 const ( 1427 // Values that will be set on our current local commit in 1428 // memory. 1429 localHeight = 2 1430 localLocalBalance = 1000 1431 localRemoteBalance = 1500 1432 1433 // Values that will be set on our current remote commit in 1434 // memory. 1435 remoteHeight = 3 1436 remoteLocalBalance = 2000 1437 remoteRemoteBalance = 2500 1438 1439 // Values that will be written to disk in the revocation log. 1440 oldHeight = 0 1441 oldLocalBalance = 200 1442 oldRemoteBalance = 300 1443 1444 // Heights to test error cases. 1445 unknownHeight = 1 1446 unreachedHeight = 4 1447 ) 1448 1449 // putRevokedState is a helper function used to put commitments is 1450 // the revocation log bucket to test lookup of balances at heights that 1451 // are not our current height. 1452 putRevokedState := func(c *OpenChannel, height uint64, local, 1453 remote lnwire.MilliAtom) error { 1454 1455 err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { 1456 chanBucket, err := fetchChanBucketRw( 1457 tx, c.IdentityPub, &c.FundingOutpoint, 1458 c.ChainHash, 1459 ) 1460 if err != nil { 1461 return err 1462 } 1463 1464 logKey := revocationLogBucket 1465 logBucket, err := chanBucket.CreateBucketIfNotExists( 1466 logKey, 1467 ) 1468 if err != nil { 1469 return err 1470 } 1471 1472 // Make a copy of our current commitment so we do not 1473 // need to re-fill all the required fields and copy in 1474 // our new desired values. 1475 commit := c.LocalCommitment 1476 commit.CommitHeight = height 1477 commit.LocalBalance = local 1478 commit.RemoteBalance = remote 1479 1480 return appendChannelLogEntry(logBucket, &commit) 1481 }, func() {}) 1482 1483 return err 1484 } 1485 1486 tests := []struct { 1487 name string 1488 targetHeight uint64 1489 expectedLocalBalance lnwire.MilliAtom 1490 expectedRemoteBalance lnwire.MilliAtom 1491 expectedError error 1492 }{ 1493 { 1494 name: "target is current local height", 1495 targetHeight: localHeight, 1496 expectedLocalBalance: localLocalBalance, 1497 expectedRemoteBalance: localRemoteBalance, 1498 expectedError: nil, 1499 }, 1500 { 1501 name: "target is current remote height", 1502 targetHeight: remoteHeight, 1503 expectedLocalBalance: remoteLocalBalance, 1504 expectedRemoteBalance: remoteRemoteBalance, 1505 expectedError: nil, 1506 }, 1507 { 1508 name: "need to lookup commit", 1509 targetHeight: oldHeight, 1510 expectedLocalBalance: oldLocalBalance, 1511 expectedRemoteBalance: oldRemoteBalance, 1512 expectedError: nil, 1513 }, 1514 { 1515 name: "height not found", 1516 targetHeight: unknownHeight, 1517 expectedLocalBalance: 0, 1518 expectedRemoteBalance: 0, 1519 expectedError: ErrLogEntryNotFound, 1520 }, 1521 { 1522 name: "height not reached", 1523 targetHeight: unreachedHeight, 1524 expectedLocalBalance: 0, 1525 expectedRemoteBalance: 0, 1526 expectedError: errHeightNotReached, 1527 }, 1528 } 1529 1530 for _, test := range tests { 1531 test := test 1532 1533 t.Run(test.name, func(t *testing.T) { 1534 t.Parallel() 1535 1536 fullDB, cleanUp, err := MakeTestDB() 1537 if err != nil { 1538 t.Fatalf("unable to make test database: %v", 1539 err) 1540 } 1541 defer cleanUp() 1542 1543 cdb := fullDB.ChannelStateDB() 1544 1545 // Create options to set the heights and balances of 1546 // our local and remote commitments. 1547 localCommitOpt := channelCommitmentOption( 1548 localHeight, localLocalBalance, 1549 localRemoteBalance, true, 1550 ) 1551 1552 remoteCommitOpt := channelCommitmentOption( 1553 remoteHeight, remoteLocalBalance, 1554 remoteRemoteBalance, false, 1555 ) 1556 1557 // Create an open channel. 1558 channel := createTestChannel( 1559 t, cdb, openChannelOption(), 1560 localCommitOpt, remoteCommitOpt, 1561 ) 1562 1563 // Write an older commit to disk. 1564 err = putRevokedState(channel, oldHeight, 1565 oldLocalBalance, oldRemoteBalance) 1566 if err != nil { 1567 t.Fatalf("unexpected error: %v", err) 1568 } 1569 1570 local, remote, err := channel.BalancesAtHeight( 1571 test.targetHeight, 1572 ) 1573 if err != test.expectedError { 1574 t.Fatalf("expected: %v, got: %v", 1575 test.expectedError, err) 1576 } 1577 1578 if local != test.expectedLocalBalance { 1579 t.Fatalf("expected local: %v, got: %v", 1580 test.expectedLocalBalance, local) 1581 } 1582 1583 if remote != test.expectedRemoteBalance { 1584 t.Fatalf("expected remote: %v, got: %v", 1585 test.expectedRemoteBalance, remote) 1586 } 1587 }) 1588 } 1589 } 1590 1591 // TestHasChanStatus asserts the behavior of HasChanStatus by checking the 1592 // behavior of various status flags in addition to the special case of 1593 // ChanStatusDefault which is treated like a flag in the code base even though 1594 // it isn't. 1595 func TestHasChanStatus(t *testing.T) { 1596 tests := []struct { 1597 name string 1598 status ChannelStatus 1599 expHas map[ChannelStatus]bool 1600 }{ 1601 { 1602 name: "default", 1603 status: ChanStatusDefault, 1604 expHas: map[ChannelStatus]bool{ 1605 ChanStatusDefault: true, 1606 ChanStatusBorked: false, 1607 }, 1608 }, 1609 { 1610 name: "single flag", 1611 status: ChanStatusBorked, 1612 expHas: map[ChannelStatus]bool{ 1613 ChanStatusDefault: false, 1614 ChanStatusBorked: true, 1615 }, 1616 }, 1617 { 1618 name: "multiple flags", 1619 status: ChanStatusBorked | ChanStatusLocalDataLoss, 1620 expHas: map[ChannelStatus]bool{ 1621 ChanStatusDefault: false, 1622 ChanStatusBorked: true, 1623 ChanStatusLocalDataLoss: true, 1624 }, 1625 }, 1626 } 1627 1628 for _, test := range tests { 1629 test := test 1630 1631 t.Run(test.name, func(t *testing.T) { 1632 c := &OpenChannel{ 1633 chanStatus: test.status, 1634 } 1635 1636 for status, expHas := range test.expHas { 1637 has := c.HasChanStatus(status) 1638 if has == expHas { 1639 continue 1640 } 1641 1642 t.Fatalf("expected chan status to "+ 1643 "have %s? %t, got: %t", 1644 status, expHas, has) 1645 } 1646 }) 1647 } 1648 } 1649 1650 // TestKeyLocatorEncoding tests that we are able to serialize a given 1651 // keychain.KeyLocator. After successfully encoding, we check that the decode 1652 // output arrives at the same initial KeyLocator. 1653 func TestKeyLocatorEncoding(t *testing.T) { 1654 keyLoc := keychain.KeyLocator{ 1655 Family: keychain.KeyFamilyRevocationRoot, 1656 Index: keyLocIndex, 1657 } 1658 1659 // First, we'll encode the KeyLocator into a buffer. 1660 var ( 1661 b bytes.Buffer 1662 buf [8]byte 1663 ) 1664 1665 err := EKeyLocator(&b, &keyLoc, &buf) 1666 require.NoError(t, err, "unable to encode key locator") 1667 1668 // Next, we'll attempt to decode the bytes into a new KeyLocator. 1669 r := bytes.NewReader(b.Bytes()) 1670 var decodedKeyLoc keychain.KeyLocator 1671 1672 err = DKeyLocator(r, &decodedKeyLoc, &buf, 8) 1673 require.NoError(t, err, "unable to decode key locator") 1674 1675 // Finally, we'll compare that the original KeyLocator and the decoded 1676 // version are equal. 1677 require.Equal(t, keyLoc, decodedKeyLoc) 1678 }