github.com/koko1123/flow-go-1@v0.29.6/engine/common/provider/engine_test.go (about) 1 package provider_test 2 3 import ( 4 "context" 5 "math/rand" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/mock" 11 "github.com/stretchr/testify/require" 12 "github.com/vmihailenco/msgpack" 13 14 "github.com/koko1123/flow-go-1/engine/common/provider" 15 "github.com/koko1123/flow-go-1/model/flow" 16 "github.com/koko1123/flow-go-1/model/flow/filter" 17 "github.com/koko1123/flow-go-1/model/messages" 18 "github.com/koko1123/flow-go-1/module/irrecoverable" 19 "github.com/koko1123/flow-go-1/module/mempool/queue" 20 "github.com/koko1123/flow-go-1/module/metrics" 21 mockmodule "github.com/koko1123/flow-go-1/module/mock" 22 "github.com/koko1123/flow-go-1/network/channels" 23 "github.com/koko1123/flow-go-1/network/mocknetwork" 24 protocol "github.com/koko1123/flow-go-1/state/protocol/mock" 25 "github.com/koko1123/flow-go-1/storage" 26 "github.com/koko1123/flow-go-1/utils/unittest" 27 ) 28 29 func TestOnEntityRequestFull(t *testing.T) { 30 cancelCtx, cancel := context.WithCancel(context.Background()) 31 defer cancel() 32 ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx) 33 34 entities := make(map[flow.Identifier]flow.Entity) 35 36 identities := unittest.IdentityListFixture(8) 37 selector := filter.HasNodeID(identities.NodeIDs()...) 38 originID := identities[0].NodeID 39 40 coll1 := unittest.CollectionFixture(1) 41 coll2 := unittest.CollectionFixture(2) 42 coll3 := unittest.CollectionFixture(3) 43 coll4 := unittest.CollectionFixture(4) 44 coll5 := unittest.CollectionFixture(5) 45 46 entities[coll1.ID()] = coll1 47 entities[coll2.ID()] = coll2 48 entities[coll3.ID()] = coll3 49 entities[coll4.ID()] = coll4 50 entities[coll5.ID()] = coll5 51 52 retrieve := func(entityID flow.Identifier) (flow.Entity, error) { 53 entity, ok := entities[entityID] 54 if !ok { 55 return nil, storage.ErrNotFound 56 } 57 return entity, nil 58 } 59 60 final := protocol.NewSnapshot(t) 61 final.On("Identities", mock.Anything).Return( 62 func(selector flow.IdentityFilter) flow.IdentityList { 63 return identities.Filter(selector) 64 }, 65 nil, 66 ) 67 68 state := protocol.NewState(t) 69 state.On("Final").Return(final, nil) 70 71 net := mocknetwork.NewNetwork(t) 72 con := mocknetwork.NewConduit(t) 73 net.On("Register", mock.Anything, mock.Anything).Return(con, nil) 74 con.On("Unicast", mock.Anything, mock.Anything).Run( 75 func(args mock.Arguments) { 76 defer cancel() 77 78 response := args.Get(0).(*messages.EntityResponse) 79 nodeID := args.Get(1).(flow.Identifier) 80 assert.Equal(t, nodeID, originID) 81 var entities []flow.Entity 82 for _, blob := range response.Blobs { 83 coll := &flow.Collection{} 84 _ = msgpack.Unmarshal(blob, &coll) 85 entities = append(entities, coll) 86 } 87 assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll2, &coll3, &coll4, &coll5}) 88 }, 89 ).Return(nil) 90 91 me := mockmodule.NewLocal(t) 92 me.On("NodeID").Return(unittest.IdentifierFixture()) 93 requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) 94 95 e, err := provider.New( 96 unittest.Logger(), 97 metrics.NewNoopCollector(), 98 net, 99 me, 100 state, 101 requestQueue, 102 provider.DefaultRequestProviderWorkers, 103 channels.TestNetworkChannel, 104 selector, 105 retrieve) 106 require.NoError(t, err) 107 108 request := &messages.EntityRequest{ 109 Nonce: rand.Uint64(), 110 EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()}, 111 } 112 113 e.Start(ctx) 114 115 unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine") 116 117 err = e.Process(channels.TestNetworkChannel, originID, request) 118 require.NoError(t, err, "should not error on full response") 119 120 unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine") 121 } 122 123 func TestOnEntityRequestPartial(t *testing.T) { 124 cancelCtx, cancel := context.WithCancel(context.Background()) 125 defer cancel() 126 ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx) 127 128 entities := make(map[flow.Identifier]flow.Entity) 129 130 identities := unittest.IdentityListFixture(8) 131 selector := filter.HasNodeID(identities.NodeIDs()...) 132 originID := identities[0].NodeID 133 134 coll1 := unittest.CollectionFixture(1) 135 coll2 := unittest.CollectionFixture(2) 136 coll3 := unittest.CollectionFixture(3) 137 coll4 := unittest.CollectionFixture(4) 138 coll5 := unittest.CollectionFixture(5) 139 140 entities[coll1.ID()] = coll1 141 // entities[coll2.ID()] = coll2 142 entities[coll3.ID()] = coll3 143 // entities[coll4.ID()] = coll4 144 entities[coll5.ID()] = coll5 145 146 retrieve := func(entityID flow.Identifier) (flow.Entity, error) { 147 entity, ok := entities[entityID] 148 if !ok { 149 return nil, storage.ErrNotFound 150 } 151 return entity, nil 152 } 153 154 final := protocol.NewSnapshot(t) 155 final.On("Identities", mock.Anything).Return( 156 func(selector flow.IdentityFilter) flow.IdentityList { 157 return identities.Filter(selector) 158 }, 159 nil, 160 ) 161 162 state := protocol.NewState(t) 163 state.On("Final").Return(final, nil) 164 165 net := mocknetwork.NewNetwork(t) 166 con := mocknetwork.NewConduit(t) 167 net.On("Register", mock.Anything, mock.Anything).Return(con, nil) 168 con.On("Unicast", mock.Anything, mock.Anything).Run( 169 func(args mock.Arguments) { 170 defer cancel() 171 172 response := args.Get(0).(*messages.EntityResponse) 173 nodeID := args.Get(1).(flow.Identifier) 174 assert.Equal(t, nodeID, originID) 175 var entities []flow.Entity 176 for _, blob := range response.Blobs { 177 coll := &flow.Collection{} 178 _ = msgpack.Unmarshal(blob, &coll) 179 entities = append(entities, coll) 180 } 181 assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll3, &coll5}) 182 }, 183 ).Return(nil) 184 185 me := mockmodule.NewLocal(t) 186 me.On("NodeID").Return(unittest.IdentifierFixture()) 187 requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) 188 189 e, err := provider.New( 190 unittest.Logger(), 191 metrics.NewNoopCollector(), 192 net, 193 me, 194 state, 195 requestQueue, 196 provider.DefaultRequestProviderWorkers, 197 channels.TestNetworkChannel, 198 selector, 199 retrieve) 200 require.NoError(t, err) 201 202 request := &messages.EntityRequest{ 203 Nonce: rand.Uint64(), 204 EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()}, 205 } 206 207 e.Start(ctx) 208 209 unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine") 210 err = e.Process(channels.TestNetworkChannel, originID, request) 211 require.NoError(t, err, "should not error on full response") 212 unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine") 213 } 214 215 func TestOnEntityRequestDuplicates(t *testing.T) { 216 cancelCtx, cancel := context.WithCancel(context.Background()) 217 defer cancel() 218 ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx) 219 220 entities := make(map[flow.Identifier]flow.Entity) 221 222 identities := unittest.IdentityListFixture(8) 223 selector := filter.HasNodeID(identities.NodeIDs()...) 224 originID := identities[0].NodeID 225 226 coll1 := unittest.CollectionFixture(1) 227 coll2 := unittest.CollectionFixture(2) 228 coll3 := unittest.CollectionFixture(3) 229 230 entities[coll1.ID()] = coll1 231 entities[coll2.ID()] = coll2 232 entities[coll3.ID()] = coll3 233 234 retrieve := func(entityID flow.Identifier) (flow.Entity, error) { 235 entity, ok := entities[entityID] 236 if !ok { 237 return nil, storage.ErrNotFound 238 } 239 return entity, nil 240 } 241 242 final := protocol.NewSnapshot(t) 243 final.On("Identities", mock.Anything).Return( 244 func(selector flow.IdentityFilter) flow.IdentityList { 245 return identities.Filter(selector) 246 }, 247 nil, 248 ) 249 250 state := protocol.NewState(t) 251 state.On("Final").Return(final, nil) 252 253 net := mocknetwork.NewNetwork(t) 254 con := mocknetwork.NewConduit(t) 255 net.On("Register", mock.Anything, mock.Anything).Return(con, nil) 256 con.On("Unicast", mock.Anything, mock.Anything).Run( 257 func(args mock.Arguments) { 258 defer cancel() 259 260 response := args.Get(0).(*messages.EntityResponse) 261 nodeID := args.Get(1).(flow.Identifier) 262 assert.Equal(t, nodeID, originID) 263 var entities []flow.Entity 264 for _, blob := range response.Blobs { 265 coll := &flow.Collection{} 266 _ = msgpack.Unmarshal(blob, &coll) 267 entities = append(entities, coll) 268 } 269 assert.ElementsMatch(t, entities, []flow.Entity{&coll1, &coll2, &coll3}) 270 }, 271 ).Return(nil) 272 273 me := mockmodule.NewLocal(t) 274 me.On("NodeID").Return(unittest.IdentifierFixture()) 275 requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) 276 277 e, err := provider.New( 278 unittest.Logger(), 279 metrics.NewNoopCollector(), 280 net, 281 me, 282 state, 283 requestQueue, 284 provider.DefaultRequestProviderWorkers, 285 channels.TestNetworkChannel, 286 selector, 287 retrieve) 288 require.NoError(t, err) 289 290 // create entity requests with some duplicate entity IDs 291 request := &messages.EntityRequest{ 292 Nonce: rand.Uint64(), 293 EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll3.ID(), coll2.ID(), coll1.ID()}, 294 } 295 296 e.Start(ctx) 297 unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine") 298 err = e.Process(channels.TestNetworkChannel, originID, request) 299 require.NoError(t, err, "should not error on full response") 300 unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine") 301 } 302 303 func TestOnEntityRequestEmpty(t *testing.T) { 304 cancelCtx, cancel := context.WithCancel(context.Background()) 305 defer cancel() 306 ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx) 307 308 entities := make(map[flow.Identifier]flow.Entity) 309 identities := unittest.IdentityListFixture(8) 310 selector := filter.HasNodeID(identities.NodeIDs()...) 311 originID := identities[0].NodeID 312 313 coll1 := unittest.CollectionFixture(1) 314 coll2 := unittest.CollectionFixture(2) 315 coll3 := unittest.CollectionFixture(3) 316 coll4 := unittest.CollectionFixture(4) 317 coll5 := unittest.CollectionFixture(5) 318 319 retrieve := func(entityID flow.Identifier) (flow.Entity, error) { 320 entity, ok := entities[entityID] 321 if !ok { 322 return nil, storage.ErrNotFound 323 } 324 return entity, nil 325 } 326 327 final := protocol.NewSnapshot(t) 328 final.On("Identities", mock.Anything).Return( 329 func(selector flow.IdentityFilter) flow.IdentityList { 330 return identities.Filter(selector) 331 }, 332 nil, 333 ) 334 335 state := protocol.NewState(t) 336 state.On("Final").Return(final, nil) 337 338 net := mocknetwork.NewNetwork(t) 339 con := mocknetwork.NewConduit(t) 340 net.On("Register", mock.Anything, mock.Anything).Return(con, nil) 341 con.On("Unicast", mock.Anything, mock.Anything).Run( 342 func(args mock.Arguments) { 343 defer cancel() 344 345 response := args.Get(0).(*messages.EntityResponse) 346 nodeID := args.Get(1).(flow.Identifier) 347 assert.Equal(t, nodeID, originID) 348 assert.Empty(t, response.Blobs) 349 }, 350 ).Return(nil) 351 352 me := mockmodule.NewLocal(t) 353 me.On("NodeID").Return(unittest.IdentifierFixture()) 354 requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) 355 356 e, err := provider.New( 357 unittest.Logger(), 358 metrics.NewNoopCollector(), 359 net, 360 me, 361 state, 362 requestQueue, 363 provider.DefaultRequestProviderWorkers, 364 channels.TestNetworkChannel, 365 selector, 366 retrieve) 367 require.NoError(t, err) 368 369 request := &messages.EntityRequest{ 370 Nonce: rand.Uint64(), 371 EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()}, 372 } 373 374 e.Start(ctx) 375 unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine") 376 err = e.Process(channels.TestNetworkChannel, originID, request) 377 require.NoError(t, err, "should not error on full response") 378 unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine") 379 } 380 381 func TestOnEntityRequestInvalidOrigin(t *testing.T) { 382 cancelCtx, cancel := context.WithCancel(context.Background()) 383 defer cancel() 384 ctx := irrecoverable.NewMockSignalerContext(t, cancelCtx) 385 386 entities := make(map[flow.Identifier]flow.Entity) 387 identities := unittest.IdentityListFixture(8) 388 selector := filter.HasNodeID(identities.NodeIDs()...) 389 originID := unittest.IdentifierFixture() 390 391 coll1 := unittest.CollectionFixture(1) 392 coll2 := unittest.CollectionFixture(2) 393 coll3 := unittest.CollectionFixture(3) 394 coll4 := unittest.CollectionFixture(4) 395 coll5 := unittest.CollectionFixture(5) 396 397 entities[coll1.ID()] = coll1 398 entities[coll2.ID()] = coll2 399 entities[coll3.ID()] = coll3 400 entities[coll4.ID()] = coll4 401 entities[coll5.ID()] = coll5 402 403 retrieve := func(entityID flow.Identifier) (flow.Entity, error) { 404 entity, ok := entities[entityID] 405 if !ok { 406 return nil, storage.ErrNotFound 407 } 408 return entity, nil 409 } 410 411 final := protocol.NewSnapshot(t) 412 final.On("Identities", mock.Anything).Return( 413 func(selector flow.IdentityFilter) flow.IdentityList { 414 defer cancel() 415 return identities.Filter(selector) 416 }, 417 nil, 418 ) 419 420 state := protocol.NewState(t) 421 state.On("Final").Return(final, nil) 422 423 net := mocknetwork.NewNetwork(t) 424 con := mocknetwork.NewConduit(t) 425 net.On("Register", mock.Anything, mock.Anything).Return(con, nil) 426 me := mockmodule.NewLocal(t) 427 me.On("NodeID").Return(unittest.IdentifierFixture()) 428 requestQueue := queue.NewHeroStore(10, unittest.Logger(), metrics.NewNoopCollector()) 429 430 e, err := provider.New( 431 unittest.Logger(), 432 metrics.NewNoopCollector(), 433 net, 434 me, 435 state, 436 requestQueue, 437 provider.DefaultRequestProviderWorkers, 438 channels.TestNetworkChannel, 439 selector, 440 retrieve) 441 require.NoError(t, err) 442 443 request := &messages.EntityRequest{ 444 Nonce: rand.Uint64(), 445 EntityIDs: []flow.Identifier{coll1.ID(), coll2.ID(), coll3.ID(), coll4.ID(), coll5.ID()}, 446 } 447 448 e.Start(ctx) 449 unittest.RequireCloseBefore(t, e.Ready(), 100*time.Millisecond, "could not start engine") 450 err = e.Process(channels.TestNetworkChannel, originID, request) 451 require.NoError(t, err) 452 unittest.RequireCloseBefore(t, e.Done(), 100*time.Millisecond, "could not stop engine") 453 }