github.com/decred/dcrlnd@v0.7.6/channeldb/migration_01_to_11/migration_11_invoices_test.go (about)

     1  package migration_01_to_11
     2  
     3  import (
     4  	"bytes"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/decred/dcrd/chaincfg/v3"
     9  	"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
    10  	"github.com/decred/dcrlnd/kvdb"
    11  	"github.com/decred/dcrlnd/zpay32"
    12  )
    13  
    14  var (
    15  	testPrivKeyBytes = []byte{
    16  		0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
    17  		0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
    18  		0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
    19  		0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
    20  	}
    21  
    22  	testCltvDelta = int32(50)
    23  )
    24  
    25  // beforeMigrationFuncV11 insert the test invoices in the database.
    26  func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) {
    27  	err := kvdb.Update(d, func(tx kvdb.RwTx) error {
    28  		invoicesBucket, err := tx.CreateTopLevelBucket(
    29  			invoiceBucket,
    30  		)
    31  		if err != nil {
    32  			return err
    33  		}
    34  
    35  		invoiceNum := uint32(1)
    36  		for _, invoice := range invoices {
    37  			var invoiceKey [4]byte
    38  			byteOrder.PutUint32(invoiceKey[:], invoiceNum)
    39  			invoiceNum++
    40  
    41  			var buf bytes.Buffer
    42  			err := serializeInvoiceLegacy(&buf, &invoice) // nolint:scopelint
    43  			if err != nil {
    44  				return err
    45  			}
    46  
    47  			err = invoicesBucket.Put(
    48  				invoiceKey[:], buf.Bytes(),
    49  			)
    50  			if err != nil {
    51  				return err
    52  			}
    53  		}
    54  
    55  		return nil
    56  	}, func() {})
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  }
    61  
    62  // TestMigrateInvoices checks that invoices are migrated correctly.
    63  func TestMigrateInvoices(t *testing.T) {
    64  	t.Parallel()
    65  
    66  	payReqBtc, err := getPayReq(chaincfg.MainNetParams())
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  
    71  	invoices := []Invoice{
    72  		{
    73  			PaymentRequest: []byte(payReqBtc),
    74  		},
    75  	}
    76  
    77  	// Verify that all invoices were migrated.
    78  	afterMigrationFunc := func(d *DB) {
    79  		dbInvoices, err := d.FetchAllInvoices(false)
    80  		if err != nil {
    81  			t.Fatalf("unable to fetch invoices: %v", err)
    82  		}
    83  
    84  		if len(invoices) != len(dbInvoices) {
    85  			t.Fatalf("expected %d invoices, got %d", len(invoices),
    86  				len(dbInvoices))
    87  		}
    88  
    89  		for _, dbInvoice := range dbInvoices {
    90  			if dbInvoice.FinalCltvDelta != testCltvDelta {
    91  				t.Fatal("incorrect final cltv delta")
    92  			}
    93  			if dbInvoice.Expiry != 3600*time.Second {
    94  				t.Fatal("incorrect expiry")
    95  			}
    96  			if len(dbInvoice.Htlcs) != 0 {
    97  				t.Fatal("expected no htlcs after migration")
    98  			}
    99  		}
   100  	}
   101  
   102  	applyMigration(t,
   103  		func(d *DB) { beforeMigrationFuncV11(t, d, invoices) },
   104  		afterMigrationFunc,
   105  		MigrateInvoices,
   106  		false)
   107  }
   108  
   109  // TestMigrateInvoicesHodl checks that a hodl invoice in the accepted state
   110  // fails the migration.
   111  func TestMigrateInvoicesHodl(t *testing.T) {
   112  	t.Parallel()
   113  
   114  	payReqBtc, err := getPayReq(chaincfg.MainNetParams())
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  
   119  	invoices := []Invoice{
   120  		{
   121  			PaymentRequest: []byte(payReqBtc),
   122  			Terms: ContractTerm{
   123  				State: ContractAccepted,
   124  			},
   125  		},
   126  	}
   127  
   128  	applyMigration(t,
   129  		func(d *DB) { beforeMigrationFuncV11(t, d, invoices) },
   130  		func(d *DB) {},
   131  		MigrateInvoices,
   132  		true)
   133  }
   134  
   135  // signDigestCompact generates a test signature to be used in the generation of
   136  // test payment requests.
   137  func signDigestCompact(hash []byte) ([]byte, error) {
   138  	// Should the signature reference a compressed public key or not.
   139  	isCompressedKey := true
   140  
   141  	privKey, _ := privKeyFromBytes(testPrivKeyBytes)
   142  
   143  	// secp256k1.SignCompact returns a pubkey-recoverable signature
   144  	sig := ecdsa.SignCompact(
   145  		privKey, hash, isCompressedKey,
   146  	)
   147  
   148  	return sig, nil
   149  }
   150  
   151  // getPayReq creates a payment request for the given net.
   152  func getPayReq(net *chaincfg.Params) (string, error) {
   153  	options := []func(*zpay32.Invoice){
   154  		zpay32.CLTVExpiry(uint64(testCltvDelta)),
   155  		zpay32.Description("test"),
   156  	}
   157  
   158  	payReq, err := zpay32.NewInvoice(
   159  		net, [32]byte{}, time.Unix(1, 0), options...,
   160  	)
   161  	if err != nil {
   162  		return "", err
   163  	}
   164  	return payReq.Encode(
   165  		zpay32.MessageSigner{
   166  			SignCompact: signDigestCompact,
   167  		},
   168  	)
   169  }