github.com/decred/dcrlnd@v0.7.6/watchtower/wtwire/init_test.go (about)

     1  package wtwire_test
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/decred/dcrd/chaincfg/chainhash"
     7  	"github.com/decred/dcrd/chaincfg/v3"
     8  	"github.com/decred/dcrlnd/feature"
     9  	"github.com/decred/dcrlnd/lnwire"
    10  	"github.com/decred/dcrlnd/watchtower/wtwire"
    11  )
    12  
    13  var (
    14  	testnetChainHash = chaincfg.TestNet3Params().GenesisHash
    15  	mainnetChainHash = chaincfg.MainNetParams().GenesisHash
    16  )
    17  
    18  type checkRemoteInitTest struct {
    19  	name      string
    20  	lFeatures *lnwire.RawFeatureVector
    21  	lHash     chainhash.Hash
    22  	rFeatures *lnwire.RawFeatureVector
    23  	rHash     chainhash.Hash
    24  	expErr    error
    25  }
    26  
    27  var checkRemoteInitTests = []checkRemoteInitTest{
    28  	{
    29  		name:      "same chain, local-optional remote-required",
    30  		lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
    31  		lHash:     testnetChainHash,
    32  		rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
    33  		rHash:     testnetChainHash,
    34  	},
    35  	{
    36  		name:      "same chain, local-required remote-optional",
    37  		lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
    38  		lHash:     testnetChainHash,
    39  		rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
    40  		rHash:     testnetChainHash,
    41  	},
    42  	{
    43  		name:      "different chain, local-optional remote-required",
    44  		lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
    45  		lHash:     testnetChainHash,
    46  		rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
    47  		rHash:     mainnetChainHash,
    48  		expErr:    wtwire.NewErrUnknownChainHash(mainnetChainHash),
    49  	},
    50  	{
    51  		name:      "different chain, local-required remote-optional",
    52  		lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
    53  		lHash:     testnetChainHash,
    54  		rFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
    55  		rHash:     mainnetChainHash,
    56  		expErr:    wtwire.NewErrUnknownChainHash(mainnetChainHash),
    57  	},
    58  	{
    59  		name:      "same chain, remote-unknown-required",
    60  		lFeatures: lnwire.NewRawFeatureVector(wtwire.AltruistSessionsOptional),
    61  		lHash:     testnetChainHash,
    62  		rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired),
    63  		rHash:     testnetChainHash,
    64  		expErr: feature.NewErrUnknownRequired(
    65  			[]lnwire.FeatureBit{lnwire.GossipQueriesRequired},
    66  		),
    67  	},
    68  }
    69  
    70  // TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with
    71  // the remote party's Init message and the default wtwire.Features. We assert
    72  // the validity of advertised features from the perspective of both client and
    73  // server, as well as failure cases such as differing chain hashes or unknown
    74  // required features.
    75  func TestCheckRemoteInit(t *testing.T) {
    76  	for _, test := range checkRemoteInitTests {
    77  		t.Run(test.name, func(t *testing.T) {
    78  			testCheckRemoteInit(t, test)
    79  		})
    80  	}
    81  }
    82  
    83  func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) {
    84  	localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash)
    85  	remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash)
    86  
    87  	err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
    88  	switch {
    89  
    90  	// Both non-nil, pass.
    91  	case err == nil && test.expErr == nil:
    92  		return
    93  
    94  	// One is nil and one is non-nil, fail.
    95  	default:
    96  		t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err)
    97  
    98  	// Both non-nil, assert same error type.
    99  	case err != nil && test.expErr != nil:
   100  	}
   101  
   102  	// Compare error strings to assert same type.
   103  	if err.Error() != test.expErr.Error() {
   104  		t.Fatalf("error mismatch, want: %v, got: %v",
   105  			test.expErr.Error(), err.Error())
   106  	}
   107  }