github.com/decred/dcrlnd@v0.7.6/htlcswitch/payment_result_test.go (about)

     1  package htlcswitch
     2  
     3  import (
     4  	"bytes"
     5  	"io/ioutil"
     6  	"math/rand"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/davecgh/go-spew/spew"
    12  	"github.com/decred/dcrlnd/channeldb"
    13  	"github.com/decred/dcrlnd/lntypes"
    14  	"github.com/decred/dcrlnd/lnwire"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  // TestNetworkResultSerialization checks that NetworkResults are properly
    19  // (de)serialized.
    20  func TestNetworkResultSerialization(t *testing.T) {
    21  	t.Parallel()
    22  
    23  	var preimage lntypes.Preimage
    24  	if _, err := rand.Read(preimage[:]); err != nil {
    25  		t.Fatalf("unable gen rand preimag: %v", err)
    26  	}
    27  
    28  	var chanID lnwire.ChannelID
    29  	if _, err := rand.Read(chanID[:]); err != nil {
    30  		t.Fatalf("unable gen rand chanid: %v", err)
    31  	}
    32  
    33  	var reason [256]byte
    34  	if _, err := rand.Read(reason[:]); err != nil {
    35  		t.Fatalf("unable gen rand reason: %v", err)
    36  	}
    37  
    38  	settle := &lnwire.UpdateFulfillHTLC{
    39  		ChanID:          chanID,
    40  		ID:              2,
    41  		PaymentPreimage: preimage,
    42  		ExtraData:       make([]byte, 0),
    43  	}
    44  
    45  	fail := &lnwire.UpdateFailHTLC{
    46  		ChanID:    chanID,
    47  		ID:        1,
    48  		Reason:    []byte{},
    49  		ExtraData: make([]byte, 0),
    50  	}
    51  
    52  	fail2 := &lnwire.UpdateFailHTLC{
    53  		ChanID:    chanID,
    54  		ID:        1,
    55  		Reason:    reason[:],
    56  		ExtraData: make([]byte, 0),
    57  	}
    58  
    59  	testCases := []*networkResult{
    60  		{
    61  			msg: settle,
    62  		},
    63  		{
    64  			msg:          fail,
    65  			unencrypted:  false,
    66  			isResolution: false,
    67  		},
    68  		{
    69  			msg:          fail,
    70  			unencrypted:  false,
    71  			isResolution: true,
    72  		},
    73  		{
    74  			msg:          fail2,
    75  			unencrypted:  true,
    76  			isResolution: false,
    77  		},
    78  	}
    79  
    80  	for _, p := range testCases {
    81  		var buf bytes.Buffer
    82  		if err := serializeNetworkResult(&buf, p); err != nil {
    83  			t.Fatalf("serialize failed: %v", err)
    84  		}
    85  
    86  		r := bytes.NewReader(buf.Bytes())
    87  		p1, err := deserializeNetworkResult(r)
    88  		if err != nil {
    89  			t.Fatalf("unable to deserizlize: %v", err)
    90  		}
    91  
    92  		if !reflect.DeepEqual(p, p1) {
    93  			t.Fatalf("not equal. %v vs %v", spew.Sdump(p),
    94  				spew.Sdump(p1))
    95  		}
    96  	}
    97  }
    98  
    99  // TestNetworkResultStore tests that the networkResult store behaves as
   100  // expected, and that we can store, get and subscribe to results.
   101  func TestNetworkResultStore(t *testing.T) {
   102  	t.Parallel()
   103  
   104  	const numResults = 4
   105  
   106  	tempDir, err := ioutil.TempDir("", "testdb")
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	db, err := channeldb.Open(tempDir)
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	store := newNetworkResultStore(db)
   116  
   117  	var results []*networkResult
   118  	for i := 0; i < numResults; i++ {
   119  		n := &networkResult{
   120  			msg:          &lnwire.UpdateAddHTLC{},
   121  			unencrypted:  true,
   122  			isResolution: true,
   123  		}
   124  		results = append(results, n)
   125  	}
   126  
   127  	// Subscribe to 2 of them.
   128  	var subs []<-chan *networkResult
   129  	for i := uint64(0); i < 2; i++ {
   130  		sub, err := store.subscribeResult(i)
   131  		if err != nil {
   132  			t.Fatalf("unable to subscribe: %v", err)
   133  		}
   134  		subs = append(subs, sub)
   135  	}
   136  
   137  	// Store three of them.
   138  	for i := uint64(0); i < 3; i++ {
   139  		err := store.storeResult(i, results[i])
   140  		if err != nil {
   141  			t.Fatalf("unable to store result: %v", err)
   142  		}
   143  	}
   144  
   145  	// The two subscribers should be notified.
   146  	for _, sub := range subs {
   147  		select {
   148  		case <-sub:
   149  		case <-time.After(1 * time.Second):
   150  			t.Fatalf("no result received")
   151  		}
   152  	}
   153  
   154  	// Let the third one subscribe now. THe result should be received
   155  	// immediately.
   156  	sub, err := store.subscribeResult(2)
   157  	if err != nil {
   158  		t.Fatalf("unable to subscribe: %v", err)
   159  	}
   160  	select {
   161  	case <-sub:
   162  	case <-time.After(1 * time.Second):
   163  		t.Fatalf("no result received")
   164  	}
   165  
   166  	// Try fetching the result directly for the non-stored one. This should
   167  	// fail.
   168  	_, err = store.getResult(3)
   169  	if err != ErrPaymentIDNotFound {
   170  		t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
   171  	}
   172  
   173  	// Add the result and try again.
   174  	err = store.storeResult(3, results[3])
   175  	if err != nil {
   176  		t.Fatalf("unable to store result: %v", err)
   177  	}
   178  
   179  	_, err = store.getResult(3)
   180  	if err != nil {
   181  		t.Fatalf("unable to get result: %v", err)
   182  	}
   183  
   184  	// Since we don't delete results from the store (yet), make sure we
   185  	// will get subscriptions for all of them.
   186  	for i := uint64(0); i < numResults; i++ {
   187  		sub, err := store.subscribeResult(i)
   188  		if err != nil {
   189  			t.Fatalf("unable to subscribe: %v", err)
   190  		}
   191  
   192  		select {
   193  		case <-sub:
   194  		case <-time.After(1 * time.Second):
   195  			t.Fatalf("no result received")
   196  		}
   197  	}
   198  
   199  	// Clean the store keeping the first two results.
   200  	toKeep := map[uint64]struct{}{
   201  		0: {},
   202  		1: {},
   203  	}
   204  	// Finally, delete the result.
   205  	err = store.cleanStore(toKeep)
   206  	require.NoError(t, err)
   207  
   208  	// Payment IDs 0 and 1 should be found, 2 and 3 should be deleted.
   209  	for i := uint64(0); i < numResults; i++ {
   210  		_, err = store.getResult(i)
   211  		if i <= 1 {
   212  			require.NoError(t, err, "unable to get result")
   213  		}
   214  		if i >= 2 && err != ErrPaymentIDNotFound {
   215  			t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
   216  		}
   217  
   218  	}
   219  }