github.com/decred/dcrlnd@v0.7.6/channeldb/channel_test.go (about)

     1  package channeldb
     2  
     3  import (
     4  	"bytes"
     5  	"math/rand"
     6  	"net"
     7  	"reflect"
     8  	"runtime"
     9  	"testing"
    10  
    11  	"github.com/davecgh/go-spew/spew"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/decred/dcrd/chaincfg/chainhash"
    15  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    16  	"github.com/decred/dcrd/dcrutil/v4"
    17  	"github.com/decred/dcrd/wire"
    18  	"github.com/decred/dcrlnd/clock"
    19  	"github.com/decred/dcrlnd/keychain"
    20  	"github.com/decred/dcrlnd/kvdb"
    21  	"github.com/decred/dcrlnd/lntest/channels"
    22  	"github.com/decred/dcrlnd/lnwire"
    23  	"github.com/decred/dcrlnd/shachain"
    24  )
    25  
    26  func privKeyFromBytes(b []byte) (*secp256k1.PrivateKey, *secp256k1.PublicKey) {
    27  	k := secp256k1.PrivKeyFromBytes(b)
    28  	return k, k.PubKey()
    29  }
    30  
    31  var (
    32  	key = [chainhash.HashSize]byte{
    33  		0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
    34  		0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
    35  		0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d,
    36  		0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9,
    37  	}
    38  	rev = [chainhash.HashSize]byte{
    39  		0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
    40  		0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
    41  		0x2d, 0xe7, 0x93, 0xe4,
    42  	}
    43  
    44  	privKey, pubKey = privKeyFromBytes(key[:])
    45  
    46  	wireSig, _ = lnwire.NewSigFromSignature(testSig)
    47  
    48  	testClock = clock.NewTestClock(testNow)
    49  
    50  	// defaultPendingHeight is the default height at which we set
    51  	// channels to pending.
    52  	defaultPendingHeight = 100
    53  
    54  	// defaultAddr is the default address that we mark test channels pending
    55  	// with.
    56  	defaultAddr = &net.TCPAddr{
    57  		IP:   net.ParseIP("127.0.0.1"),
    58  		Port: 18555,
    59  	}
    60  
    61  	// keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding.
    62  	keyLocIndex = uint32(2049)
    63  )
    64  
    65  // testChannelParams is a struct which details the specifics of how a channel
    66  // should be created.
    67  type testChannelParams struct {
    68  	// channel is the channel that will be written to disk.
    69  	channel *OpenChannel
    70  
    71  	// addr is the address that the channel will be synced pending with.
    72  	addr *net.TCPAddr
    73  
    74  	// pendingHeight is the height that the channel should be recorded as
    75  	// pending.
    76  	pendingHeight uint32
    77  
    78  	// openChannel is set to true if the channel should be fully marked as
    79  	// open if this is false, the channel will be left in pending state.
    80  	openChannel bool
    81  }
    82  
    83  // testChannelOption is a functional option which can be used to alter the
    84  // default channel that is creates for testing.
    85  type testChannelOption func(params *testChannelParams)
    86  
    87  // channelCommitmentOption is an option which allows overwriting of the default
    88  // commitment height and balances. The local boolean can be used to set these
    89  // balances on the local or remote commit.
    90  func channelCommitmentOption(height uint64, localBalance,
    91  	remoteBalance lnwire.MilliAtom, local bool) testChannelOption {
    92  
    93  	return func(params *testChannelParams) {
    94  		if local {
    95  			params.channel.LocalCommitment.CommitHeight = height
    96  			params.channel.LocalCommitment.LocalBalance = localBalance
    97  			params.channel.LocalCommitment.RemoteBalance = remoteBalance
    98  		} else {
    99  			params.channel.RemoteCommitment.CommitHeight = height
   100  			params.channel.RemoteCommitment.LocalBalance = localBalance
   101  			params.channel.RemoteCommitment.RemoteBalance = remoteBalance
   102  		}
   103  	}
   104  }
   105  
   106  // pendingHeightOption is an option which can be used to set the height the
   107  // channel is marked as pending at.
   108  func pendingHeightOption(height uint32) testChannelOption {
   109  	return func(params *testChannelParams) {
   110  		params.pendingHeight = height
   111  	}
   112  }
   113  
   114  // openChannelOption is an option which can be used to create a test channel
   115  // that is open.
   116  func openChannelOption() testChannelOption {
   117  	return func(params *testChannelParams) {
   118  		params.openChannel = true
   119  	}
   120  }
   121  
   122  // localHtlcsOption is an option which allows setting of htlcs on the local
   123  // commitment.
   124  func localHtlcsOption(htlcs []HTLC) testChannelOption {
   125  	return func(params *testChannelParams) {
   126  		params.channel.LocalCommitment.Htlcs = htlcs
   127  	}
   128  }
   129  
   130  // remoteHtlcsOption is an option which allows setting of htlcs on the remote
   131  // commitment.
   132  func remoteHtlcsOption(htlcs []HTLC) testChannelOption {
   133  	return func(params *testChannelParams) {
   134  		params.channel.RemoteCommitment.Htlcs = htlcs
   135  	}
   136  }
   137  
   138  // loadFwdPkgs is a helper method that reads all forwarding packages for a
   139  // particular packager.
   140  func loadFwdPkgs(t *testing.T, db kvdb.Backend,
   141  	packager FwdPackager) []*FwdPkg {
   142  
   143  	var (
   144  		fwdPkgs []*FwdPkg
   145  		err     error
   146  	)
   147  
   148  	err = kvdb.View(db, func(tx kvdb.RTx) error {
   149  		fwdPkgs, err = packager.LoadFwdPkgs(tx)
   150  		return err
   151  	}, func() {})
   152  	require.NoError(t, err, "unable to load fwd pkgs")
   153  
   154  	return fwdPkgs
   155  }
   156  
   157  // localShutdownOption is an option which sets the local upfront shutdown
   158  // script for the channel.
   159  func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
   160  	return func(params *testChannelParams) {
   161  		params.channel.LocalShutdownScript = addr
   162  	}
   163  }
   164  
   165  // remoteShutdownOption is an option which sets the remote upfront shutdown
   166  // script for the channel.
   167  func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
   168  	return func(params *testChannelParams) {
   169  		params.channel.RemoteShutdownScript = addr
   170  	}
   171  }
   172  
   173  // fundingPointOption is an option which sets the funding outpoint of the
   174  // channel.
   175  func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
   176  	return func(params *testChannelParams) {
   177  		params.channel.FundingOutpoint = chanPoint
   178  	}
   179  }
   180  
   181  // channelIDOption is an option which sets the short channel ID of the channel.
   182  var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
   183  	return func(params *testChannelParams) {
   184  		params.channel.ShortChannelID = chanID
   185  	}
   186  }
   187  
   188  // createTestChannel writes a test channel to the database. It takes a set of
   189  // functional options which can be used to overwrite the default of creating
   190  // a pending channel that was broadcast at height 100.
   191  func createTestChannel(t *testing.T, cdb *ChannelStateDB,
   192  	opts ...testChannelOption) *OpenChannel {
   193  
   194  	// Create a default set of parameters.
   195  	params := &testChannelParams{
   196  		channel:       createTestChannelState(t, cdb),
   197  		addr:          defaultAddr,
   198  		openChannel:   false,
   199  		pendingHeight: uint32(defaultPendingHeight),
   200  	}
   201  
   202  	// Apply all functional options to the test channel params.
   203  	for _, o := range opts {
   204  		o(params)
   205  	}
   206  
   207  	// Mark the channel as pending.
   208  	err := params.channel.SyncPending(params.addr, params.pendingHeight)
   209  	if err != nil {
   210  		t.Fatalf("unable to save and serialize channel "+
   211  			"state: %v", err)
   212  	}
   213  
   214  	// If the parameters do not specify that we should open the channel
   215  	// fully, we return the pending channel.
   216  	if !params.openChannel {
   217  		return params.channel
   218  	}
   219  
   220  	// Mark the channel as open with the short channel id provided.
   221  	err = params.channel.MarkAsOpen(params.channel.ShortChannelID)
   222  	if err != nil {
   223  		t.Fatalf("unable to mark channel open: %v", err)
   224  	}
   225  
   226  	return params.channel
   227  }
   228  
   229  func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
   230  	// Simulate 1000 channel updates.
   231  	producer, err := shachain.NewRevocationProducerFromBytes(key[:])
   232  	if err != nil {
   233  		t.Fatalf("could not get producer: %v", err)
   234  	}
   235  	store := shachain.NewRevocationStore()
   236  	for i := 0; i < 1; i++ {
   237  		preImage, err := producer.AtIndex(uint64(i))
   238  		if err != nil {
   239  			t.Fatalf("could not get "+
   240  				"preimage: %v", err)
   241  		}
   242  
   243  		if err := store.AddNextEntry(preImage); err != nil {
   244  			t.Fatalf("could not add entry: %v", err)
   245  		}
   246  	}
   247  
   248  	localCfg := ChannelConfig{
   249  		ChannelConstraints: ChannelConstraints{
   250  			DustLimit:        dcrutil.Amount(rand.Int63()),
   251  			MaxPendingAmount: lnwire.MilliAtom(rand.Int63()),
   252  			ChanReserve:      dcrutil.Amount(rand.Int63()),
   253  			MinHTLC:          lnwire.MilliAtom(rand.Int63()),
   254  			MaxAcceptedHtlcs: uint16(rand.Int31()),
   255  			CsvDelay:         uint16(rand.Int31()),
   256  		},
   257  		MultiSigKey: keychain.KeyDescriptor{
   258  			PubKey: privKey.PubKey(),
   259  		},
   260  		RevocationBasePoint: keychain.KeyDescriptor{
   261  			PubKey: privKey.PubKey(),
   262  		},
   263  		PaymentBasePoint: keychain.KeyDescriptor{
   264  			PubKey: privKey.PubKey(),
   265  		},
   266  		DelayBasePoint: keychain.KeyDescriptor{
   267  			PubKey: privKey.PubKey(),
   268  		},
   269  		HtlcBasePoint: keychain.KeyDescriptor{
   270  			PubKey: privKey.PubKey(),
   271  		},
   272  	}
   273  	remoteCfg := ChannelConfig{
   274  		ChannelConstraints: ChannelConstraints{
   275  			DustLimit:        dcrutil.Amount(rand.Int63()),
   276  			MaxPendingAmount: lnwire.MilliAtom(rand.Int63()),
   277  			ChanReserve:      dcrutil.Amount(rand.Int63()),
   278  			MinHTLC:          lnwire.MilliAtom(rand.Int63()),
   279  			MaxAcceptedHtlcs: uint16(rand.Int31()),
   280  			CsvDelay:         uint16(rand.Int31()),
   281  		},
   282  		MultiSigKey: keychain.KeyDescriptor{
   283  			PubKey: privKey.PubKey(),
   284  			KeyLocator: keychain.KeyLocator{
   285  				Family: keychain.KeyFamilyMultiSig,
   286  				Index:  9,
   287  			},
   288  		},
   289  		RevocationBasePoint: keychain.KeyDescriptor{
   290  			PubKey: privKey.PubKey(),
   291  			KeyLocator: keychain.KeyLocator{
   292  				Family: keychain.KeyFamilyRevocationBase,
   293  				Index:  8,
   294  			},
   295  		},
   296  		PaymentBasePoint: keychain.KeyDescriptor{
   297  			PubKey: privKey.PubKey(),
   298  			KeyLocator: keychain.KeyLocator{
   299  				Family: keychain.KeyFamilyPaymentBase,
   300  				Index:  7,
   301  			},
   302  		},
   303  		DelayBasePoint: keychain.KeyDescriptor{
   304  			PubKey: privKey.PubKey(),
   305  			KeyLocator: keychain.KeyLocator{
   306  				Family: keychain.KeyFamilyDelayBase,
   307  				Index:  6,
   308  			},
   309  		},
   310  		HtlcBasePoint: keychain.KeyDescriptor{
   311  			PubKey: privKey.PubKey(),
   312  			KeyLocator: keychain.KeyLocator{
   313  				Family: keychain.KeyFamilyHtlcBase,
   314  				Index:  5,
   315  			},
   316  		},
   317  	}
   318  
   319  	chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63()))
   320  
   321  	return &OpenChannel{
   322  		ChanType:            SingleFunderBit | FrozenBit,
   323  		ChainHash:           key,
   324  		FundingOutpoint:     wire.OutPoint{Hash: key, Index: rand.Uint32()},
   325  		ShortChannelID:      chanID,
   326  		IsInitiator:         true,
   327  		IsPending:           true,
   328  		IdentityPub:         pubKey,
   329  		Capacity:            dcrutil.Amount(10000),
   330  		LocalChanCfg:        localCfg,
   331  		RemoteChanCfg:       remoteCfg,
   332  		TotalMAtomsSent:     8,
   333  		TotalMAtomsReceived: 2,
   334  		LocalCommitment: ChannelCommitment{
   335  			CommitHeight:  0,
   336  			LocalBalance:  lnwire.MilliAtom(9000),
   337  			RemoteBalance: lnwire.MilliAtom(3000),
   338  			CommitFee:     dcrutil.Amount(rand.Int63()),
   339  			FeePerKB:      dcrutil.Amount(5000),
   340  			CommitTx:      channels.TestFundingTx,
   341  			CommitSig:     bytes.Repeat([]byte{1}, 71),
   342  		},
   343  		RemoteCommitment: ChannelCommitment{
   344  			CommitHeight:  0,
   345  			LocalBalance:  lnwire.MilliAtom(3000),
   346  			RemoteBalance: lnwire.MilliAtom(9000),
   347  			CommitFee:     dcrutil.Amount(rand.Int63()),
   348  			FeePerKB:      dcrutil.Amount(5000),
   349  			CommitTx:      channels.TestFundingTx,
   350  			CommitSig:     bytes.Repeat([]byte{1}, 71),
   351  		},
   352  		NumConfsRequired:        4,
   353  		RemoteCurrentRevocation: privKey.PubKey(),
   354  		RemoteNextRevocation:    privKey.PubKey(),
   355  		RevocationProducer:      producer,
   356  		RevocationStore:         store,
   357  		Db:                      cdb,
   358  		Packager:                NewChannelPackager(chanID),
   359  		FundingTxn:              channels.TestFundingTx,
   360  		ThawHeight:              uint32(defaultPendingHeight),
   361  	}
   362  }
   363  
   364  func TestOpenChannelPutGetDelete(t *testing.T) {
   365  	t.Parallel()
   366  
   367  	fullDB, cleanUp, err := MakeTestDB()
   368  	if err != nil {
   369  		t.Fatalf("unable to make test database: %v", err)
   370  	}
   371  	defer cleanUp()
   372  
   373  	cdb := fullDB.ChannelStateDB()
   374  
   375  	// Create the test channel state, with additional htlcs on the local
   376  	// and remote commitment.
   377  	localHtlcs := []HTLC{
   378  		{Signature: testSig.Serialize(),
   379  			Incoming:      true,
   380  			Amt:           10,
   381  			RHash:         key,
   382  			RefundTimeout: 1,
   383  			OnionBlob:     []byte("onionblob"),
   384  		},
   385  	}
   386  
   387  	remoteHtlcs := []HTLC{
   388  		{
   389  			Signature:     testSig.Serialize(),
   390  			Incoming:      false,
   391  			Amt:           10,
   392  			RHash:         key,
   393  			RefundTimeout: 1,
   394  			OnionBlob:     []byte("onionblob"),
   395  		},
   396  	}
   397  
   398  	state := createTestChannel(
   399  		t, cdb,
   400  		remoteHtlcsOption(remoteHtlcs),
   401  		localHtlcsOption(localHtlcs),
   402  	)
   403  
   404  	openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
   405  	if err != nil {
   406  		t.Fatalf("unable to fetch open channel: %v", err)
   407  	}
   408  
   409  	newState := openChannels[0]
   410  
   411  	// The decoded channel state should be identical to what we stored
   412  	// above.
   413  	if !reflect.DeepEqual(state, newState) {
   414  		t.Fatalf("channel state doesn't match:: %v vs %v",
   415  			spew.Sdump(state), spew.Sdump(newState))
   416  	}
   417  
   418  	// We'll also test that the channel is properly able to hot swap the
   419  	// next revocation for the state machine. This tests the initial
   420  	// post-funding revocation exchange.
   421  	nextRevKey, err := secp256k1.GeneratePrivateKey()
   422  	if err != nil {
   423  		t.Fatalf("unable to create new private key: %v", err)
   424  	}
   425  	nextRevKeyPub := nextRevKey.PubKey()
   426  	if err := state.InsertNextRevocation(nextRevKeyPub); err != nil {
   427  		t.Fatalf("unable to update revocation: %v", err)
   428  	}
   429  
   430  	openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
   431  	if err != nil {
   432  		t.Fatalf("unable to fetch open channel: %v", err)
   433  	}
   434  	updatedChan := openChannels[0]
   435  
   436  	// Ensure that the revocation was set properly.
   437  	if !nextRevKeyPub.IsEqual(updatedChan.RemoteNextRevocation) {
   438  		t.Fatalf("next revocation wasn't updated")
   439  	}
   440  
   441  	// Finally to wrap up the test, delete the state of the channel within
   442  	// the database. This involves "closing" the channel which removes all
   443  	// written state, and creates a small "summary" elsewhere within the
   444  	// database.
   445  	closeSummary := &ChannelCloseSummary{
   446  		ChanPoint:         state.FundingOutpoint,
   447  		RemotePub:         state.IdentityPub,
   448  		SettledBalance:    dcrutil.Amount(500),
   449  		TimeLockedBalance: dcrutil.Amount(10000),
   450  		IsPending:         false,
   451  		CloseType:         CooperativeClose,
   452  	}
   453  	if err := state.CloseChannel(closeSummary); err != nil {
   454  		t.Fatalf("unable to close channel: %v", err)
   455  	}
   456  
   457  	// As the channel is now closed, attempting to fetch all open channels
   458  	// for our fake node ID should return an empty slice.
   459  	openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
   460  	if err != nil {
   461  		t.Fatalf("unable to fetch open channels: %v", err)
   462  	}
   463  	if len(openChans) != 0 {
   464  		t.Fatalf("all channels not deleted, found %v", len(openChans))
   465  	}
   466  
   467  	// Additionally, attempting to fetch all the open channels globally
   468  	// should yield no results.
   469  	openChans, err = cdb.FetchAllChannels()
   470  	if err != nil {
   471  		t.Fatal("unable to fetch all open chans")
   472  	}
   473  	if len(openChans) != 0 {
   474  		t.Fatalf("all channels not deleted, found %v", len(openChans))
   475  	}
   476  }
   477  
   478  // TestOptionalShutdown tests the reading and writing of channels with and
   479  // without optional shutdown script fields.
   480  func TestOptionalShutdown(t *testing.T) {
   481  	local := lnwire.DeliveryAddress([]byte("local shutdown script"))
   482  	remote := lnwire.DeliveryAddress([]byte("remote shutdown script"))
   483  
   484  	if _, err := rand.Read(remote); err != nil {
   485  		t.Fatalf("Could not create random script: %v", err)
   486  	}
   487  
   488  	tests := []struct {
   489  		name           string
   490  		localShutdown  lnwire.DeliveryAddress
   491  		remoteShutdown lnwire.DeliveryAddress
   492  	}{
   493  		{
   494  			name:           "no shutdown scripts",
   495  			localShutdown:  nil,
   496  			remoteShutdown: nil,
   497  		},
   498  		{
   499  			name:           "local shutdown script",
   500  			localShutdown:  local,
   501  			remoteShutdown: nil,
   502  		},
   503  		{
   504  			name:           "remote shutdown script",
   505  			localShutdown:  nil,
   506  			remoteShutdown: remote,
   507  		},
   508  		{
   509  			name:           "both scripts set",
   510  			localShutdown:  local,
   511  			remoteShutdown: remote,
   512  		},
   513  	}
   514  
   515  	for _, test := range tests {
   516  		test := test
   517  
   518  		t.Run(test.name, func(t *testing.T) {
   519  			fullDB, cleanUp, err := MakeTestDB()
   520  			if err != nil {
   521  				t.Fatalf("unable to make test database: %v", err)
   522  			}
   523  			defer cleanUp()
   524  
   525  			cdb := fullDB.ChannelStateDB()
   526  
   527  			// Create a channel with upfront scripts set as
   528  			// specified in the test.
   529  			state := createTestChannel(
   530  				t, cdb,
   531  				localShutdownOption(test.localShutdown),
   532  				remoteShutdownOption(test.remoteShutdown),
   533  			)
   534  
   535  			openChannels, err := cdb.FetchOpenChannels(
   536  				state.IdentityPub,
   537  			)
   538  			if err != nil {
   539  				t.Fatalf("unable to fetch open"+
   540  					" channel: %v", err)
   541  			}
   542  
   543  			if len(openChannels) != 1 {
   544  				t.Fatalf("Expected one channel open,"+
   545  					" got: %v", len(openChannels))
   546  			}
   547  
   548  			if !bytes.Equal(openChannels[0].LocalShutdownScript,
   549  				test.localShutdown) {
   550  
   551  				t.Fatalf("Expected local: %x, got: %x",
   552  					test.localShutdown,
   553  					openChannels[0].LocalShutdownScript)
   554  			}
   555  
   556  			if !bytes.Equal(openChannels[0].RemoteShutdownScript,
   557  				test.remoteShutdown) {
   558  
   559  				t.Fatalf("Expected remote: %x, got: %x",
   560  					test.remoteShutdown,
   561  					openChannels[0].RemoteShutdownScript)
   562  			}
   563  		})
   564  	}
   565  }
   566  
   567  func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
   568  	if !reflect.DeepEqual(a, b) {
   569  		_, _, line, _ := runtime.Caller(1)
   570  		t.Fatalf("line %v: commitments don't match: %v vs %v",
   571  			line, spew.Sdump(a), spew.Sdump(b))
   572  	}
   573  }
   574  
   575  func TestChannelStateTransition(t *testing.T) {
   576  	t.Parallel()
   577  
   578  	fullDB, cleanUp, err := MakeTestDB()
   579  	if err != nil {
   580  		t.Fatalf("unable to make test database: %v", err)
   581  	}
   582  	defer cleanUp()
   583  
   584  	cdb := fullDB.ChannelStateDB()
   585  
   586  	// First create a minimal channel, then perform a full sync in order to
   587  	// persist the data.
   588  	channel := createTestChannel(t, cdb)
   589  
   590  	// Add some HTLCs which were added during this new state transition.
   591  	// Half of the HTLCs are incoming, while the other half are outgoing.
   592  	var (
   593  		htlcs   []HTLC
   594  		htlcAmt lnwire.MilliAtom
   595  	)
   596  	for i := uint32(0); i < 10; i++ {
   597  		var incoming bool
   598  		if i > 5 {
   599  			incoming = true
   600  		}
   601  		htlc := HTLC{
   602  			Signature:     testSig.Serialize(),
   603  			Incoming:      incoming,
   604  			Amt:           10,
   605  			RHash:         key,
   606  			RefundTimeout: i,
   607  			OutputIndex:   int32(i * 3),
   608  			LogIndex:      uint64(i * 2),
   609  			HtlcIndex:     uint64(i),
   610  		}
   611  		htlc.OnionBlob = make([]byte, 10)
   612  		copy(htlc.OnionBlob, bytes.Repeat([]byte{2}, 10))
   613  		htlcs = append(htlcs, htlc)
   614  		htlcAmt += htlc.Amt
   615  	}
   616  
   617  	// Create a new channel delta which includes the above HTLCs, some
   618  	// balance updates, and an increment of the current commitment height.
   619  	// Additionally, modify the signature and commitment transaction.
   620  	newSequence := uint32(129498)
   621  	newSig := bytes.Repeat([]byte{3}, 71)
   622  	newTx := channel.LocalCommitment.CommitTx.Copy()
   623  	newTx.TxIn[0].Sequence = newSequence
   624  	commitment := ChannelCommitment{
   625  		CommitHeight:    1,
   626  		LocalLogIndex:   2,
   627  		LocalHtlcIndex:  1,
   628  		RemoteLogIndex:  2,
   629  		RemoteHtlcIndex: 1,
   630  		LocalBalance:    lnwire.MilliAtom(1e8),
   631  		RemoteBalance:   lnwire.MilliAtom(1e8),
   632  		CommitFee:       55,
   633  		FeePerKB:        99,
   634  		CommitTx:        newTx,
   635  		CommitSig:       newSig,
   636  		Htlcs:           htlcs,
   637  	}
   638  
   639  	// First update the local node's broadcastable state and also add a
   640  	// CommitDiff remote node's as well in order to simulate a proper state
   641  	// transition.
   642  	unsignedAckedUpdates := []LogUpdate{
   643  		{
   644  			LogIndex: 2,
   645  			UpdateMsg: &lnwire.UpdateAddHTLC{
   646  				ChanID:    lnwire.ChannelID{1, 2, 3},
   647  				ExtraData: make([]byte, 0),
   648  			},
   649  		},
   650  	}
   651  
   652  	err = channel.UpdateCommitment(&commitment, unsignedAckedUpdates)
   653  	if err != nil {
   654  		t.Fatalf("unable to update commitment: %v", err)
   655  	}
   656  
   657  	// Assert that update is correctly written to the database.
   658  	dbUnsignedAckedUpdates, err := channel.UnsignedAckedUpdates()
   659  	if err != nil {
   660  		t.Fatalf("unable to fetch dangling remote updates: %v", err)
   661  	}
   662  	if len(dbUnsignedAckedUpdates) != 1 {
   663  		t.Fatalf("unexpected number of dangling remote updates")
   664  	}
   665  	if !reflect.DeepEqual(
   666  		dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0],
   667  	) {
   668  		t.Fatalf("unexpected update: expected %v, got %v",
   669  			spew.Sdump(unsignedAckedUpdates[0]),
   670  			spew.Sdump(dbUnsignedAckedUpdates))
   671  	}
   672  
   673  	// The balances, new update, the HTLCs and the changes to the fake
   674  	// commitment transaction along with the modified signature should all
   675  	// have been updated.
   676  	updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
   677  	if err != nil {
   678  		t.Fatalf("unable to fetch updated channel: %v", err)
   679  	}
   680  	assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
   681  	numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
   682  	if err != nil {
   683  		t.Fatalf("unable to read commitment height from disk: %v", err)
   684  	}
   685  	if numDiskUpdates != commitment.CommitHeight {
   686  		t.Fatalf("num disk updates doesn't match: %v vs %v",
   687  			numDiskUpdates, commitment.CommitHeight)
   688  	}
   689  
   690  	// Attempting to query for a commitment diff should return
   691  	// ErrNoPendingCommit as we haven't yet created a new state for them.
   692  	_, err = channel.RemoteCommitChainTip()
   693  	if err != ErrNoPendingCommit {
   694  		t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
   695  	}
   696  
   697  	// To simulate us extending a new state to the remote party, we'll also
   698  	// create a new commit diff for them.
   699  	remoteCommit := commitment
   700  	remoteCommit.LocalBalance = lnwire.MilliAtom(2e8)
   701  	remoteCommit.RemoteBalance = lnwire.MilliAtom(3e8)
   702  	remoteCommit.CommitHeight = 1
   703  	commitDiff := &CommitDiff{
   704  		Commitment: remoteCommit,
   705  		CommitSig: &lnwire.CommitSig{
   706  			ChanID:    lnwire.ChannelID(key),
   707  			CommitSig: wireSig,
   708  			HtlcSigs: []lnwire.Sig{
   709  				wireSig,
   710  				wireSig,
   711  			},
   712  			ExtraData: make([]byte, 0),
   713  		},
   714  		LogUpdates: []LogUpdate{
   715  			{
   716  				LogIndex: 1,
   717  				UpdateMsg: &lnwire.UpdateAddHTLC{
   718  					ID:        1,
   719  					Amount:    lnwire.NewMAtomsFromAtoms(100),
   720  					Expiry:    25,
   721  					ExtraData: make([]byte, 0),
   722  				},
   723  			},
   724  			{
   725  				LogIndex: 2,
   726  				UpdateMsg: &lnwire.UpdateAddHTLC{
   727  					ID:        2,
   728  					Amount:    lnwire.NewMAtomsFromAtoms(200),
   729  					Expiry:    50,
   730  					ExtraData: make([]byte, 0),
   731  				},
   732  			},
   733  		},
   734  		OpenedCircuitKeys: []CircuitKey{},
   735  		ClosedCircuitKeys: []CircuitKey{},
   736  	}
   737  	copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
   738  		bytes.Repeat([]byte{1}, 32))
   739  	copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
   740  		bytes.Repeat([]byte{2}, 32))
   741  	if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
   742  		t.Fatalf("unable to add to commit chain: %v", err)
   743  	}
   744  
   745  	// The commitment tip should now match the commitment that we just
   746  	// inserted.
   747  	diskCommitDiff, err := channel.RemoteCommitChainTip()
   748  	if err != nil {
   749  		t.Fatalf("unable to fetch commit diff: %v", err)
   750  	}
   751  	if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
   752  		t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
   753  			spew.Sdump(diskCommitDiff))
   754  	}
   755  
   756  	// We'll save the old remote commitment as this will be added to the
   757  	// revocation log shortly.
   758  	oldRemoteCommit := channel.RemoteCommitment
   759  
   760  	// Next, write to the log which tracks the necessary revocation state
   761  	// needed to rectify any fishy behavior by the remote party. Modify the
   762  	// current uncollapsed revocation state to simulate a state transition
   763  	// by the remote party.
   764  	channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
   765  	newPriv, err := secp256k1.GeneratePrivateKey()
   766  	if err != nil {
   767  		t.Fatalf("unable to generate key: %v", err)
   768  	}
   769  	channel.RemoteNextRevocation = newPriv.PubKey()
   770  
   771  	fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
   772  		diskCommitDiff.LogUpdates, nil)
   773  
   774  	err = channel.AdvanceCommitChainTail(fwdPkg, nil)
   775  	if err != nil {
   776  		t.Fatalf("unable to append to revocation log: %v", err)
   777  	}
   778  
   779  	// At this point, the remote commit chain should be nil, and the posted
   780  	// remote commitment should match the one we added as a diff above.
   781  	if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
   782  		t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
   783  	}
   784  
   785  	// We should be able to fetch the channel delta created above by its
   786  	// update number with all the state properly reconstructed.
   787  	diskPrevCommit, err := channel.FindPreviousState(
   788  		oldRemoteCommit.CommitHeight,
   789  	)
   790  	if err != nil {
   791  		t.Fatalf("unable to fetch past delta: %v", err)
   792  	}
   793  
   794  	// The two deltas (the original vs the on-disk version) should
   795  	// identical, and all HTLC data should properly be retained.
   796  	assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit)
   797  
   798  	// The state number recovered from the tail of the revocation log
   799  	// should be identical to this current state.
   800  	logTail, err := channel.RevocationLogTail()
   801  	if err != nil {
   802  		t.Fatalf("unable to retrieve log: %v", err)
   803  	}
   804  	if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
   805  		t.Fatal("update number doesn't match")
   806  	}
   807  
   808  	oldRemoteCommit = channel.RemoteCommitment
   809  
   810  	// Next modify the posted diff commitment slightly, then create a new
   811  	// commitment diff and advance the tail.
   812  	commitDiff.Commitment.CommitHeight = 2
   813  	commitDiff.Commitment.LocalBalance -= htlcAmt
   814  	commitDiff.Commitment.RemoteBalance += htlcAmt
   815  	commitDiff.LogUpdates = []LogUpdate{}
   816  	if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
   817  		t.Fatalf("unable to add to commit chain: %v", err)
   818  	}
   819  
   820  	fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
   821  
   822  	err = channel.AdvanceCommitChainTail(fwdPkg, nil)
   823  	if err != nil {
   824  		t.Fatalf("unable to append to revocation log: %v", err)
   825  	}
   826  
   827  	// Once again, fetch the state and ensure it has been properly updated.
   828  	prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight)
   829  	if err != nil {
   830  		t.Fatalf("unable to fetch past delta: %v", err)
   831  	}
   832  	assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)
   833  
   834  	// Once again, state number recovered from the tail of the revocation
   835  	// log should be identical to this current state.
   836  	logTail, err = channel.RevocationLogTail()
   837  	if err != nil {
   838  		t.Fatalf("unable to retrieve log: %v", err)
   839  	}
   840  	if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
   841  		t.Fatal("update number doesn't match")
   842  	}
   843  
   844  	// The revocation state stored on-disk should now also be identical.
   845  	updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
   846  	if err != nil {
   847  		t.Fatalf("unable to fetch updated channel: %v", err)
   848  	}
   849  	if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
   850  		t.Fatalf("revocation state was not synced")
   851  	}
   852  	if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
   853  		t.Fatalf("revocation state was not synced")
   854  	}
   855  
   856  	// At this point, we should have 2 forwarding packages added.
   857  	fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager)
   858  	require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages")
   859  
   860  	// Now attempt to delete the channel from the database.
   861  	closeSummary := &ChannelCloseSummary{
   862  		ChanPoint:         channel.FundingOutpoint,
   863  		RemotePub:         channel.IdentityPub,
   864  		SettledBalance:    dcrutil.Amount(500),
   865  		TimeLockedBalance: dcrutil.Amount(10000),
   866  		IsPending:         false,
   867  		CloseType:         RemoteForceClose,
   868  	}
   869  	if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
   870  		t.Fatalf("unable to delete updated channel: %v", err)
   871  	}
   872  
   873  	// If we attempt to fetch the target channel again, it shouldn't be
   874  	// found.
   875  	channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
   876  	if err != nil {
   877  		t.Fatalf("unable to fetch updated channels: %v", err)
   878  	}
   879  	if len(channels) != 0 {
   880  		t.Fatalf("%v channels, found, but none should be",
   881  			len(channels))
   882  	}
   883  
   884  	// Attempting to find previous states on the channel should fail as the
   885  	// revocation log has been deleted.
   886  	_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight)
   887  	if err == nil {
   888  		t.Fatal("revocation log search should have failed")
   889  	}
   890  
   891  	// All forwarding packages of this channel has been deleted too.
   892  	fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager)
   893  	require.Empty(t, fwdPkgs, "no forwarding packages should exist")
   894  }
   895  
   896  func TestFetchPendingChannels(t *testing.T) {
   897  	t.Parallel()
   898  
   899  	fullDB, cleanUp, err := MakeTestDB()
   900  	if err != nil {
   901  		t.Fatalf("unable to make test database: %v", err)
   902  	}
   903  	defer cleanUp()
   904  
   905  	cdb := fullDB.ChannelStateDB()
   906  
   907  	// Create a pending channel that was broadcast at height 99.
   908  	const broadcastHeight = 99
   909  	createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))
   910  
   911  	pendingChannels, err := cdb.FetchPendingChannels()
   912  	if err != nil {
   913  		t.Fatalf("unable to list pending channels: %v", err)
   914  	}
   915  
   916  	if len(pendingChannels) != 1 {
   917  		t.Fatalf("incorrect number of pending channels: expecting %v,"+
   918  			"got %v", 1, len(pendingChannels))
   919  	}
   920  
   921  	// The broadcast height of the pending channel should have been set
   922  	// properly.
   923  	if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
   924  		t.Fatalf("broadcast height mismatch: expected %v, got %v",
   925  			pendingChannels[0].FundingBroadcastHeight,
   926  			broadcastHeight)
   927  	}
   928  
   929  	chanOpenLoc := lnwire.ShortChannelID{
   930  		BlockHeight: 5,
   931  		TxIndex:     10,
   932  		TxPosition:  15,
   933  	}
   934  	err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
   935  	if err != nil {
   936  		t.Fatalf("unable to mark channel as open: %v", err)
   937  	}
   938  
   939  	if pendingChannels[0].IsPending {
   940  		t.Fatalf("channel marked open should no longer be pending")
   941  	}
   942  
   943  	if pendingChannels[0].ShortChanID() != chanOpenLoc {
   944  		t.Fatalf("channel opening height not updated: expected %v, "+
   945  			"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
   946  			chanOpenLoc)
   947  	}
   948  
   949  	// Next, we'll re-fetch the channel to ensure that the open height was
   950  	// properly set.
   951  	openChans, err := cdb.FetchAllChannels()
   952  	if err != nil {
   953  		t.Fatalf("unable to fetch channels: %v", err)
   954  	}
   955  	if openChans[0].ShortChanID() != chanOpenLoc {
   956  		t.Fatalf("channel opening heights don't match: expected %v, "+
   957  			"got %v", spew.Sdump(openChans[0].ShortChanID()),
   958  			chanOpenLoc)
   959  	}
   960  	if openChans[0].FundingBroadcastHeight != broadcastHeight {
   961  		t.Fatalf("broadcast height mismatch: expected %v, got %v",
   962  			openChans[0].FundingBroadcastHeight,
   963  			broadcastHeight)
   964  	}
   965  
   966  	pendingChannels, err = cdb.FetchPendingChannels()
   967  	if err != nil {
   968  		t.Fatalf("unable to list pending channels: %v", err)
   969  	}
   970  
   971  	if len(pendingChannels) != 0 {
   972  		t.Fatalf("incorrect number of pending channels: expecting %v,"+
   973  			"got %v", 0, len(pendingChannels))
   974  	}
   975  }
   976  
   977  func TestFetchClosedChannels(t *testing.T) {
   978  	t.Parallel()
   979  
   980  	fullDB, cleanUp, err := MakeTestDB()
   981  	if err != nil {
   982  		t.Fatalf("unable to make test database: %v", err)
   983  	}
   984  	defer cleanUp()
   985  
   986  	cdb := fullDB.ChannelStateDB()
   987  
   988  	// Create an open channel in the database.
   989  	state := createTestChannel(t, cdb, openChannelOption())
   990  
   991  	// Next, close the channel by including a close channel summary in the
   992  	// database.
   993  	summary := &ChannelCloseSummary{
   994  		ChanPoint:         state.FundingOutpoint,
   995  		ClosingTXID:       rev,
   996  		RemotePub:         state.IdentityPub,
   997  		Capacity:          state.Capacity,
   998  		SettledBalance:    state.LocalCommitment.LocalBalance.ToAtoms(),
   999  		TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToAtoms() + 10000,
  1000  		CloseType:         RemoteForceClose,
  1001  		IsPending:         true,
  1002  		LocalChanConfig:   state.LocalChanCfg,
  1003  	}
  1004  	if err := state.CloseChannel(summary); err != nil {
  1005  		t.Fatalf("unable to close channel: %v", err)
  1006  	}
  1007  
  1008  	// Query the database to ensure that the channel has now been properly
  1009  	// closed. We should get the same result whether querying for pending
  1010  	// channels only, or not.
  1011  	pendingClosed, err := cdb.FetchClosedChannels(true)
  1012  	if err != nil {
  1013  		t.Fatalf("failed fetching closed channels: %v", err)
  1014  	}
  1015  	if len(pendingClosed) != 1 {
  1016  		t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
  1017  			"got %v", 1, len(pendingClosed))
  1018  	}
  1019  	if !reflect.DeepEqual(summary, pendingClosed[0]) {
  1020  		t.Fatalf("database summaries don't match: expected %v got %v",
  1021  			spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
  1022  	}
  1023  	closed, err := cdb.FetchClosedChannels(false)
  1024  	if err != nil {
  1025  		t.Fatalf("failed fetching all closed channels: %v", err)
  1026  	}
  1027  	if len(closed) != 1 {
  1028  		t.Fatalf("incorrect number of closed channels: expecting %v, "+
  1029  			"got %v", 1, len(closed))
  1030  	}
  1031  	if !reflect.DeepEqual(summary, closed[0]) {
  1032  		t.Fatalf("database summaries don't match: expected %v got %v",
  1033  			spew.Sdump(summary), spew.Sdump(closed[0]))
  1034  	}
  1035  
  1036  	// Mark the channel as fully closed.
  1037  	err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
  1038  	if err != nil {
  1039  		t.Fatalf("failed fully closing channel: %v", err)
  1040  	}
  1041  
  1042  	// The channel should no longer be considered pending, but should still
  1043  	// be retrieved when fetching all the closed channels.
  1044  	closed, err = cdb.FetchClosedChannels(false)
  1045  	if err != nil {
  1046  		t.Fatalf("failed fetching closed channels: %v", err)
  1047  	}
  1048  	if len(closed) != 1 {
  1049  		t.Fatalf("incorrect number of closed channels: expecting %v, "+
  1050  			"got %v", 1, len(closed))
  1051  	}
  1052  	pendingClose, err := cdb.FetchClosedChannels(true)
  1053  	if err != nil {
  1054  		t.Fatalf("failed fetching channels pending close: %v", err)
  1055  	}
  1056  	if len(pendingClose) != 0 {
  1057  		t.Fatalf("incorrect number of closed channels: expecting %v, "+
  1058  			"got %v", 0, len(closed))
  1059  	}
  1060  }
  1061  
  1062  // TestFetchWaitingCloseChannels ensures that the correct channels that are
  1063  // waiting to be closed are returned.
  1064  func TestFetchWaitingCloseChannels(t *testing.T) {
  1065  	t.Parallel()
  1066  
  1067  	const numChannels = 2
  1068  	const broadcastHeight = 99
  1069  
  1070  	// We'll start by creating two channels within our test database. One of
  1071  	// them will have their funding transaction confirmed on-chain, while
  1072  	// the other one will remain unconfirmed.
  1073  	fullDB, cleanUp, err := MakeTestDB()
  1074  	if err != nil {
  1075  		t.Fatalf("unable to make test database: %v", err)
  1076  	}
  1077  	defer cleanUp()
  1078  
  1079  	cdb := fullDB.ChannelStateDB()
  1080  
  1081  	channels := make([]*OpenChannel, numChannels)
  1082  	for i := 0; i < numChannels; i++ {
  1083  		// Create a pending channel in the database at the broadcast
  1084  		// height.
  1085  		channels[i] = createTestChannel(
  1086  			t, cdb, pendingHeightOption(broadcastHeight),
  1087  		)
  1088  	}
  1089  
  1090  	// We'll only confirm the first one.
  1091  	channelConf := lnwire.ShortChannelID{
  1092  		BlockHeight: broadcastHeight + 1,
  1093  		TxIndex:     10,
  1094  		TxPosition:  15,
  1095  	}
  1096  	if err := channels[0].MarkAsOpen(channelConf); err != nil {
  1097  		t.Fatalf("unable to mark channel as open: %v", err)
  1098  	}
  1099  
  1100  	// Then, we'll mark the channels as if their commitments were broadcast.
  1101  	// This would happen in the event of a force close and should make the
  1102  	// channels enter a state of waiting close.
  1103  	for _, channel := range channels {
  1104  		closeTx := wire.NewMsgTx()
  1105  		closeTx.Version = 2
  1106  		closeTx.AddTxIn(
  1107  			&wire.TxIn{
  1108  				PreviousOutPoint: channel.FundingOutpoint,
  1109  			},
  1110  		)
  1111  
  1112  		if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil {
  1113  			t.Fatalf("unable to mark commitment broadcast: %v", err)
  1114  		}
  1115  
  1116  		// Now try to marking a coop close with a nil tx. This should
  1117  		// succeed, but it shouldn't exit when queried.
  1118  		if err = channel.MarkCoopBroadcasted(nil, true); err != nil {
  1119  			t.Fatalf("unable to mark nil coop broadcast: %v", err)
  1120  		}
  1121  		_, err := channel.BroadcastedCooperative()
  1122  		if err != ErrNoCloseTx {
  1123  			t.Fatalf("expected no closing tx error, got: %v", err)
  1124  		}
  1125  
  1126  		// Finally, modify the close tx deterministically  and also mark
  1127  		// it as coop closed. Later we will test that distinct
  1128  		// transactions are returned for both coop and force closes.
  1129  		closeTx.TxIn[0].PreviousOutPoint.Index ^= 1
  1130  		if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil {
  1131  			t.Fatalf("unable to mark coop broadcast: %v", err)
  1132  		}
  1133  	}
  1134  
  1135  	// Now, we'll fetch all the channels waiting to be closed from the
  1136  	// database. We should expect to see both channels above, even if any of
  1137  	// them haven't had their funding transaction confirm on-chain.
  1138  	waitingCloseChannels, err := cdb.FetchWaitingCloseChannels()
  1139  	if err != nil {
  1140  		t.Fatalf("unable to fetch all waiting close channels: %v", err)
  1141  	}
  1142  	if len(waitingCloseChannels) != numChannels {
  1143  		t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
  1144  			len(waitingCloseChannels))
  1145  	}
  1146  	expectedChannels := make(map[wire.OutPoint]struct{})
  1147  	for _, channel := range channels {
  1148  		expectedChannels[channel.FundingOutpoint] = struct{}{}
  1149  	}
  1150  	for _, channel := range waitingCloseChannels {
  1151  		if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
  1152  			t.Fatalf("expected channel %v to be waiting close",
  1153  				channel.FundingOutpoint)
  1154  		}
  1155  
  1156  		chanPoint := channel.FundingOutpoint
  1157  
  1158  		// Assert that the force close transaction is retrievable.
  1159  		forceCloseTx, err := channel.BroadcastedCommitment()
  1160  		if err != nil {
  1161  			t.Fatalf("Unable to retrieve commitment: %v", err)
  1162  		}
  1163  
  1164  		if forceCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
  1165  			t.Fatalf("expected outpoint %v, got %v",
  1166  				chanPoint,
  1167  				forceCloseTx.TxIn[0].PreviousOutPoint)
  1168  		}
  1169  
  1170  		// Assert that the coop close transaction is retrievable.
  1171  		coopCloseTx, err := channel.BroadcastedCooperative()
  1172  		if err != nil {
  1173  			t.Fatalf("unable to retrieve coop close: %v", err)
  1174  		}
  1175  
  1176  		chanPoint.Index ^= 1
  1177  		if coopCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
  1178  			t.Fatalf("expected outpoint %v, got %v",
  1179  				chanPoint,
  1180  				coopCloseTx.TxIn[0].PreviousOutPoint)
  1181  		}
  1182  	}
  1183  }
  1184  
  1185  // TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
  1186  // state of another OpenChannel to reflect a preceding call to MarkOpen on a
  1187  // different OpenChannel.
  1188  func TestRefreshShortChanID(t *testing.T) {
  1189  	t.Parallel()
  1190  
  1191  	fullDB, cleanUp, err := MakeTestDB()
  1192  	if err != nil {
  1193  		t.Fatalf("unable to make test database: %v", err)
  1194  	}
  1195  	defer cleanUp()
  1196  
  1197  	cdb := fullDB.ChannelStateDB()
  1198  
  1199  	// First create a test channel.
  1200  	state := createTestChannel(t, cdb)
  1201  
  1202  	// Next, locate the pending channel with the database.
  1203  	pendingChannels, err := cdb.FetchPendingChannels()
  1204  	if err != nil {
  1205  		t.Fatalf("unable to load pending channels; %v", err)
  1206  	}
  1207  
  1208  	var pendingChannel *OpenChannel
  1209  	for _, channel := range pendingChannels {
  1210  		if channel.FundingOutpoint == state.FundingOutpoint {
  1211  			pendingChannel = channel
  1212  			break
  1213  		}
  1214  	}
  1215  	if pendingChannel == nil {
  1216  		t.Fatalf("unable to find pending channel with funding "+
  1217  			"outpoint=%v: %v", state.FundingOutpoint, err)
  1218  	}
  1219  
  1220  	// Next, simulate the confirmation of the channel by marking it as
  1221  	// pending within the database.
  1222  	chanOpenLoc := lnwire.ShortChannelID{
  1223  		BlockHeight: 105,
  1224  		TxIndex:     10,
  1225  		TxPosition:  15,
  1226  	}
  1227  
  1228  	err = state.MarkAsOpen(chanOpenLoc)
  1229  	if err != nil {
  1230  		t.Fatalf("unable to mark channel open: %v", err)
  1231  	}
  1232  
  1233  	// The short_chan_id of the receiver to MarkAsOpen should reflect the
  1234  	// open location, but the other pending channel should remain unchanged.
  1235  	if state.ShortChanID() == pendingChannel.ShortChanID() {
  1236  		t.Fatalf("pending channel short_chan_ID should not have been " +
  1237  			"updated before refreshing short_chan_id")
  1238  	}
  1239  
  1240  	// Now that the receiver's short channel id has been updated, check to
  1241  	// ensure that the channel packager's source has been updated as well.
  1242  	// This ensures that the packager will read and write to buckets
  1243  	// corresponding to the new short chan id, instead of the prior.
  1244  	if state.Packager.(*ChannelPackager).source != chanOpenLoc {
  1245  		t.Fatalf("channel packager source was not updated: want %v, "+
  1246  			"got %v", chanOpenLoc,
  1247  			state.Packager.(*ChannelPackager).source)
  1248  	}
  1249  
  1250  	// Now, refresh the short channel ID of the pending channel.
  1251  	err = pendingChannel.RefreshShortChanID()
  1252  	if err != nil {
  1253  		t.Fatalf("unable to refresh short_chan_id: %v", err)
  1254  	}
  1255  
  1256  	// This should result in both OpenChannel's now having the same
  1257  	// ShortChanID.
  1258  	if state.ShortChanID() != pendingChannel.ShortChanID() {
  1259  		t.Fatalf("expected pending channel short_chan_id to be "+
  1260  			"refreshed: want %v, got %v", state.ShortChanID(),
  1261  			pendingChannel.ShortChanID())
  1262  	}
  1263  
  1264  	// Check to ensure that the _other_ OpenChannel channel packager's
  1265  	// source has also been updated after the refresh. This ensures that the
  1266  	// other packagers will read and write to buckets corresponding to the
  1267  	// updated short chan id.
  1268  	if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
  1269  		t.Fatalf("channel packager source was not updated: want %v, "+
  1270  			"got %v", chanOpenLoc,
  1271  			pendingChannel.Packager.(*ChannelPackager).source)
  1272  	}
  1273  
  1274  	// Check to ensure that this channel is no longer pending and this field
  1275  	// is up to date.
  1276  	if pendingChannel.IsPending {
  1277  		t.Fatalf("channel pending state wasn't updated: want false got true")
  1278  	}
  1279  }
  1280  
  1281  // TestCloseInitiator tests the setting of close initiator statuses for
  1282  // cooperative closes and local force closes.
  1283  func TestCloseInitiator(t *testing.T) {
  1284  	tests := []struct {
  1285  		name string
  1286  		// updateChannel is called to update the channel as broadcast,
  1287  		// cooperatively or not, based on the test's requirements.
  1288  		updateChannel    func(c *OpenChannel) error
  1289  		expectedStatuses []ChannelStatus
  1290  	}{
  1291  		{
  1292  			name: "local coop close",
  1293  			// Mark the channel as cooperatively closed, initiated
  1294  			// by the local party.
  1295  			updateChannel: func(c *OpenChannel) error {
  1296  				return c.MarkCoopBroadcasted(
  1297  					&wire.MsgTx{}, true,
  1298  				)
  1299  			},
  1300  			expectedStatuses: []ChannelStatus{
  1301  				ChanStatusLocalCloseInitiator,
  1302  				ChanStatusCoopBroadcasted,
  1303  			},
  1304  		},
  1305  		{
  1306  			name: "remote coop close",
  1307  			// Mark the channel as cooperatively closed, initiated
  1308  			// by the remote party.
  1309  			updateChannel: func(c *OpenChannel) error {
  1310  				return c.MarkCoopBroadcasted(
  1311  					&wire.MsgTx{}, false,
  1312  				)
  1313  			},
  1314  			expectedStatuses: []ChannelStatus{
  1315  				ChanStatusRemoteCloseInitiator,
  1316  				ChanStatusCoopBroadcasted,
  1317  			},
  1318  		},
  1319  		{
  1320  			name: "local force close",
  1321  			// Mark the channel's commitment as broadcast with
  1322  			// local initiator.
  1323  			updateChannel: func(c *OpenChannel) error {
  1324  				return c.MarkCommitmentBroadcasted(
  1325  					&wire.MsgTx{}, true,
  1326  				)
  1327  			},
  1328  			expectedStatuses: []ChannelStatus{
  1329  				ChanStatusLocalCloseInitiator,
  1330  				ChanStatusCommitBroadcasted,
  1331  			},
  1332  		},
  1333  	}
  1334  
  1335  	for _, test := range tests {
  1336  		test := test
  1337  
  1338  		t.Run(test.name, func(t *testing.T) {
  1339  			t.Parallel()
  1340  
  1341  			fullDB, cleanUp, err := MakeTestDB()
  1342  			if err != nil {
  1343  				t.Fatalf("unable to make test database: %v",
  1344  					err)
  1345  			}
  1346  			defer cleanUp()
  1347  
  1348  			cdb := fullDB.ChannelStateDB()
  1349  
  1350  			// Create an open channel.
  1351  			channel := createTestChannel(
  1352  				t, cdb, openChannelOption(),
  1353  			)
  1354  
  1355  			err = test.updateChannel(channel)
  1356  			if err != nil {
  1357  				t.Fatalf("unexpected error: %v", err)
  1358  			}
  1359  
  1360  			// Lookup open channels in the database.
  1361  			dbChans, err := fetchChannels(
  1362  				cdb, pendingChannelFilter(false),
  1363  			)
  1364  			if err != nil {
  1365  				t.Fatalf("unexpected error: %v", err)
  1366  			}
  1367  			if len(dbChans) != 1 {
  1368  				t.Fatalf("expected 1 channel, got: %v",
  1369  					len(dbChans))
  1370  			}
  1371  
  1372  			// Check that the statuses that we expect were written
  1373  			// to disk.
  1374  			for _, status := range test.expectedStatuses {
  1375  				if !dbChans[0].HasChanStatus(status) {
  1376  					t.Fatalf("expected channel to have "+
  1377  						"status: %v, has status: %v",
  1378  						status, dbChans[0].chanStatus)
  1379  				}
  1380  			}
  1381  		})
  1382  	}
  1383  }
  1384  
  1385  // TestCloseChannelStatus tests setting of a channel status on the historical
  1386  // channel on channel close.
  1387  func TestCloseChannelStatus(t *testing.T) {
  1388  	fullDB, cleanUp, err := MakeTestDB()
  1389  	if err != nil {
  1390  		t.Fatalf("unable to make test database: %v",
  1391  			err)
  1392  	}
  1393  	defer cleanUp()
  1394  
  1395  	cdb := fullDB.ChannelStateDB()
  1396  
  1397  	// Create an open channel.
  1398  	channel := createTestChannel(
  1399  		t, cdb, openChannelOption(),
  1400  	)
  1401  
  1402  	if err := channel.CloseChannel(
  1403  		&ChannelCloseSummary{
  1404  			ChanPoint: channel.FundingOutpoint,
  1405  			RemotePub: channel.IdentityPub,
  1406  		}, ChanStatusRemoteCloseInitiator,
  1407  	); err != nil {
  1408  		t.Fatalf("unexpected error: %v", err)
  1409  	}
  1410  
  1411  	histChan, err := channel.Db.FetchHistoricalChannel(
  1412  		&channel.FundingOutpoint,
  1413  	)
  1414  	if err != nil {
  1415  		t.Fatalf("unexpected error: %v", err)
  1416  	}
  1417  
  1418  	if !histChan.HasChanStatus(ChanStatusRemoteCloseInitiator) {
  1419  		t.Fatalf("channel should have status")
  1420  	}
  1421  }
  1422  
  1423  // TestBalanceAtHeight tests lookup of our local and remote balance at a given
  1424  // height.
  1425  func TestBalanceAtHeight(t *testing.T) {
  1426  	const (
  1427  		// Values that will be set on our current local commit in
  1428  		// memory.
  1429  		localHeight        = 2
  1430  		localLocalBalance  = 1000
  1431  		localRemoteBalance = 1500
  1432  
  1433  		// Values that will be set on our current remote commit in
  1434  		// memory.
  1435  		remoteHeight        = 3
  1436  		remoteLocalBalance  = 2000
  1437  		remoteRemoteBalance = 2500
  1438  
  1439  		// Values that will be written to disk in the revocation log.
  1440  		oldHeight        = 0
  1441  		oldLocalBalance  = 200
  1442  		oldRemoteBalance = 300
  1443  
  1444  		// Heights to test error cases.
  1445  		unknownHeight   = 1
  1446  		unreachedHeight = 4
  1447  	)
  1448  
  1449  	// putRevokedState is a helper function used to put commitments is
  1450  	// the revocation log bucket to test lookup of balances at heights that
  1451  	// are not our current height.
  1452  	putRevokedState := func(c *OpenChannel, height uint64, local,
  1453  		remote lnwire.MilliAtom) error {
  1454  
  1455  		err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
  1456  			chanBucket, err := fetchChanBucketRw(
  1457  				tx, c.IdentityPub, &c.FundingOutpoint,
  1458  				c.ChainHash,
  1459  			)
  1460  			if err != nil {
  1461  				return err
  1462  			}
  1463  
  1464  			logKey := revocationLogBucket
  1465  			logBucket, err := chanBucket.CreateBucketIfNotExists(
  1466  				logKey,
  1467  			)
  1468  			if err != nil {
  1469  				return err
  1470  			}
  1471  
  1472  			// Make a copy of our current commitment so we do not
  1473  			// need to re-fill all the required fields and copy in
  1474  			// our new desired values.
  1475  			commit := c.LocalCommitment
  1476  			commit.CommitHeight = height
  1477  			commit.LocalBalance = local
  1478  			commit.RemoteBalance = remote
  1479  
  1480  			return appendChannelLogEntry(logBucket, &commit)
  1481  		}, func() {})
  1482  
  1483  		return err
  1484  	}
  1485  
  1486  	tests := []struct {
  1487  		name                  string
  1488  		targetHeight          uint64
  1489  		expectedLocalBalance  lnwire.MilliAtom
  1490  		expectedRemoteBalance lnwire.MilliAtom
  1491  		expectedError         error
  1492  	}{
  1493  		{
  1494  			name:                  "target is current local height",
  1495  			targetHeight:          localHeight,
  1496  			expectedLocalBalance:  localLocalBalance,
  1497  			expectedRemoteBalance: localRemoteBalance,
  1498  			expectedError:         nil,
  1499  		},
  1500  		{
  1501  			name:                  "target is current remote height",
  1502  			targetHeight:          remoteHeight,
  1503  			expectedLocalBalance:  remoteLocalBalance,
  1504  			expectedRemoteBalance: remoteRemoteBalance,
  1505  			expectedError:         nil,
  1506  		},
  1507  		{
  1508  			name:                  "need to lookup commit",
  1509  			targetHeight:          oldHeight,
  1510  			expectedLocalBalance:  oldLocalBalance,
  1511  			expectedRemoteBalance: oldRemoteBalance,
  1512  			expectedError:         nil,
  1513  		},
  1514  		{
  1515  			name:                  "height not found",
  1516  			targetHeight:          unknownHeight,
  1517  			expectedLocalBalance:  0,
  1518  			expectedRemoteBalance: 0,
  1519  			expectedError:         ErrLogEntryNotFound,
  1520  		},
  1521  		{
  1522  			name:                  "height not reached",
  1523  			targetHeight:          unreachedHeight,
  1524  			expectedLocalBalance:  0,
  1525  			expectedRemoteBalance: 0,
  1526  			expectedError:         errHeightNotReached,
  1527  		},
  1528  	}
  1529  
  1530  	for _, test := range tests {
  1531  		test := test
  1532  
  1533  		t.Run(test.name, func(t *testing.T) {
  1534  			t.Parallel()
  1535  
  1536  			fullDB, cleanUp, err := MakeTestDB()
  1537  			if err != nil {
  1538  				t.Fatalf("unable to make test database: %v",
  1539  					err)
  1540  			}
  1541  			defer cleanUp()
  1542  
  1543  			cdb := fullDB.ChannelStateDB()
  1544  
  1545  			// Create options to set the heights and balances of
  1546  			// our local and remote commitments.
  1547  			localCommitOpt := channelCommitmentOption(
  1548  				localHeight, localLocalBalance,
  1549  				localRemoteBalance, true,
  1550  			)
  1551  
  1552  			remoteCommitOpt := channelCommitmentOption(
  1553  				remoteHeight, remoteLocalBalance,
  1554  				remoteRemoteBalance, false,
  1555  			)
  1556  
  1557  			// Create an open channel.
  1558  			channel := createTestChannel(
  1559  				t, cdb, openChannelOption(),
  1560  				localCommitOpt, remoteCommitOpt,
  1561  			)
  1562  
  1563  			// Write an older commit to disk.
  1564  			err = putRevokedState(channel, oldHeight,
  1565  				oldLocalBalance, oldRemoteBalance)
  1566  			if err != nil {
  1567  				t.Fatalf("unexpected error: %v", err)
  1568  			}
  1569  
  1570  			local, remote, err := channel.BalancesAtHeight(
  1571  				test.targetHeight,
  1572  			)
  1573  			if err != test.expectedError {
  1574  				t.Fatalf("expected: %v, got: %v",
  1575  					test.expectedError, err)
  1576  			}
  1577  
  1578  			if local != test.expectedLocalBalance {
  1579  				t.Fatalf("expected local: %v, got: %v",
  1580  					test.expectedLocalBalance, local)
  1581  			}
  1582  
  1583  			if remote != test.expectedRemoteBalance {
  1584  				t.Fatalf("expected remote: %v, got: %v",
  1585  					test.expectedRemoteBalance, remote)
  1586  			}
  1587  		})
  1588  	}
  1589  }
  1590  
  1591  // TestHasChanStatus asserts the behavior of HasChanStatus by checking the
  1592  // behavior of various status flags in addition to the special case of
  1593  // ChanStatusDefault which is treated like a flag in the code base even though
  1594  // it isn't.
  1595  func TestHasChanStatus(t *testing.T) {
  1596  	tests := []struct {
  1597  		name   string
  1598  		status ChannelStatus
  1599  		expHas map[ChannelStatus]bool
  1600  	}{
  1601  		{
  1602  			name:   "default",
  1603  			status: ChanStatusDefault,
  1604  			expHas: map[ChannelStatus]bool{
  1605  				ChanStatusDefault: true,
  1606  				ChanStatusBorked:  false,
  1607  			},
  1608  		},
  1609  		{
  1610  			name:   "single flag",
  1611  			status: ChanStatusBorked,
  1612  			expHas: map[ChannelStatus]bool{
  1613  				ChanStatusDefault: false,
  1614  				ChanStatusBorked:  true,
  1615  			},
  1616  		},
  1617  		{
  1618  			name:   "multiple flags",
  1619  			status: ChanStatusBorked | ChanStatusLocalDataLoss,
  1620  			expHas: map[ChannelStatus]bool{
  1621  				ChanStatusDefault:       false,
  1622  				ChanStatusBorked:        true,
  1623  				ChanStatusLocalDataLoss: true,
  1624  			},
  1625  		},
  1626  	}
  1627  
  1628  	for _, test := range tests {
  1629  		test := test
  1630  
  1631  		t.Run(test.name, func(t *testing.T) {
  1632  			c := &OpenChannel{
  1633  				chanStatus: test.status,
  1634  			}
  1635  
  1636  			for status, expHas := range test.expHas {
  1637  				has := c.HasChanStatus(status)
  1638  				if has == expHas {
  1639  					continue
  1640  				}
  1641  
  1642  				t.Fatalf("expected chan status to "+
  1643  					"have %s? %t, got: %t",
  1644  					status, expHas, has)
  1645  			}
  1646  		})
  1647  	}
  1648  }
  1649  
  1650  // TestKeyLocatorEncoding tests that we are able to serialize a given
  1651  // keychain.KeyLocator. After successfully encoding, we check that the decode
  1652  // output arrives at the same initial KeyLocator.
  1653  func TestKeyLocatorEncoding(t *testing.T) {
  1654  	keyLoc := keychain.KeyLocator{
  1655  		Family: keychain.KeyFamilyRevocationRoot,
  1656  		Index:  keyLocIndex,
  1657  	}
  1658  
  1659  	// First, we'll encode the KeyLocator into a buffer.
  1660  	var (
  1661  		b   bytes.Buffer
  1662  		buf [8]byte
  1663  	)
  1664  
  1665  	err := EKeyLocator(&b, &keyLoc, &buf)
  1666  	require.NoError(t, err, "unable to encode key locator")
  1667  
  1668  	// Next, we'll attempt to decode the bytes into a new KeyLocator.
  1669  	r := bytes.NewReader(b.Bytes())
  1670  	var decodedKeyLoc keychain.KeyLocator
  1671  
  1672  	err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
  1673  	require.NoError(t, err, "unable to decode key locator")
  1674  
  1675  	// Finally, we'll compare that the original KeyLocator and the decoded
  1676  	// version are equal.
  1677  	require.Equal(t, keyLoc, decodedKeyLoc)
  1678  }