github.com/decred/dcrlnd@v0.7.6/channeldb/graph_cache_test.go (about)

     1  package channeldb
     2  
     3  import (
     4  	"encoding/hex"
     5  	"testing"
     6  
     7  	"github.com/decred/dcrlnd/kvdb"
     8  	"github.com/decred/dcrlnd/lnwire"
     9  	"github.com/decred/dcrlnd/routing/route"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  var (
    14  	pubKey1Bytes, _ = hex.DecodeString(
    15  		"0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" +
    16  			"22c91d",
    17  	)
    18  	pubKey2Bytes, _ = hex.DecodeString(
    19  		"038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" +
    20  			"f4484f",
    21  	)
    22  
    23  	pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes)
    24  	pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes)
    25  )
    26  
    27  type node struct {
    28  	pubKey   route.Vertex
    29  	features *lnwire.FeatureVector
    30  
    31  	edgeInfos   []*ChannelEdgeInfo
    32  	outPolicies []*ChannelEdgePolicy
    33  	inPolicies  []*ChannelEdgePolicy
    34  }
    35  
    36  func (n *node) PubKey() route.Vertex {
    37  	return n.pubKey
    38  }
    39  func (n *node) Features() *lnwire.FeatureVector {
    40  	return n.features
    41  }
    42  
    43  func (n *node) ForEachChannel(tx kvdb.RTx,
    44  	cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
    45  		*ChannelEdgePolicy) error) error {
    46  
    47  	for idx := range n.edgeInfos {
    48  		err := cb(
    49  			tx, n.edgeInfos[idx], n.outPolicies[idx],
    50  			n.inPolicies[idx],
    51  		)
    52  		if err != nil {
    53  			return err
    54  		}
    55  	}
    56  
    57  	return nil
    58  }
    59  
    60  // TestGraphCacheAddNode tests that a channel going from node A to node B can be
    61  // cached correctly, independent of the direction we add the channel as.
    62  func TestGraphCacheAddNode(t *testing.T) {
    63  	t.Parallel()
    64  
    65  	runTest := func(nodeA, nodeB route.Vertex) {
    66  		t.Helper()
    67  
    68  		channelFlagA, channelFlagB := 0, 1
    69  		if nodeA == pubKey2 {
    70  			channelFlagA, channelFlagB = 1, 0
    71  		}
    72  
    73  		outPolicy1 := &ChannelEdgePolicy{
    74  			ChannelID:    1000,
    75  			ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA),
    76  			Node: &LightningNode{
    77  				PubKeyBytes: nodeB,
    78  				Features:    lnwire.EmptyFeatureVector(),
    79  			},
    80  		}
    81  		inPolicy1 := &ChannelEdgePolicy{
    82  			ChannelID:    1000,
    83  			ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB),
    84  			Node: &LightningNode{
    85  				PubKeyBytes: nodeA,
    86  				Features:    lnwire.EmptyFeatureVector(),
    87  			},
    88  		}
    89  		node := &node{
    90  			pubKey:   nodeA,
    91  			features: lnwire.EmptyFeatureVector(),
    92  			edgeInfos: []*ChannelEdgeInfo{{
    93  				ChannelID: 1000,
    94  				// Those are direction independent!
    95  				NodeKey1Bytes: pubKey1,
    96  				NodeKey2Bytes: pubKey2,
    97  				Capacity:      500,
    98  			}},
    99  			outPolicies: []*ChannelEdgePolicy{outPolicy1},
   100  			inPolicies:  []*ChannelEdgePolicy{inPolicy1},
   101  		}
   102  		cache := NewGraphCache(10)
   103  		require.NoError(t, cache.AddNode(nil, node))
   104  
   105  		var fromChannels, toChannels []*DirectedChannel
   106  		_ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error {
   107  			fromChannels = append(fromChannels, c)
   108  			return nil
   109  		})
   110  		_ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error {
   111  			toChannels = append(toChannels, c)
   112  			return nil
   113  		})
   114  
   115  		require.Len(t, fromChannels, 1)
   116  		require.Len(t, toChannels, 1)
   117  
   118  		require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet)
   119  		assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy)
   120  
   121  		require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
   122  		assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
   123  
   124  		// Now that we've inserted two nodes into the graph, check that
   125  		// we'll recover the same set of channels during ForEachNode.
   126  		nodes := make(map[route.Vertex]struct{})
   127  		chans := make(map[uint64]struct{})
   128  		_ = cache.ForEachNode(func(node route.Vertex,
   129  			edges map[uint64]*DirectedChannel) error {
   130  
   131  			nodes[node] = struct{}{}
   132  			for chanID := range edges {
   133  				chans[chanID] = struct{}{}
   134  			}
   135  
   136  			return nil
   137  		})
   138  
   139  		require.Len(t, nodes, 2)
   140  		require.Len(t, chans, 1)
   141  	}
   142  
   143  	runTest(pubKey1, pubKey2)
   144  	runTest(pubKey2, pubKey1)
   145  }
   146  
   147  func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy,
   148  	cached *CachedEdgePolicy) {
   149  
   150  	require.Equal(t, original.ChannelID, cached.ChannelID)
   151  	require.Equal(t, original.MessageFlags, cached.MessageFlags)
   152  	require.Equal(t, original.ChannelFlags, cached.ChannelFlags)
   153  	require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta)
   154  	require.Equal(t, original.MinHTLC, cached.MinHTLC)
   155  	require.Equal(t, original.MaxHTLC, cached.MaxHTLC)
   156  	require.Equal(t, original.FeeBaseMAtoms, cached.FeeBaseMAtoms)
   157  	require.Equal(
   158  		t, original.FeeProportionalMillionths,
   159  		cached.FeeProportionalMillionths,
   160  	)
   161  	require.Equal(
   162  		t,
   163  		route.Vertex(original.Node.PubKeyBytes),
   164  		cached.ToNodePubKey(),
   165  	)
   166  	require.Equal(t, original.Node.Features, cached.ToNodeFeatures)
   167  }