github.com/decred/dcrlnd@v0.7.6/peer/test_utils.go (about) 1 package peer 2 3 import ( 4 "bytes" 5 crand "crypto/rand" 6 "encoding/binary" 7 "io" 8 "io/ioutil" 9 "math/rand" 10 "net" 11 "os" 12 "testing" 13 "time" 14 15 "github.com/decred/dcrd/chaincfg/chainhash" 16 "github.com/decred/dcrd/chaincfg/v3" 17 "github.com/decred/dcrd/dcrec/secp256k1/v4" 18 "github.com/decred/dcrd/dcrutil/v4" 19 "github.com/decred/dcrd/wire" 20 "github.com/decred/dcrlnd/chainntnfs" 21 "github.com/decred/dcrlnd/channeldb" 22 "github.com/decred/dcrlnd/htlcswitch" 23 "github.com/decred/dcrlnd/input" 24 "github.com/decred/dcrlnd/keychain" 25 "github.com/decred/dcrlnd/lntest/channels" 26 "github.com/decred/dcrlnd/lntest/mock" 27 "github.com/decred/dcrlnd/lnwallet" 28 "github.com/decred/dcrlnd/lnwallet/chainfee" 29 "github.com/decred/dcrlnd/lnwire" 30 "github.com/decred/dcrlnd/netann" 31 "github.com/decred/dcrlnd/queue" 32 "github.com/decred/dcrlnd/shachain" 33 "github.com/stretchr/testify/require" 34 ) 35 36 const ( 37 broadcastHeight = 100 38 39 // timeout is a timeout value to use for tests which need to wait for 40 // a return value on a channel. 41 timeout = time.Second * 5 42 ) 43 44 var ( 45 // Just use some arbitrary bytes as delivery script. 46 dummyDeliveryScript = channels.AlicesPrivKey 47 48 testKeyLoc = keychain.KeyLocator{Family: keychain.KeyFamilyNodeKey} 49 ) 50 51 // noUpdate is a function which can be used as a parameter in createTestPeer to 52 // call the setup code with no custom values on the channels set up. 53 var noUpdate = func(a, b *channeldb.OpenChannel) {} 54 55 // createTestPeer creates a channel between two nodes, and returns a peer for 56 // one of the nodes, together with the channel seen from both nodes. It takes 57 // an updateChan function which can be used to modify the default values on 58 // the channel states for each peer. 59 func createTestPeer(notifier chainntnfs.ChainNotifier, 60 publTx chan *wire.MsgTx, updateChan func(a, b *channeldb.OpenChannel), 61 mockSwitch *mockMessageSwitch) ( 62 *Brontide, *lnwallet.LightningChannel, func(), error) { 63 64 chainParams := chaincfg.RegNetParams() 65 66 nodeKeyLocator := keychain.KeyLocator{ 67 Family: keychain.KeyFamilyNodeKey, 68 } 69 aliceKeyPriv, aliceKeyPub := channels.PrivKeyFromBytes(channels.AlicesPrivKey) 70 aliceKeySigner := keychain.NewPrivKeyMessageSigner(aliceKeyPriv, nodeKeyLocator) 71 bobKeyPriv, bobKeyPub := channels.PrivKeyFromBytes(channels.BobsPrivKey) 72 73 channelCapacity := dcrutil.Amount(10 * 1e8) 74 channelBal := channelCapacity / 2 75 aliceDustLimit := dcrutil.Amount(200) 76 bobDustLimit := dcrutil.Amount(1300) 77 csvTimeoutAlice := uint32(5) 78 csvTimeoutBob := uint32(4) 79 isAliceInitiator := true 80 81 prevOut := &wire.OutPoint{ 82 Hash: channels.TestHdSeed, 83 Index: 0, 84 } 85 fundingTxIn := wire.NewTxIn(prevOut, 0, nil) // TODO(decred): Need correct input value 86 87 aliceCfg := channeldb.ChannelConfig{ 88 ChannelConstraints: channeldb.ChannelConstraints{ 89 DustLimit: aliceDustLimit, 90 MaxPendingAmount: lnwire.MilliAtom(rand.Int63()), 91 ChanReserve: dcrutil.Amount(rand.Int63()), 92 MinHTLC: lnwire.MilliAtom(rand.Int63()), 93 MaxAcceptedHtlcs: uint16(rand.Int31()), 94 CsvDelay: uint16(csvTimeoutAlice), 95 }, 96 MultiSigKey: keychain.KeyDescriptor{ 97 PubKey: aliceKeyPub, 98 }, 99 RevocationBasePoint: keychain.KeyDescriptor{ 100 PubKey: aliceKeyPub, 101 }, 102 PaymentBasePoint: keychain.KeyDescriptor{ 103 PubKey: aliceKeyPub, 104 }, 105 DelayBasePoint: keychain.KeyDescriptor{ 106 PubKey: aliceKeyPub, 107 }, 108 HtlcBasePoint: keychain.KeyDescriptor{ 109 PubKey: aliceKeyPub, 110 }, 111 } 112 bobCfg := channeldb.ChannelConfig{ 113 ChannelConstraints: channeldb.ChannelConstraints{ 114 DustLimit: bobDustLimit, 115 MaxPendingAmount: lnwire.MilliAtom(rand.Int63()), 116 ChanReserve: dcrutil.Amount(rand.Int63()), 117 MinHTLC: lnwire.MilliAtom(rand.Int63()), 118 MaxAcceptedHtlcs: uint16(rand.Int31()), 119 CsvDelay: uint16(csvTimeoutBob), 120 }, 121 MultiSigKey: keychain.KeyDescriptor{ 122 PubKey: bobKeyPub, 123 }, 124 RevocationBasePoint: keychain.KeyDescriptor{ 125 PubKey: bobKeyPub, 126 }, 127 PaymentBasePoint: keychain.KeyDescriptor{ 128 PubKey: bobKeyPub, 129 }, 130 DelayBasePoint: keychain.KeyDescriptor{ 131 PubKey: bobKeyPub, 132 }, 133 HtlcBasePoint: keychain.KeyDescriptor{ 134 PubKey: bobKeyPub, 135 }, 136 } 137 138 bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) 139 if err != nil { 140 return nil, nil, nil, err 141 } 142 bobPreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*bobRoot)) 143 bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) 144 if err != nil { 145 return nil, nil, nil, err 146 } 147 bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) 148 149 aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) 150 if err != nil { 151 return nil, nil, nil, err 152 } 153 alicePreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*aliceRoot)) 154 aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) 155 if err != nil { 156 return nil, nil, nil, err 157 } 158 aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) 159 160 aliceCommitTx, bobCommitTx, err := lnwallet.CreateCommitmentTxns( 161 channelBal, channelBal, &aliceCfg, &bobCfg, aliceCommitPoint, 162 bobCommitPoint, *fundingTxIn, channeldb.SingleFunderTweaklessBit, 163 isAliceInitiator, 0, chainParams, 164 ) 165 if err != nil { 166 return nil, nil, nil, err 167 } 168 169 alicePath, err := ioutil.TempDir("", "alicedb") 170 if err != nil { 171 return nil, nil, nil, err 172 } 173 174 dbAlice, err := channeldb.Open(alicePath) 175 if err != nil { 176 return nil, nil, nil, err 177 } 178 179 bobPath, err := ioutil.TempDir("", "bobdb") 180 if err != nil { 181 return nil, nil, nil, err 182 } 183 184 dbBob, err := channeldb.Open(bobPath) 185 if err != nil { 186 return nil, nil, nil, err 187 } 188 189 estimator := chainfee.NewStaticEstimator(12500, 0) 190 feePerKB, err := estimator.EstimateFeePerKB(1) 191 if err != nil { 192 return nil, nil, nil, err 193 } 194 195 // TODO(roasbeef): need to factor in commit fee? 196 aliceCommit := channeldb.ChannelCommitment{ 197 CommitHeight: 0, 198 LocalBalance: lnwire.NewMAtomsFromAtoms(channelBal), 199 RemoteBalance: lnwire.NewMAtomsFromAtoms(channelBal), 200 FeePerKB: dcrutil.Amount(feePerKB), 201 CommitFee: feePerKB.FeeForSize(input.CommitmentTxSize), 202 CommitTx: aliceCommitTx, 203 CommitSig: bytes.Repeat([]byte{1}, 71), 204 } 205 bobCommit := channeldb.ChannelCommitment{ 206 CommitHeight: 0, 207 LocalBalance: lnwire.NewMAtomsFromAtoms(channelBal), 208 RemoteBalance: lnwire.NewMAtomsFromAtoms(channelBal), 209 FeePerKB: dcrutil.Amount(feePerKB), 210 CommitFee: feePerKB.FeeForSize(input.CommitmentTxSize), 211 CommitTx: bobCommitTx, 212 CommitSig: bytes.Repeat([]byte{1}, 71), 213 } 214 215 var chanIDBytes [8]byte 216 if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil { 217 return nil, nil, nil, err 218 } 219 220 shortChanID := lnwire.NewShortChanIDFromInt( 221 binary.BigEndian.Uint64(chanIDBytes[:]), 222 ) 223 224 aliceChannelState := &channeldb.OpenChannel{ 225 LocalChanCfg: aliceCfg, 226 RemoteChanCfg: bobCfg, 227 IdentityPub: aliceKeyPub, 228 FundingOutpoint: *prevOut, 229 ShortChannelID: shortChanID, 230 ChanType: channeldb.SingleFunderTweaklessBit, 231 IsInitiator: isAliceInitiator, 232 Capacity: channelCapacity, 233 RemoteCurrentRevocation: bobCommitPoint, 234 RevocationProducer: alicePreimageProducer, 235 RevocationStore: shachain.NewRevocationStore(), 236 LocalCommitment: aliceCommit, 237 RemoteCommitment: aliceCommit, 238 Db: dbAlice.ChannelStateDB(), 239 Packager: channeldb.NewChannelPackager(shortChanID), 240 FundingTxn: channels.TestFundingTx, 241 } 242 bobChannelState := &channeldb.OpenChannel{ 243 LocalChanCfg: bobCfg, 244 RemoteChanCfg: aliceCfg, 245 IdentityPub: bobKeyPub, 246 FundingOutpoint: *prevOut, 247 ChanType: channeldb.SingleFunderTweaklessBit, 248 IsInitiator: !isAliceInitiator, 249 Capacity: channelCapacity, 250 RemoteCurrentRevocation: aliceCommitPoint, 251 RevocationProducer: bobPreimageProducer, 252 RevocationStore: shachain.NewRevocationStore(), 253 LocalCommitment: bobCommit, 254 RemoteCommitment: bobCommit, 255 Db: dbBob.ChannelStateDB(), 256 Packager: channeldb.NewChannelPackager(shortChanID), 257 } 258 259 // Set custom values on the channel states. 260 updateChan(aliceChannelState, bobChannelState) 261 262 aliceAddr := &net.TCPAddr{ 263 IP: net.ParseIP("127.0.0.1"), 264 Port: 18555, 265 } 266 267 if err := aliceChannelState.SyncPending(aliceAddr, 0); err != nil { 268 return nil, nil, nil, err 269 } 270 271 bobAddr := &net.TCPAddr{ 272 IP: net.ParseIP("127.0.0.1"), 273 Port: 18556, 274 } 275 276 if err := bobChannelState.SyncPending(bobAddr, 0); err != nil { 277 return nil, nil, nil, err 278 } 279 280 cleanUpFunc := func() { 281 os.RemoveAll(bobPath) 282 os.RemoveAll(alicePath) 283 } 284 285 aliceSigner := &mock.SingleSigner{Privkey: aliceKeyPriv} 286 bobSigner := &mock.SingleSigner{Privkey: bobKeyPriv} 287 288 alicePool := lnwallet.NewSigPool(1, aliceSigner) 289 channelAlice, err := lnwallet.NewLightningChannel( 290 aliceSigner, aliceChannelState, alicePool, chainParams, 291 ) 292 if err != nil { 293 return nil, nil, nil, err 294 } 295 _ = alicePool.Start() 296 297 bobPool := lnwallet.NewSigPool(1, bobSigner) 298 channelBob, err := lnwallet.NewLightningChannel( 299 bobSigner, bobChannelState, bobPool, chainParams, 300 ) 301 if err != nil { 302 return nil, nil, nil, err 303 } 304 _ = bobPool.Start() 305 306 chainIO := &mock.ChainIO{ 307 BestHeight: broadcastHeight, 308 } 309 wallet := &lnwallet.LightningWallet{ 310 WalletController: &mock.WalletController{ 311 RootKey: aliceKeyPriv, 312 PublishedTransactions: publTx, 313 }, 314 } 315 316 // If mockSwitch is not set by the caller, set it to the default as the 317 // caller does not need to control it. 318 if mockSwitch == nil { 319 mockSwitch = &mockMessageSwitch{} 320 } 321 322 nodeSignerAlice := netann.NewNodeSigner(aliceKeySigner) 323 324 const chanActiveTimeout = time.Minute 325 326 chanStatusMgr, err := netann.NewChanStatusManager(&netann.ChanStatusConfig{ 327 ChanStatusSampleInterval: 30 * time.Second, 328 ChanEnableTimeout: chanActiveTimeout, 329 ChanDisableTimeout: 2 * time.Minute, 330 DB: dbAlice.ChannelStateDB(), 331 Graph: dbAlice.ChannelGraph(), 332 MessageSigner: nodeSignerAlice, 333 OurPubKey: aliceKeyPub, 334 OurKeyLoc: testKeyLoc, 335 IsChannelActive: func(lnwire.ChannelID) bool { return true }, 336 ApplyChannelUpdate: func(*lnwire.ChannelUpdate) error { return nil }, 337 }) 338 if err != nil { 339 return nil, nil, nil, err 340 } 341 if err = chanStatusMgr.Start(); err != nil { 342 return nil, nil, nil, err 343 } 344 345 errBuffer, err := queue.NewCircularBuffer(ErrorBufferSize) 346 if err != nil { 347 return nil, nil, nil, err 348 } 349 350 var pubKey [33]byte 351 copy(pubKey[:], aliceKeyPub.SerializeCompressed()) 352 353 cfgAddr := &lnwire.NetAddress{ 354 IdentityKey: aliceKeyPub, 355 Address: aliceAddr, 356 ChainNet: wire.SimNet, 357 } 358 359 cfg := &Config{ 360 Addr: cfgAddr, 361 PubKeyBytes: pubKey, 362 ErrorBuffer: errBuffer, 363 ChainIO: chainIO, 364 Switch: mockSwitch, 365 366 ChanActiveTimeout: chanActiveTimeout, 367 InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), 368 369 ChannelDB: dbAlice.ChannelStateDB(), 370 FeeEstimator: estimator, 371 Wallet: wallet, 372 ChainNotifier: notifier, 373 ChanStatusMgr: chanStatusMgr, 374 DisconnectPeer: func(b *secp256k1.PublicKey) error { return nil }, 375 ChainParams: chainParams, 376 } 377 378 alicePeer := NewBrontide(*cfg) 379 380 chanID := lnwire.NewChanIDFromOutPoint(channelAlice.ChannelPoint()) 381 alicePeer.activeChannels[chanID] = channelAlice 382 383 alicePeer.wg.Add(1) 384 go alicePeer.channelManager() 385 386 return alicePeer, channelBob, cleanUpFunc, nil 387 } 388 389 // mockMessageSwitch is a mock implementation of the messageSwitch interface 390 // used for testing without relying on a *htlcswitch.Switch in unit tests. 391 type mockMessageSwitch struct { 392 links []htlcswitch.ChannelUpdateHandler 393 } 394 395 // BestHeight currently returns a dummy value. 396 func (m *mockMessageSwitch) BestHeight() uint32 { 397 return 0 398 } 399 400 // CircuitModifier currently returns a dummy value. 401 func (m *mockMessageSwitch) CircuitModifier() htlcswitch.CircuitModifier { 402 return nil 403 } 404 405 // RemoveLink currently does nothing. 406 func (m *mockMessageSwitch) RemoveLink(cid lnwire.ChannelID) {} 407 408 // CreateAndAddLink currently returns a dummy value. 409 func (m *mockMessageSwitch) CreateAndAddLink(cfg htlcswitch.ChannelLinkConfig, 410 lnChan *lnwallet.LightningChannel) error { 411 412 return nil 413 } 414 415 // GetLinksByInterface returns the active links. 416 func (m *mockMessageSwitch) GetLinksByInterface(pub [33]byte) ( 417 []htlcswitch.ChannelUpdateHandler, error) { 418 419 return m.links, nil 420 } 421 422 // mockUpdateHandler is a mock implementation of the ChannelUpdateHandler 423 // interface. It is used in mockMessageSwitch's GetLinksByInterface method. 424 type mockUpdateHandler struct { 425 cid lnwire.ChannelID 426 } 427 428 // newMockUpdateHandler creates a new mockUpdateHandler. 429 func newMockUpdateHandler(cid lnwire.ChannelID) *mockUpdateHandler { 430 return &mockUpdateHandler{ 431 cid: cid, 432 } 433 } 434 435 // HandleChannelUpdate currently does nothing. 436 func (m *mockUpdateHandler) HandleChannelUpdate(msg lnwire.Message) {} 437 438 // ChanID returns the mockUpdateHandler's cid. 439 func (m *mockUpdateHandler) ChanID() lnwire.ChannelID { return m.cid } 440 441 // Bandwidth currently returns a dummy value. 442 func (m *mockUpdateHandler) Bandwidth() lnwire.MilliAtom { return 0 } 443 444 // EligibleToForward currently returns a dummy value. 445 func (m *mockUpdateHandler) EligibleToForward() bool { return false } 446 447 // MayAddOutgoingHtlc currently returns nil. 448 func (m *mockUpdateHandler) MayAddOutgoingHtlc(lnwire.MilliAtom) error { return nil } 449 450 // ShutdownIfChannelClean currently returns nil. 451 func (m *mockUpdateHandler) ShutdownIfChannelClean() error { return nil } 452 453 type mockMessageConn struct { 454 t *testing.T 455 456 // MessageConn embeds our interface so that the mock does not need to 457 // implement every function. The mock will panic if an unspecified function 458 // is called. 459 MessageConn 460 461 // writtenMessages is a channel that our mock pushes written messages into. 462 writtenMessages chan []byte 463 464 readMessages chan []byte 465 curReadMessage []byte 466 } 467 468 func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { 469 return &mockMessageConn{ 470 t: t, 471 writtenMessages: make(chan []byte, expectedMessages), 472 readMessages: make(chan []byte, 1), 473 } 474 } 475 476 // SetWriteDeadline mocks setting write deadline for our conn. 477 func (m *mockMessageConn) SetWriteDeadline(time.Time) error { 478 return nil 479 } 480 481 // Flush mocks a message conn flush. 482 func (m *mockMessageConn) Flush() (int, error) { 483 return 0, nil 484 } 485 486 // WriteMessage mocks sending of a message on our connection. It will push 487 // the bytes sent into the mock's writtenMessages channel. 488 func (m *mockMessageConn) WriteMessage(msg []byte) error { 489 select { 490 case m.writtenMessages <- msg: 491 case <-time.After(timeout): 492 m.t.Fatalf("timeout sending message: %v", msg) 493 } 494 495 return nil 496 } 497 498 // assertWrite asserts that our mock as had WriteMessage called with the byte 499 // slice we expect. 500 func (m *mockMessageConn) assertWrite(expected []byte) { 501 select { 502 case actual := <-m.writtenMessages: 503 require.Equal(m.t, expected, actual) 504 505 case <-time.After(timeout): 506 m.t.Fatalf("timeout waiting for write: %v", expected) 507 } 508 } 509 510 func (m *mockMessageConn) SetReadDeadline(t time.Time) error { 511 return nil 512 } 513 514 func (m *mockMessageConn) ReadNextHeader() (uint32, error) { 515 m.curReadMessage = <-m.readMessages 516 return uint32(len(m.curReadMessage)), nil 517 } 518 519 func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) { 520 return m.curReadMessage, nil 521 } 522 523 func (m *mockMessageConn) RemoteAddr() net.Addr { 524 return nil 525 } 526 527 func (m *mockMessageConn) LocalAddr() net.Addr { 528 return nil 529 }