github.com/decred/dcrlnd@v0.7.6/discovery/message_store_test.go (about)

     1  package discovery
     2  
     3  import (
     4  	"bytes"
     5  	"io/ioutil"
     6  	"math/rand"
     7  	"os"
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/davecgh/go-spew/spew"
    12  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    13  	"github.com/decred/dcrlnd/channeldb"
    14  	"github.com/decred/dcrlnd/kvdb"
    15  	"github.com/decred/dcrlnd/lnwire"
    16  )
    17  
    18  func createTestMessageStore(t *testing.T) (*MessageStore, func()) {
    19  	t.Helper()
    20  
    21  	tempDir, err := ioutil.TempDir("", "channeldb")
    22  	if err != nil {
    23  		t.Fatalf("unable to create temp dir: %v", err)
    24  	}
    25  	db, err := channeldb.Open(tempDir)
    26  	if err != nil {
    27  		os.RemoveAll(tempDir)
    28  		t.Fatalf("unable to open db: %v", err)
    29  	}
    30  
    31  	cleanUp := func() {
    32  		db.Close()
    33  		os.RemoveAll(tempDir)
    34  	}
    35  
    36  	store, err := NewMessageStore(db)
    37  	if err != nil {
    38  		cleanUp()
    39  		t.Fatalf("unable to initialize message store: %v", err)
    40  	}
    41  
    42  	return store, cleanUp
    43  }
    44  
    45  func randPubKey(t *testing.T) *secp256k1.PublicKey {
    46  	priv, err := secp256k1.GeneratePrivateKey()
    47  	if err != nil {
    48  		t.Fatalf("unable to create private key: %v", err)
    49  	}
    50  
    51  	return priv.PubKey()
    52  }
    53  
    54  func randCompressedPubKey(t *testing.T) [33]byte {
    55  	t.Helper()
    56  
    57  	pubKey := randPubKey(t)
    58  
    59  	var compressedPubKey [33]byte
    60  	copy(compressedPubKey[:], pubKey.SerializeCompressed())
    61  
    62  	return compressedPubKey
    63  }
    64  
    65  func randAnnounceSignatures() *lnwire.AnnounceSignatures {
    66  	return &lnwire.AnnounceSignatures{
    67  		ShortChannelID:  lnwire.NewShortChanIDFromInt(rand.Uint64()),
    68  		ExtraOpaqueData: make([]byte, 0),
    69  	}
    70  }
    71  
    72  func randChannelUpdate() *lnwire.ChannelUpdate {
    73  	return &lnwire.ChannelUpdate{
    74  		ShortChannelID:  lnwire.NewShortChanIDFromInt(rand.Uint64()),
    75  		ExtraOpaqueData: make([]byte, 0),
    76  	}
    77  }
    78  
    79  // TestMessageStoreMessages ensures that messages can be properly queried from
    80  // the store.
    81  func TestMessageStoreMessages(t *testing.T) {
    82  	t.Parallel()
    83  
    84  	// We'll start by creating our test message store.
    85  	msgStore, cleanUp := createTestMessageStore(t)
    86  	defer cleanUp()
    87  
    88  	// We'll then create some test messages for two test peers, and none for
    89  	// an additional test peer.
    90  	channelUpdate1 := randChannelUpdate()
    91  	announceSignatures1 := randAnnounceSignatures()
    92  	peer1 := randCompressedPubKey(t)
    93  	if err := msgStore.AddMessage(channelUpdate1, peer1); err != nil {
    94  		t.Fatalf("unable to add message: %v", err)
    95  	}
    96  	if err := msgStore.AddMessage(announceSignatures1, peer1); err != nil {
    97  		t.Fatalf("unable to add message: %v", err)
    98  	}
    99  	expectedPeerMsgs1 := map[uint64]lnwire.MessageType{
   100  		channelUpdate1.ShortChannelID.ToUint64():      channelUpdate1.MsgType(),
   101  		announceSignatures1.ShortChannelID.ToUint64(): announceSignatures1.MsgType(),
   102  	}
   103  
   104  	channelUpdate2 := randChannelUpdate()
   105  	peer2 := randCompressedPubKey(t)
   106  	if err := msgStore.AddMessage(channelUpdate2, peer2); err != nil {
   107  		t.Fatalf("unable to add message: %v", err)
   108  	}
   109  	expectedPeerMsgs2 := map[uint64]lnwire.MessageType{
   110  		channelUpdate2.ShortChannelID.ToUint64(): channelUpdate2.MsgType(),
   111  	}
   112  
   113  	peer3 := randCompressedPubKey(t)
   114  	expectedPeerMsgs3 := map[uint64]lnwire.MessageType{}
   115  
   116  	// assertPeerMsgs is a helper closure that we'll use to ensure we
   117  	// retrieve the correct set of messages for a given peer.
   118  	assertPeerMsgs := func(peerMsgs []lnwire.Message,
   119  		expected map[uint64]lnwire.MessageType) {
   120  
   121  		t.Helper()
   122  
   123  		if len(peerMsgs) != len(expected) {
   124  			t.Fatalf("expected %d pending messages, got %d",
   125  				len(expected), len(peerMsgs))
   126  		}
   127  		for _, msg := range peerMsgs {
   128  			var shortChanID uint64
   129  			switch msg := msg.(type) {
   130  			case *lnwire.AnnounceSignatures:
   131  				shortChanID = msg.ShortChannelID.ToUint64()
   132  			case *lnwire.ChannelUpdate:
   133  				shortChanID = msg.ShortChannelID.ToUint64()
   134  			default:
   135  				t.Fatalf("found unexpected message type %T", msg)
   136  			}
   137  
   138  			msgType, ok := expected[shortChanID]
   139  			if !ok {
   140  				t.Fatalf("retrieved message with unexpected ID "+
   141  					"%d from store", shortChanID)
   142  			}
   143  			if msgType != msg.MsgType() {
   144  				t.Fatalf("expected message of type %v, got %v",
   145  					msg.MsgType(), msgType)
   146  			}
   147  		}
   148  	}
   149  
   150  	// Then, we'll query the store for the set of messages for each peer and
   151  	// ensure it matches what we expect.
   152  	peers := [][33]byte{peer1, peer2, peer3}
   153  	expectedPeerMsgs := []map[uint64]lnwire.MessageType{
   154  		expectedPeerMsgs1, expectedPeerMsgs2, expectedPeerMsgs3,
   155  	}
   156  	for i, peer := range peers {
   157  		peerMsgs, err := msgStore.MessagesForPeer(peer)
   158  		if err != nil {
   159  			t.Fatalf("unable to retrieve messages: %v", err)
   160  		}
   161  		assertPeerMsgs(peerMsgs, expectedPeerMsgs[i])
   162  	}
   163  
   164  	// Finally, we'll query the store for all of its messages of every peer.
   165  	// Again, each peer should have a set of messages that match what we
   166  	// expect.
   167  	//
   168  	// We'll construct the expected response. Only the first two peers will
   169  	// have messages.
   170  	totalPeerMsgs := make(map[[33]byte]map[uint64]lnwire.MessageType, 2)
   171  	for i := 0; i < 2; i++ {
   172  		totalPeerMsgs[peers[i]] = expectedPeerMsgs[i]
   173  	}
   174  
   175  	msgs, err := msgStore.Messages()
   176  	if err != nil {
   177  		t.Fatalf("unable to retrieve all peers with pending messages: "+
   178  			"%v", err)
   179  	}
   180  	if len(msgs) != len(totalPeerMsgs) {
   181  		t.Fatalf("expected %d peers with messages, got %d",
   182  			len(totalPeerMsgs), len(msgs))
   183  	}
   184  	for peer, peerMsgs := range msgs {
   185  		expected, ok := totalPeerMsgs[peer]
   186  		if !ok {
   187  			t.Fatalf("expected to find pending messages for peer %x",
   188  				peer)
   189  		}
   190  
   191  		assertPeerMsgs(peerMsgs, expected)
   192  	}
   193  
   194  	peerPubKeys, err := msgStore.Peers()
   195  	if err != nil {
   196  		t.Fatalf("unable to retrieve all peers with pending messages: "+
   197  			"%v", err)
   198  	}
   199  	if len(peerPubKeys) != len(totalPeerMsgs) {
   200  		t.Fatalf("expected %d peers with messages, got %d",
   201  			len(totalPeerMsgs), len(peerPubKeys))
   202  	}
   203  	for peerPubKey := range peerPubKeys {
   204  		if _, ok := totalPeerMsgs[peerPubKey]; !ok {
   205  			t.Fatalf("expected to find peer %x", peerPubKey)
   206  		}
   207  	}
   208  }
   209  
   210  // TestMessageStoreUnsupportedMessage ensures that we are not able to add a
   211  // message which is unsupported, and if a message is found to be unsupported by
   212  // the current version of the store, that it is properly filtered out from the
   213  // response.
   214  func TestMessageStoreUnsupportedMessage(t *testing.T) {
   215  	t.Parallel()
   216  
   217  	// We'll start by creating our test message store.
   218  	msgStore, cleanUp := createTestMessageStore(t)
   219  	defer cleanUp()
   220  
   221  	// Create a message that is known to not be supported by the store.
   222  	peer := randCompressedPubKey(t)
   223  	unsupportedMsg := &lnwire.Error{}
   224  
   225  	// Attempting to add it to the store should result in
   226  	// ErrUnsupportedMessage.
   227  	err := msgStore.AddMessage(unsupportedMsg, peer)
   228  	if err != ErrUnsupportedMessage {
   229  		t.Fatalf("expected ErrUnsupportedMessage, got %v", err)
   230  	}
   231  
   232  	// We'll now pretend that the message is actually supported in a future
   233  	// version of the store, so it's able to be added successfully. To
   234  	// replicate this, we'll add the message manually rather than through
   235  	// the existing AddMessage method.
   236  	msgKey := peer[:]
   237  	var rawMsg bytes.Buffer
   238  	if _, err := lnwire.WriteMessage(&rawMsg, unsupportedMsg, 0); err != nil {
   239  		t.Fatalf("unable to serialize message: %v", err)
   240  	}
   241  	err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error {
   242  		messageStore := tx.ReadWriteBucket(messageStoreBucket)
   243  		return messageStore.Put(msgKey, rawMsg.Bytes())
   244  	}, func() {})
   245  	if err != nil {
   246  		t.Fatalf("unable to add unsupported message to store: %v", err)
   247  	}
   248  
   249  	// Finally, we'll check that the store can properly filter out messages
   250  	// that are currently unknown to it. We'll make sure this is done for
   251  	// both Messages and MessagesForPeer.
   252  	totalMsgs, err := msgStore.Messages()
   253  	if err != nil {
   254  		t.Fatalf("unable to retrieve messages: %v", err)
   255  	}
   256  	if len(totalMsgs) != 0 {
   257  		t.Fatalf("expected to filter out unsupported message")
   258  	}
   259  	peerMsgs, err := msgStore.MessagesForPeer(peer)
   260  	if err != nil {
   261  		t.Fatalf("unable to retrieve peer messages: %v", err)
   262  	}
   263  	if len(peerMsgs) != 0 {
   264  		t.Fatalf("expected to filter out unsupported message")
   265  	}
   266  }
   267  
   268  // TestMessageStoreDeleteMessage ensures that we can properly delete messages
   269  // from the store.
   270  func TestMessageStoreDeleteMessage(t *testing.T) {
   271  	t.Parallel()
   272  
   273  	msgStore, cleanUp := createTestMessageStore(t)
   274  	defer cleanUp()
   275  
   276  	// assertMsg is a helper closure we'll use to ensure a message
   277  	// does/doesn't exist within the store.
   278  	assertMsg := func(msg lnwire.Message, peer [33]byte, exists bool) {
   279  		t.Helper()
   280  
   281  		storeMsgs, err := msgStore.MessagesForPeer(peer)
   282  		if err != nil {
   283  			t.Fatalf("unable to retrieve messages: %v", err)
   284  		}
   285  
   286  		found := false
   287  		for _, storeMsg := range storeMsgs {
   288  			if reflect.DeepEqual(msg, storeMsg) {
   289  				found = true
   290  			}
   291  		}
   292  
   293  		if found != exists {
   294  			str := "find"
   295  			if !exists {
   296  				str = "not find"
   297  			}
   298  			t.Fatalf("expected to %v message %v", str,
   299  				spew.Sdump(msg))
   300  		}
   301  	}
   302  
   303  	// An AnnounceSignatures message should exist within the store after
   304  	// adding it, and should no longer exists after deleting it.
   305  	peer := randCompressedPubKey(t)
   306  	annSig := randAnnounceSignatures()
   307  	if err := msgStore.AddMessage(annSig, peer); err != nil {
   308  		t.Fatalf("unable to add message: %v", err)
   309  	}
   310  	assertMsg(annSig, peer, true)
   311  	if err := msgStore.DeleteMessage(annSig, peer); err != nil {
   312  		t.Fatalf("unable to delete message: %v", err)
   313  	}
   314  	assertMsg(annSig, peer, false)
   315  
   316  	// The store allows overwriting ChannelUpdates, since there can be
   317  	// multiple versions, so we'll test things slightly different.
   318  	//
   319  	// The ChannelUpdate message should exist within the store after adding
   320  	// it.
   321  	chanUpdate := randChannelUpdate()
   322  	if err := msgStore.AddMessage(chanUpdate, peer); err != nil {
   323  		t.Fatalf("unable to add message: %v", err)
   324  	}
   325  	assertMsg(chanUpdate, peer, true)
   326  
   327  	// Now, we'll create a new version for the same ChannelUpdate message.
   328  	// Adding this one to the store will overwrite the previous one, so only
   329  	// the new one should exist.
   330  	newChanUpdate := randChannelUpdate()
   331  	newChanUpdate.ShortChannelID = chanUpdate.ShortChannelID
   332  	newChanUpdate.Timestamp = chanUpdate.Timestamp + 1
   333  	if err := msgStore.AddMessage(newChanUpdate, peer); err != nil {
   334  		t.Fatalf("unable to add message: %v", err)
   335  	}
   336  	assertMsg(chanUpdate, peer, false)
   337  	assertMsg(newChanUpdate, peer, true)
   338  
   339  	// Deleting the older message should act as a NOP and should NOT delete
   340  	// the newer version as the older no longer exists.
   341  	if err := msgStore.DeleteMessage(chanUpdate, peer); err != nil {
   342  		t.Fatalf("unable to delete message: %v", err)
   343  	}
   344  	assertMsg(chanUpdate, peer, false)
   345  	assertMsg(newChanUpdate, peer, true)
   346  
   347  	// The newer version should no longer exist within the store after
   348  	// deleting it.
   349  	if err := msgStore.DeleteMessage(newChanUpdate, peer); err != nil {
   350  		t.Fatalf("unable to delete message: %v", err)
   351  	}
   352  	assertMsg(newChanUpdate, peer, false)
   353  }