github.com/decred/dcrlnd@v0.7.6/record/record_test.go (about) 1 package record_test 2 3 import ( 4 "bytes" 5 "testing" 6 7 "github.com/decred/dcrlnd/lnwire" 8 "github.com/decred/dcrlnd/record" 9 "github.com/decred/dcrlnd/tlv" 10 ) 11 12 type recordEncDecTest struct { 13 name string 14 encRecord func() tlv.RecordProducer 15 decRecord func() tlv.RecordProducer 16 assert func(*testing.T, interface{}) 17 } 18 19 var ( 20 testTotal = lnwire.MilliAtom(45) 21 testAddr = [32]byte{0x01, 0x02} 22 testShare = [32]byte{0x03, 0x04} 23 testSetID = [32]byte{0x05, 0x06} 24 testChildIndex = uint32(17) 25 ) 26 27 var recordEncDecTests = []recordEncDecTest{ 28 { 29 name: "mpp", 30 encRecord: func() tlv.RecordProducer { 31 return record.NewMPP(testTotal, testAddr) 32 }, 33 decRecord: func() tlv.RecordProducer { 34 return new(record.MPP) 35 }, 36 assert: func(t *testing.T, r interface{}) { 37 mpp := r.(*record.MPP) 38 if mpp.TotalMAtoms() != testTotal { 39 t.Fatal("incorrect total msat") 40 } 41 if mpp.PaymentAddr() != testAddr { 42 t.Fatal("incorrect payment addr") 43 } 44 }, 45 }, 46 { 47 name: "amp", 48 encRecord: func() tlv.RecordProducer { 49 return record.NewAMP( 50 testShare, testSetID, testChildIndex, 51 ) 52 }, 53 decRecord: func() tlv.RecordProducer { 54 return new(record.AMP) 55 }, 56 assert: func(t *testing.T, r interface{}) { 57 amp := r.(*record.AMP) 58 if amp.RootShare() != testShare { 59 t.Fatal("incorrect root share") 60 } 61 if amp.SetID() != testSetID { 62 t.Fatal("incorrect set id") 63 } 64 if amp.ChildIndex() != testChildIndex { 65 t.Fatal("incorrect child index") 66 } 67 }, 68 }, 69 } 70 71 // TestRecordEncodeDecode is a generic test framework for custom TLV records. It 72 // asserts that records can encode and decode themselves, and that the value of 73 // the original record matches the decoded record. 74 func TestRecordEncodeDecode(t *testing.T) { 75 for _, test := range recordEncDecTests { 76 test := test 77 t.Run(test.name, func(t *testing.T) { 78 r := test.encRecord() 79 r2 := test.decRecord() 80 encStream := tlv.MustNewStream(r.Record()) 81 decStream := tlv.MustNewStream(r2.Record()) 82 83 test.assert(t, r) 84 85 var b bytes.Buffer 86 err := encStream.Encode(&b) 87 if err != nil { 88 t.Fatalf("unable to encode record: %v", err) 89 } 90 91 err = decStream.Decode(bytes.NewReader(b.Bytes())) 92 if err != nil { 93 t.Fatalf("unable to decode record: %v", err) 94 } 95 96 test.assert(t, r2) 97 }) 98 } 99 }