github.com/decred/dcrlnd@v0.7.6/record/record_test.go (about)

     1  package record_test
     2  
     3  import (
     4  	"bytes"
     5  	"testing"
     6  
     7  	"github.com/decred/dcrlnd/lnwire"
     8  	"github.com/decred/dcrlnd/record"
     9  	"github.com/decred/dcrlnd/tlv"
    10  )
    11  
    12  type recordEncDecTest struct {
    13  	name      string
    14  	encRecord func() tlv.RecordProducer
    15  	decRecord func() tlv.RecordProducer
    16  	assert    func(*testing.T, interface{})
    17  }
    18  
    19  var (
    20  	testTotal      = lnwire.MilliAtom(45)
    21  	testAddr       = [32]byte{0x01, 0x02}
    22  	testShare      = [32]byte{0x03, 0x04}
    23  	testSetID      = [32]byte{0x05, 0x06}
    24  	testChildIndex = uint32(17)
    25  )
    26  
    27  var recordEncDecTests = []recordEncDecTest{
    28  	{
    29  		name: "mpp",
    30  		encRecord: func() tlv.RecordProducer {
    31  			return record.NewMPP(testTotal, testAddr)
    32  		},
    33  		decRecord: func() tlv.RecordProducer {
    34  			return new(record.MPP)
    35  		},
    36  		assert: func(t *testing.T, r interface{}) {
    37  			mpp := r.(*record.MPP)
    38  			if mpp.TotalMAtoms() != testTotal {
    39  				t.Fatal("incorrect total msat")
    40  			}
    41  			if mpp.PaymentAddr() != testAddr {
    42  				t.Fatal("incorrect payment addr")
    43  			}
    44  		},
    45  	},
    46  	{
    47  		name: "amp",
    48  		encRecord: func() tlv.RecordProducer {
    49  			return record.NewAMP(
    50  				testShare, testSetID, testChildIndex,
    51  			)
    52  		},
    53  		decRecord: func() tlv.RecordProducer {
    54  			return new(record.AMP)
    55  		},
    56  		assert: func(t *testing.T, r interface{}) {
    57  			amp := r.(*record.AMP)
    58  			if amp.RootShare() != testShare {
    59  				t.Fatal("incorrect root share")
    60  			}
    61  			if amp.SetID() != testSetID {
    62  				t.Fatal("incorrect set id")
    63  			}
    64  			if amp.ChildIndex() != testChildIndex {
    65  				t.Fatal("incorrect child index")
    66  			}
    67  		},
    68  	},
    69  }
    70  
    71  // TestRecordEncodeDecode is a generic test framework for custom TLV records. It
    72  // asserts that records can encode and decode themselves, and that the value of
    73  // the original record matches the decoded record.
    74  func TestRecordEncodeDecode(t *testing.T) {
    75  	for _, test := range recordEncDecTests {
    76  		test := test
    77  		t.Run(test.name, func(t *testing.T) {
    78  			r := test.encRecord()
    79  			r2 := test.decRecord()
    80  			encStream := tlv.MustNewStream(r.Record())
    81  			decStream := tlv.MustNewStream(r2.Record())
    82  
    83  			test.assert(t, r)
    84  
    85  			var b bytes.Buffer
    86  			err := encStream.Encode(&b)
    87  			if err != nil {
    88  				t.Fatalf("unable to encode record: %v", err)
    89  			}
    90  
    91  			err = decStream.Decode(bytes.NewReader(b.Bytes()))
    92  			if err != nil {
    93  				t.Fatalf("unable to decode record: %v", err)
    94  			}
    95  
    96  			test.assert(t, r2)
    97  		})
    98  	}
    99  }