github.com/decred/dcrlnd@v0.7.6/routing/localchans/manager_test.go (about)

     1  package localchans
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/decred/dcrlnd/kvdb"
     7  	"github.com/decred/dcrlnd/lnrpc"
     8  	"github.com/decred/dcrlnd/lnwire"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/decred/dcrd/chaincfg/chainhash"
    12  	"github.com/decred/dcrd/dcrutil/v4"
    13  
    14  	"github.com/decred/dcrd/wire"
    15  	"github.com/decred/dcrlnd/channeldb"
    16  	"github.com/decred/dcrlnd/discovery"
    17  	"github.com/decred/dcrlnd/htlcswitch"
    18  	"github.com/decred/dcrlnd/routing"
    19  )
    20  
    21  // TestManager tests that the local channel manager properly propagates fee
    22  // updates to gossiper and links.
    23  func TestManager(t *testing.T) {
    24  	t.Parallel()
    25  
    26  	type channel struct {
    27  		edgeInfo *channeldb.ChannelEdgeInfo
    28  	}
    29  
    30  	var (
    31  		chanPointValid     = wire.OutPoint{Hash: chainhash.Hash{1}, Index: 2}
    32  		chanCap            = dcrutil.Amount(1000)
    33  		chanPointMissing   = wire.OutPoint{Hash: chainhash.Hash{2}, Index: 2}
    34  		maxPendingAmount   = lnwire.MilliAtom(999000)
    35  		minHTLC            = lnwire.MilliAtom(2000)
    36  		expectedNumUpdates int
    37  		channelSet         []channel
    38  	)
    39  
    40  	newPolicy := routing.ChannelPolicy{
    41  		FeeSchema: routing.FeeSchema{
    42  			BaseFee: 100,
    43  			FeeRate: 200,
    44  		},
    45  		TimeLockDelta: 80,
    46  		MaxHTLC:       5000,
    47  	}
    48  
    49  	currentPolicy := channeldb.ChannelEdgePolicy{
    50  		MinHTLC:      minHTLC,
    51  		MessageFlags: lnwire.ChanUpdateOptionMaxHtlc,
    52  	}
    53  
    54  	updateForwardingPolicies := func(
    55  		chanPolicies map[wire.OutPoint]htlcswitch.ForwardingPolicy) {
    56  
    57  		if len(chanPolicies) == 0 {
    58  			return
    59  		}
    60  
    61  		if len(chanPolicies) != 1 {
    62  			t.Fatal("unexpected number of policies to apply")
    63  		}
    64  
    65  		policy := chanPolicies[chanPointValid]
    66  		if policy.TimeLockDelta != newPolicy.TimeLockDelta {
    67  			t.Fatal("unexpected time lock delta")
    68  		}
    69  		if policy.BaseFee != newPolicy.BaseFee {
    70  			t.Fatal("unexpected base fee")
    71  		}
    72  		if uint32(policy.FeeRate) != newPolicy.FeeRate {
    73  			t.Fatal("unexpected base fee")
    74  		}
    75  		if policy.MaxHTLC != newPolicy.MaxHTLC {
    76  			t.Fatal("unexpected max htlc")
    77  		}
    78  	}
    79  
    80  	propagateChanPolicyUpdate := func(
    81  		edgesToUpdate []discovery.EdgeWithInfo) error {
    82  
    83  		if len(edgesToUpdate) != expectedNumUpdates {
    84  			t.Fatalf("unexpected number of updates,"+
    85  				" expected %d got %d", expectedNumUpdates,
    86  				len(edgesToUpdate))
    87  		}
    88  
    89  		for _, edge := range edgesToUpdate {
    90  			policy := edge.Edge
    91  			if !policy.MessageFlags.HasMaxHtlc() {
    92  				t.Fatal("expected max htlc flag")
    93  			}
    94  			if policy.TimeLockDelta != uint16(newPolicy.TimeLockDelta) {
    95  				t.Fatal("unexpected time lock delta")
    96  			}
    97  			if policy.FeeBaseMAtoms != newPolicy.BaseFee {
    98  				t.Fatal("unexpected base fee")
    99  			}
   100  			if uint32(policy.FeeProportionalMillionths) != newPolicy.FeeRate {
   101  				t.Fatal("unexpected base fee")
   102  			}
   103  			if policy.MaxHTLC != newPolicy.MaxHTLC {
   104  				t.Fatal("unexpected max htlc")
   105  			}
   106  		}
   107  
   108  		return nil
   109  	}
   110  
   111  	forAllOutgoingChannels := func(cb func(kvdb.RTx,
   112  		*channeldb.ChannelEdgeInfo,
   113  		*channeldb.ChannelEdgePolicy) error) error {
   114  
   115  		for _, c := range channelSet {
   116  			if err := cb(nil, c.edgeInfo, &currentPolicy); err != nil {
   117  				return err
   118  			}
   119  		}
   120  		return nil
   121  	}
   122  
   123  	fetchChannel := func(tx kvdb.RTx, chanPoint wire.OutPoint) (
   124  		*channeldb.OpenChannel, error) {
   125  
   126  		if chanPoint == chanPointMissing {
   127  			return &channeldb.OpenChannel{}, channeldb.ErrChannelNotFound
   128  		}
   129  
   130  		constraints := channeldb.ChannelConstraints{
   131  			MaxPendingAmount: maxPendingAmount,
   132  			MinHTLC:          minHTLC,
   133  		}
   134  
   135  		return &channeldb.OpenChannel{
   136  			LocalChanCfg: channeldb.ChannelConfig{
   137  				ChannelConstraints: constraints,
   138  			},
   139  		}, nil
   140  	}
   141  
   142  	manager := Manager{
   143  		UpdateForwardingPolicies:  updateForwardingPolicies,
   144  		PropagateChanPolicyUpdate: propagateChanPolicyUpdate,
   145  		ForAllOutgoingChannels:    forAllOutgoingChannels,
   146  		FetchChannel:              fetchChannel,
   147  	}
   148  
   149  	// Policy with no max htlc value.
   150  	MaxHTLCPolicy := currentPolicy
   151  	MaxHTLCPolicy.MaxHTLC = newPolicy.MaxHTLC
   152  	noMaxHtlcPolicy := newPolicy
   153  	noMaxHtlcPolicy.MaxHTLC = 0
   154  
   155  	tests := []struct {
   156  		name                   string
   157  		currentPolicy          channeldb.ChannelEdgePolicy
   158  		newPolicy              routing.ChannelPolicy
   159  		channelSet             []channel
   160  		specifiedChanPoints    []wire.OutPoint
   161  		expectedNumUpdates     int
   162  		expectedUpdateFailures []lnrpc.UpdateFailure
   163  		expectErr              error
   164  	}{
   165  		{
   166  			name:          "valid channel",
   167  			currentPolicy: currentPolicy,
   168  			newPolicy:     newPolicy,
   169  			channelSet: []channel{
   170  				{
   171  					edgeInfo: &channeldb.ChannelEdgeInfo{
   172  						Capacity:     chanCap,
   173  						ChannelPoint: chanPointValid,
   174  					},
   175  				},
   176  			},
   177  			specifiedChanPoints:    []wire.OutPoint{chanPointValid},
   178  			expectedNumUpdates:     1,
   179  			expectedUpdateFailures: []lnrpc.UpdateFailure{},
   180  			expectErr:              nil,
   181  		},
   182  		{
   183  			name:          "all channels",
   184  			currentPolicy: currentPolicy,
   185  			newPolicy:     newPolicy,
   186  			channelSet: []channel{
   187  				{
   188  					edgeInfo: &channeldb.ChannelEdgeInfo{
   189  						Capacity:     chanCap,
   190  						ChannelPoint: chanPointValid,
   191  					},
   192  				},
   193  			},
   194  			specifiedChanPoints:    []wire.OutPoint{},
   195  			expectedNumUpdates:     1,
   196  			expectedUpdateFailures: []lnrpc.UpdateFailure{},
   197  			expectErr:              nil,
   198  		},
   199  		{
   200  			name:          "missing channel",
   201  			currentPolicy: currentPolicy,
   202  			newPolicy:     newPolicy,
   203  			channelSet: []channel{
   204  				{
   205  					edgeInfo: &channeldb.ChannelEdgeInfo{
   206  						Capacity:     chanCap,
   207  						ChannelPoint: chanPointValid,
   208  					},
   209  				},
   210  			},
   211  			specifiedChanPoints: []wire.OutPoint{chanPointMissing},
   212  			expectedNumUpdates:  0,
   213  			expectedUpdateFailures: []lnrpc.UpdateFailure{
   214  				lnrpc.UpdateFailure_UPDATE_FAILURE_NOT_FOUND,
   215  			},
   216  			expectErr: nil,
   217  		},
   218  		{
   219  			// Here, no max htlc is specified, the max htlc value
   220  			// should be kept unchanged.
   221  			name:          "no max htlc specified",
   222  			currentPolicy: MaxHTLCPolicy,
   223  			newPolicy:     noMaxHtlcPolicy,
   224  			channelSet: []channel{
   225  				{
   226  					edgeInfo: &channeldb.ChannelEdgeInfo{
   227  						Capacity:     chanCap,
   228  						ChannelPoint: chanPointValid,
   229  					},
   230  				},
   231  			},
   232  			specifiedChanPoints:    []wire.OutPoint{chanPointValid},
   233  			expectedNumUpdates:     1,
   234  			expectedUpdateFailures: []lnrpc.UpdateFailure{},
   235  			expectErr:              nil,
   236  		},
   237  	}
   238  
   239  	for _, test := range tests {
   240  		test := test
   241  		t.Run(test.name, func(t *testing.T) {
   242  			currentPolicy = test.currentPolicy
   243  			channelSet = test.channelSet
   244  			expectedNumUpdates = test.expectedNumUpdates
   245  
   246  			failedUpdates, err := manager.UpdatePolicy(test.newPolicy,
   247  				test.specifiedChanPoints...)
   248  
   249  			if len(failedUpdates) != len(test.expectedUpdateFailures) {
   250  				t.Fatalf("wrong number of failed policy updates")
   251  			}
   252  
   253  			if len(test.expectedUpdateFailures) > 0 {
   254  				if failedUpdates[0].Reason != test.expectedUpdateFailures[0] {
   255  					t.Fatalf("wrong expected policy update failure")
   256  				}
   257  			}
   258  
   259  			require.Equal(t, test.expectErr, err)
   260  		})
   261  	}
   262  }