github.com/koko1123/flow-go-1@v0.29.6/engine/common/provider/engine_test.go (about)

     1  package provider_test
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/vmihailenco/msgpack"
    13  
    14  	"github.com/koko1123/flow-go-1/engine/common/provider"
    15  	"github.com/koko1123/flow-go-1/model/flow"
    16  	"github.com/koko1123/flow-go-1/model/flow/filter"
    17  	"github.com/koko1123/flow-go-1/model/messages"
    18  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    19  	"github.com/koko1123/flow-go-1/module/mempool/queue"
    20  	"github.com/koko1123/flow-go-1/module/metrics"
    21  	mockmodule "github.com/koko1123/flow-go-1/module/mock"
    22  	"github.com/koko1123/flow-go-1/network/channels"
    23  	"github.com/koko1123/flow-go-1/network/mocknetwork"
    24  	protocol "github.com/koko1123/flow-go-1/state/protocol/mock"
    25  	"github.com/koko1123/flow-go-1/storage"
    26  	"github.com/koko1123/flow-go-1/utils/unittest"
    27  )
    28  
    29  func TestOnEntityRequestFull(t *testing.T) {
    30  	cancelCtx, cancel := context.WithCancel(context.Background())
    31  	defer cancel()
    32  	ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx)
    33  
    34  	entities := make(map[flow.Identifier]flow.Entity)
    35  
    36  	identities := unittest.IdentityListFixture(8)
    37  	selector := filter.HasNodeID(identities.NodeIDs()...)
    38  	originID := identities[0].NodeID
    39  
    40  	coll1 := unittest.CollectionFixture(1)
    41  	coll2 := unittest.CollectionFixture(2)
    42  	coll3 := unittest.CollectionFixture(3)
    43  	coll4 := unittest.CollectionFixture(4)
    44  	coll5 := unittest.CollectionFixture(5)
    45  
    46  	entities[coll1.ID()] = coll1
    47  	entities[coll2.ID()] = coll2
    48  	entities[coll3.ID()] = coll3
    49  	entities[coll4.ID()] = coll4
    50  	entities[coll5.ID()] = coll5
    51  
    52  	retrieve := func(entityID flow.Identifier) (flow.Entity, error) {
    53  		entity, ok := entities[entityID]
    54  		if !ok {
    55  			return nil, storage.ErrNotFound
    56  		}
    57  		return entity, nil
    58  	}
    59  
    60  	final := protocol.NewSnapshot(t)
    61  	final.On("Identities", mock.Anything).Return(
    62  		func(selector flow.IdentityFilter) flow.IdentityList {
    63  			return identities.Filter(selector)
    64  		},
    65  		nil,
    66  	)
    67  
    68  	state := protocol.NewState(t)
    69  	state.On("Final").Return(final, nil)
    70  
    71  	net := mocknetwork.NewNetwork(t)
    72  	con := mocknetwork.NewConduit(t)
    73  	net.On("Register", mock.Anything, mock.Anything).Return(con, nil)
    74  	con.On("Unicast", mock.Anything, mock.Anything).Run(
    75  		func(args mock.Arguments) {
    76  			defer cancel()
    77  
    78  			response := args.Get(0).(*messages.EntityResponse)
    79  			nodeID := args.Get(1).(flow.Identifier)
    80  			assert.Equal(t, nodeID, originID)
    81  			var entities []flow.Entity
    82  			for _, blob := range response.Blobs {
    83  				coll := &flow.Collection{}
    84  				_ = msgpack.Unmarshal(blob, &coll)
    85  				entities = append(entities, coll)
    86  			}
    87  			assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll2, &coll3, &coll4, &coll5})
    88  		},
    89  	).Return(nil)
    90  
    91  	me := mockmodule.NewLocal(t)
    92  	me.On("NodeID").Return(unittest.IdentifierFixture())
    93  	requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector())
    94  
    95  	e, err := provider.New(
    96  		unittest.Logger(),
    97  		metrics.NewNoopCollector(),
    98  		net,
    99  		me,
   100  		state,
   101  		requestQueue,
   102  		provider.DefaultRequestProviderWorkers,
   103  		channels.TestNetworkChannel,
   104  		selector,
   105  		retrieve)
   106  	require.NoError(t, err)
   107  
   108  	request := &messages.EntityRequest{
   109  		Nonce:     rand.Uint64(),
   110  		EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()},
   111  	}
   112  
   113  	e.Start(ctx)
   114  
   115  	unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine")
   116  
   117  	err = e.Process(channels.TestNetworkChannel, originID, request)
   118  	require.NoError(t, err, "should not error on full response")
   119  
   120  	unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine")
   121  }
   122  
   123  func TestOnEntityRequestPartial(t *testing.T) {
   124  	cancelCtx, cancel := context.WithCancel(context.Background())
   125  	defer cancel()
   126  	ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx)
   127  
   128  	entities := make(map[flow.Identifier]flow.Entity)
   129  
   130  	identities := unittest.IdentityListFixture(8)
   131  	selector := filter.HasNodeID(identities.NodeIDs()...)
   132  	originID := identities[0].NodeID
   133  
   134  	coll1 := unittest.CollectionFixture(1)
   135  	coll2 := unittest.CollectionFixture(2)
   136  	coll3 := unittest.CollectionFixture(3)
   137  	coll4 := unittest.CollectionFixture(4)
   138  	coll5 := unittest.CollectionFixture(5)
   139  
   140  	entities[coll1.ID()] = coll1
   141  	// entities[coll2.ID()] = coll2
   142  	entities[coll3.ID()] = coll3
   143  	// entities[coll4.ID()] = coll4
   144  	entities[coll5.ID()] = coll5
   145  
   146  	retrieve := func(entityID flow.Identifier) (flow.Entity, error) {
   147  		entity, ok := entities[entityID]
   148  		if !ok {
   149  			return nil, storage.ErrNotFound
   150  		}
   151  		return entity, nil
   152  	}
   153  
   154  	final := protocol.NewSnapshot(t)
   155  	final.On("Identities", mock.Anything).Return(
   156  		func(selector flow.IdentityFilter) flow.IdentityList {
   157  			return identities.Filter(selector)
   158  		},
   159  		nil,
   160  	)
   161  
   162  	state := protocol.NewState(t)
   163  	state.On("Final").Return(final, nil)
   164  
   165  	net := mocknetwork.NewNetwork(t)
   166  	con := mocknetwork.NewConduit(t)
   167  	net.On("Register", mock.Anything, mock.Anything).Return(con, nil)
   168  	con.On("Unicast", mock.Anything, mock.Anything).Run(
   169  		func(args mock.Arguments) {
   170  			defer cancel()
   171  
   172  			response := args.Get(0).(*messages.EntityResponse)
   173  			nodeID := args.Get(1).(flow.Identifier)
   174  			assert.Equal(t, nodeID, originID)
   175  			var entities []flow.Entity
   176  			for _, blob := range response.Blobs {
   177  				coll := &flow.Collection{}
   178  				_ = msgpack.Unmarshal(blob, &coll)
   179  				entities = append(entities, coll)
   180  			}
   181  			assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll3, &coll5})
   182  		},
   183  	).Return(nil)
   184  
   185  	me := mockmodule.NewLocal(t)
   186  	me.On("NodeID").Return(unittest.IdentifierFixture())
   187  	requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector())
   188  
   189  	e, err := provider.New(
   190  		unittest.Logger(),
   191  		metrics.NewNoopCollector(),
   192  		net,
   193  		me,
   194  		state,
   195  		requestQueue,
   196  		provider.DefaultRequestProviderWorkers,
   197  		channels.TestNetworkChannel,
   198  		selector,
   199  		retrieve)
   200  	require.NoError(t, err)
   201  
   202  	request := &messages.EntityRequest{
   203  		Nonce:     rand.Uint64(),
   204  		EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()},
   205  	}
   206  
   207  	e.Start(ctx)
   208  
   209  	unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine")
   210  	err = e.Process(channels.TestNetworkChannel, originID, request)
   211  	require.NoError(t, err, "should not error on full response")
   212  	unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine")
   213  }
   214  
   215  func TestOnEntityRequestDuplicates(t *testing.T) {
   216  	cancelCtx, cancel := context.WithCancel(context.Background())
   217  	defer cancel()
   218  	ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx)
   219  
   220  	entities := make(map[flow.Identifier]flow.Entity)
   221  
   222  	identities := unittest.IdentityListFixture(8)
   223  	selector := filter.HasNodeID(identities.NodeIDs()...)
   224  	originID := identities[0].NodeID
   225  
   226  	coll1 := unittest.CollectionFixture(1)
   227  	coll2 := unittest.CollectionFixture(2)
   228  	coll3 := unittest.CollectionFixture(3)
   229  
   230  	entities[coll1.ID()] = coll1
   231  	entities[coll2.ID()] = coll2
   232  	entities[coll3.ID()] = coll3
   233  
   234  	retrieve := func(entityID flow.Identifier) (flow.Entity, error) {
   235  		entity, ok := entities[entityID]
   236  		if !ok {
   237  			return nil, storage.ErrNotFound
   238  		}
   239  		return entity, nil
   240  	}
   241  
   242  	final := protocol.NewSnapshot(t)
   243  	final.On("Identities", mock.Anything).Return(
   244  		func(selector flow.IdentityFilter) flow.IdentityList {
   245  			return identities.Filter(selector)
   246  		},
   247  		nil,
   248  	)
   249  
   250  	state := protocol.NewState(t)
   251  	state.On("Final").Return(final, nil)
   252  
   253  	net := mocknetwork.NewNetwork(t)
   254  	con := mocknetwork.NewConduit(t)
   255  	net.On("Register", mock.Anything, mock.Anything).Return(con, nil)
   256  	con.On("Unicast", mock.Anything, mock.Anything).Run(
   257  		func(args mock.Arguments) {
   258  			defer cancel()
   259  
   260  			response := args.Get(0).(*messages.EntityResponse)
   261  			nodeID := args.Get(1).(flow.Identifier)
   262  			assert.Equal(t, nodeID, originID)
   263  			var entities []flow.Entity
   264  			for _, blob := range response.Blobs {
   265  				coll := &flow.Collection{}
   266  				_ = msgpack.Unmarshal(blob, &coll)
   267  				entities = append(entities, coll)
   268  			}
   269  			assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll2, &coll3})
   270  		},
   271  	).Return(nil)
   272  
   273  	me := mockmodule.NewLocal(t)
   274  	me.On("NodeID").Return(unittest.IdentifierFixture())
   275  	requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector())
   276  
   277  	e, err := provider.New(
   278  		unittest.Logger(),
   279  		metrics.NewNoopCollector(),
   280  		net,
   281  		me,
   282  		state,
   283  		requestQueue,
   284  		provider.DefaultRequestProviderWorkers,
   285  		channels.TestNetworkChannel,
   286  		selector,
   287  		retrieve)
   288  	require.NoError(t, err)
   289  
   290  	// create entity requests with some duplicate entity IDs
   291  	request := &messages.EntityRequest{
   292  		Nonce:     rand.Uint64(),
   293  		EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll3.ID(), coll2.ID(), coll1.ID()},
   294  	}
   295  
   296  	e.Start(ctx)
   297  	unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine")
   298  	err = e.Process(channels.TestNetworkChannel, originID, request)
   299  	require.NoError(t, err, "should not error on full response")
   300  	unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine")
   301  }
   302  
   303  func TestOnEntityRequestEmpty(t *testing.T) {
   304  	cancelCtx, cancel := context.WithCancel(context.Background())
   305  	defer cancel()
   306  	ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx)
   307  
   308  	entities := make(map[flow.Identifier]flow.Entity)
   309  	identities := unittest.IdentityListFixture(8)
   310  	selector := filter.HasNodeID(identities.NodeIDs()...)
   311  	originID := identities[0].NodeID
   312  
   313  	coll1 := unittest.CollectionFixture(1)
   314  	coll2 := unittest.CollectionFixture(2)
   315  	coll3 := unittest.CollectionFixture(3)
   316  	coll4 := unittest.CollectionFixture(4)
   317  	coll5 := unittest.CollectionFixture(5)
   318  
   319  	retrieve := func(entityID flow.Identifier) (flow.Entity, error) {
   320  		entity, ok := entities[entityID]
   321  		if !ok {
   322  			return nil, storage.ErrNotFound
   323  		}
   324  		return entity, nil
   325  	}
   326  
   327  	final := protocol.NewSnapshot(t)
   328  	final.On("Identities", mock.Anything).Return(
   329  		func(selector flow.IdentityFilter) flow.IdentityList {
   330  			return identities.Filter(selector)
   331  		},
   332  		nil,
   333  	)
   334  
   335  	state := protocol.NewState(t)
   336  	state.On("Final").Return(final, nil)
   337  
   338  	net := mocknetwork.NewNetwork(t)
   339  	con := mocknetwork.NewConduit(t)
   340  	net.On("Register", mock.Anything, mock.Anything).Return(con, nil)
   341  	con.On("Unicast", mock.Anything, mock.Anything).Run(
   342  		func(args mock.Arguments) {
   343  			defer cancel()
   344  
   345  			response := args.Get(0).(*messages.EntityResponse)
   346  			nodeID := args.Get(1).(flow.Identifier)
   347  			assert.Equal(t, nodeID, originID)
   348  			assert.Empty(t, response.Blobs)
   349  		},
   350  	).Return(nil)
   351  
   352  	me := mockmodule.NewLocal(t)
   353  	me.On("NodeID").Return(unittest.IdentifierFixture())
   354  	requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector())
   355  
   356  	e, err := provider.New(
   357  		unittest.Logger(),
   358  		metrics.NewNoopCollector(),
   359  		net,
   360  		me,
   361  		state,
   362  		requestQueue,
   363  		provider.DefaultRequestProviderWorkers,
   364  		channels.TestNetworkChannel,
   365  		selector,
   366  		retrieve)
   367  	require.NoError(t, err)
   368  
   369  	request := &messages.EntityRequest{
   370  		Nonce:     rand.Uint64(),
   371  		EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()},
   372  	}
   373  
   374  	e.Start(ctx)
   375  	unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine")
   376  	err = e.Process(channels.TestNetworkChannel, originID, request)
   377  	require.NoError(t, err, "should not error on full response")
   378  	unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine")
   379  }
   380  
   381  func TestOnEntityRequestInvalidOrigin(t *testing.T) {
   382  	cancelCtx, cancel := context.WithCancel(context.Background())
   383  	defer cancel()
   384  	ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx)
   385  
   386  	entities := make(map[flow.Identifier]flow.Entity)
   387  	identities := unittest.IdentityListFixture(8)
   388  	selector := filter.HasNodeID(identities.NodeIDs()...)
   389  	originID := unittest.IdentifierFixture()
   390  
   391  	coll1 := unittest.CollectionFixture(1)
   392  	coll2 := unittest.CollectionFixture(2)
   393  	coll3 := unittest.CollectionFixture(3)
   394  	coll4 := unittest.CollectionFixture(4)
   395  	coll5 := unittest.CollectionFixture(5)
   396  
   397  	entities[coll1.ID()] = coll1
   398  	entities[coll2.ID()] = coll2
   399  	entities[coll3.ID()] = coll3
   400  	entities[coll4.ID()] = coll4
   401  	entities[coll5.ID()] = coll5
   402  
   403  	retrieve := func(entityID flow.Identifier) (flow.Entity, error) {
   404  		entity, ok := entities[entityID]
   405  		if !ok {
   406  			return nil, storage.ErrNotFound
   407  		}
   408  		return entity, nil
   409  	}
   410  
   411  	final := protocol.NewSnapshot(t)
   412  	final.On("Identities", mock.Anything).Return(
   413  		func(selector flow.IdentityFilter) flow.IdentityList {
   414  			defer cancel()
   415  			return identities.Filter(selector)
   416  		},
   417  		nil,
   418  	)
   419  
   420  	state := protocol.NewState(t)
   421  	state.On("Final").Return(final, nil)
   422  
   423  	net := mocknetwork.NewNetwork(t)
   424  	con := mocknetwork.NewConduit(t)
   425  	net.On("Register", mock.Anything, mock.Anything).Return(con, nil)
   426  	me := mockmodule.NewLocal(t)
   427  	me.On("NodeID").Return(unittest.IdentifierFixture())
   428  	requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector())
   429  
   430  	e, err := provider.New(
   431  		unittest.Logger(),
   432  		metrics.NewNoopCollector(),
   433  		net,
   434  		me,
   435  		state,
   436  		requestQueue,
   437  		provider.DefaultRequestProviderWorkers,
   438  		channels.TestNetworkChannel,
   439  		selector,
   440  		retrieve)
   441  	require.NoError(t, err)
   442  
   443  	request := &messages.EntityRequest{
   444  		Nonce:     rand.Uint64(),
   445  		EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()},
   446  	}
   447  
   448  	e.Start(ctx)
   449  	unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine")
   450  	err = e.Process(channels.TestNetworkChannel, originID, request)
   451  	require.NoError(t, err)
   452  	unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine")
   453  }