github.com/decred/dcrlnd@v0.7.6/channeldb/graph_cache_test.go (about) 1 package channeldb 2 3 import ( 4 "encoding/hex" 5 "testing" 6 7 "github.com/decred/dcrlnd/kvdb" 8 "github.com/decred/dcrlnd/lnwire" 9 "github.com/decred/dcrlnd/routing/route" 10 "github.com/stretchr/testify/require" 11 ) 12 13 var ( 14 pubKey1Bytes, _ = hex.DecodeString( 15 "0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" + 16 "22c91d", 17 ) 18 pubKey2Bytes, _ = hex.DecodeString( 19 "038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" + 20 "f4484f", 21 ) 22 23 pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes) 24 pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes) 25 ) 26 27 type node struct { 28 pubKey route.Vertex 29 features *lnwire.FeatureVector 30 31 edgeInfos []*ChannelEdgeInfo 32 outPolicies []*ChannelEdgePolicy 33 inPolicies []*ChannelEdgePolicy 34 } 35 36 func (n *node) PubKey() route.Vertex { 37 return n.pubKey 38 } 39 func (n *node) Features() *lnwire.FeatureVector { 40 return n.features 41 } 42 43 func (n *node) ForEachChannel(tx kvdb.RTx, 44 cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, 45 *ChannelEdgePolicy) error) error { 46 47 for idx := range n.edgeInfos { 48 err := cb( 49 tx, n.edgeInfos[idx], n.outPolicies[idx], 50 n.inPolicies[idx], 51 ) 52 if err != nil { 53 return err 54 } 55 } 56 57 return nil 58 } 59 60 // TestGraphCacheAddNode tests that a channel going from node A to node B can be 61 // cached correctly, independent of the direction we add the channel as. 62 func TestGraphCacheAddNode(t *testing.T) { 63 t.Parallel() 64 65 runTest := func(nodeA, nodeB route.Vertex) { 66 t.Helper() 67 68 channelFlagA, channelFlagB := 0, 1 69 if nodeA == pubKey2 { 70 channelFlagA, channelFlagB = 1, 0 71 } 72 73 outPolicy1 := &ChannelEdgePolicy{ 74 ChannelID: 1000, 75 ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), 76 Node: &LightningNode{ 77 PubKeyBytes: nodeB, 78 Features: lnwire.EmptyFeatureVector(), 79 }, 80 } 81 inPolicy1 := &ChannelEdgePolicy{ 82 ChannelID: 1000, 83 ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), 84 Node: &LightningNode{ 85 PubKeyBytes: nodeA, 86 Features: lnwire.EmptyFeatureVector(), 87 }, 88 } 89 node := &node{ 90 pubKey: nodeA, 91 features: lnwire.EmptyFeatureVector(), 92 edgeInfos: []*ChannelEdgeInfo{{ 93 ChannelID: 1000, 94 // Those are direction independent! 95 NodeKey1Bytes: pubKey1, 96 NodeKey2Bytes: pubKey2, 97 Capacity: 500, 98 }}, 99 outPolicies: []*ChannelEdgePolicy{outPolicy1}, 100 inPolicies: []*ChannelEdgePolicy{inPolicy1}, 101 } 102 cache := NewGraphCache(10) 103 require.NoError(t, cache.AddNode(nil, node)) 104 105 var fromChannels, toChannels []*DirectedChannel 106 _ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error { 107 fromChannels = append(fromChannels, c) 108 return nil 109 }) 110 _ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error { 111 toChannels = append(toChannels, c) 112 return nil 113 }) 114 115 require.Len(t, fromChannels, 1) 116 require.Len(t, toChannels, 1) 117 118 require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet) 119 assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy) 120 121 require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet) 122 assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) 123 124 // Now that we've inserted two nodes into the graph, check that 125 // we'll recover the same set of channels during ForEachNode. 126 nodes := make(map[route.Vertex]struct{}) 127 chans := make(map[uint64]struct{}) 128 _ = cache.ForEachNode(func(node route.Vertex, 129 edges map[uint64]*DirectedChannel) error { 130 131 nodes[node] = struct{}{} 132 for chanID := range edges { 133 chans[chanID] = struct{}{} 134 } 135 136 return nil 137 }) 138 139 require.Len(t, nodes, 2) 140 require.Len(t, chans, 1) 141 } 142 143 runTest(pubKey1, pubKey2) 144 runTest(pubKey2, pubKey1) 145 } 146 147 func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy, 148 cached *CachedEdgePolicy) { 149 150 require.Equal(t, original.ChannelID, cached.ChannelID) 151 require.Equal(t, original.MessageFlags, cached.MessageFlags) 152 require.Equal(t, original.ChannelFlags, cached.ChannelFlags) 153 require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta) 154 require.Equal(t, original.MinHTLC, cached.MinHTLC) 155 require.Equal(t, original.MaxHTLC, cached.MaxHTLC) 156 require.Equal(t, original.FeeBaseMAtoms, cached.FeeBaseMAtoms) 157 require.Equal( 158 t, original.FeeProportionalMillionths, 159 cached.FeeProportionalMillionths, 160 ) 161 require.Equal( 162 t, 163 route.Vertex(original.Node.PubKeyBytes), 164 cached.ToNodePubKey(), 165 ) 166 require.Equal(t, original.Node.Features, cached.ToNodeFeatures) 167 }