github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rest/routes/subscribe_events_test.go (about) 1 package routes 2 3 import ( 4 "crypto/rand" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "net/http" 9 "net/url" 10 "regexp" 11 "strings" 12 "testing" 13 "time" 14 15 "golang.org/x/exp/slices" 16 17 jsoncdc "github.com/onflow/cadence/encoding/json" 18 "github.com/onflow/flow/protobuf/go/flow/entities" 19 mocks "github.com/stretchr/testify/mock" 20 "github.com/stretchr/testify/require" 21 "github.com/stretchr/testify/suite" 22 23 "github.com/onflow/flow-go/engine/access/rest/request" 24 "github.com/onflow/flow-go/engine/access/state_stream" 25 "github.com/onflow/flow-go/engine/access/state_stream/backend" 26 mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock" 27 "github.com/onflow/flow-go/model/flow" 28 "github.com/onflow/flow-go/utils/unittest" 29 "github.com/onflow/flow-go/utils/unittest/generator" 30 ) 31 32 type testType struct { 33 name string 34 startBlockID flow.Identifier 35 startHeight uint64 36 37 eventTypes []string 38 addresses []string 39 contracts []string 40 41 heartbeatInterval uint64 42 43 headers http.Header 44 } 45 46 var chainID = flow.Testnet 47 var testEventTypes = []flow.EventType{ 48 unittest.EventTypeFixture(chainID), 49 unittest.EventTypeFixture(chainID), 50 unittest.EventTypeFixture(chainID), 51 } 52 53 type SubscribeEventsSuite struct { 54 suite.Suite 55 56 blocks []*flow.Block 57 blockEvents map[flow.Identifier]flow.EventsList 58 } 59 60 func TestSubscribeEventsSuite(t *testing.T) { 61 suite.Run(t, new(SubscribeEventsSuite)) 62 } 63 64 func (s *SubscribeEventsSuite) SetupTest() { 65 rootBlock := unittest.BlockFixture() 66 parent := rootBlock.Header 67 68 blockCount := 5 69 70 s.blocks = make([]*flow.Block, 0, blockCount) 71 s.blockEvents = make(map[flow.Identifier]flow.EventsList, blockCount) 72 73 // by default, events are in CCF encoding 74 eventsGenerator := generator.EventGenerator(generator.WithEncoding(entities.EventEncodingVersion_CCF_V0)) 75 76 for i := 0; i < blockCount; i++ { 77 block := unittest.BlockWithParentFixture(parent) 78 // update for next iteration 79 parent = block.Header 80 81 result := unittest.ExecutionResultFixture() 82 blockEvents := unittest.BlockEventsFixture(block.Header, (i%len(testEventTypes))*3+1, testEventTypes...) 83 84 // update payloads with valid CCF encoded data 85 for i := range blockEvents.Events { 86 blockEvents.Events[i].Payload = eventsGenerator.New().Payload 87 88 s.T().Logf("block events %d %v => %v", block.Header.Height, block.ID(), blockEvents.Events[i].Type) 89 } 90 91 s.blocks = append(s.blocks, block) 92 s.blockEvents[block.ID()] = blockEvents.Events 93 94 s.T().Logf("adding exec data for block %d %d %v => %v", i, block.Header.Height, block.ID(), result.ExecutionDataID) 95 } 96 } 97 98 // TestSubscribeEvents is a happy cases tests for the SubscribeEvents functionality. 99 // This test function covers various scenarios for subscribing to events via WebSocket. 100 // 101 // It tests scenarios: 102 // - Subscribing to events from the root height. 103 // - Subscribing to events from a specific start height. 104 // - Subscribing to events from a specific start block ID. 105 // - Subscribing to events from the root height with custom heartbeat interval. 106 // 107 // Every scenario covers the following aspects: 108 // - Subscribing to all events. 109 // - Subscribing to events of a specific type (some events). 110 // 111 // For each scenario, this test function creates WebSocket requests, simulates WebSocket responses with mock data, 112 // and validates that the received WebSocket response matches the expected EventsResponses. 113 func (s *SubscribeEventsSuite) TestSubscribeEvents() { 114 testVectors := []testType{ 115 { 116 name: "happy path - all events from root height", 117 startBlockID: flow.ZeroID, 118 startHeight: request.EmptyHeight, 119 heartbeatInterval: 1, 120 }, 121 { 122 name: "happy path - all events from startHeight", 123 startBlockID: flow.ZeroID, 124 startHeight: s.blocks[0].Header.Height, 125 heartbeatInterval: 1, 126 }, 127 { 128 name: "happy path - all events from startBlockID", 129 startBlockID: s.blocks[0].ID(), 130 startHeight: request.EmptyHeight, 131 heartbeatInterval: 1, 132 }, 133 { 134 name: "happy path - events from root height with custom heartbeat", 135 startBlockID: flow.ZeroID, 136 startHeight: request.EmptyHeight, 137 heartbeatInterval: 2, 138 }, 139 { 140 name: "happy path - all origins allowed", 141 startBlockID: flow.ZeroID, 142 startHeight: request.EmptyHeight, 143 heartbeatInterval: 1, 144 headers: http.Header{ 145 "Origin": []string{"https://example.com"}, 146 }, 147 }, 148 } 149 150 // create variations for each of the base test 151 tests := make([]testType, 0, len(testVectors)*2) 152 for _, test := range testVectors { 153 t1 := test 154 t1.name = fmt.Sprintf("%s - all events", test.name) 155 tests = append(tests, t1) 156 157 t2 := test 158 t2.name = fmt.Sprintf("%s - some events", test.name) 159 t2.eventTypes = []string{string(testEventTypes[0])} 160 tests = append(tests, t2) 161 162 t3 := test 163 t3.name = fmt.Sprintf("%s - non existing events", test.name) 164 t3.eventTypes = []string{fmt.Sprintf("%s_new", testEventTypes[0])} 165 tests = append(tests, t3) 166 } 167 168 for _, test := range tests { 169 s.Run(test.name, func() { 170 stateStreamBackend := mockstatestream.NewAPI(s.T()) 171 subscription := mockstatestream.NewSubscription(s.T()) 172 173 filter, err := state_stream.NewEventFilter( 174 state_stream.DefaultEventFilterConfig, 175 chainID.Chain(), 176 test.eventTypes, 177 test.addresses, 178 test.contracts) 179 require.NoError(s.T(), err) 180 181 var expectedEventsResponses []*backend.EventsResponse 182 var subscriptionEventsResponses []*backend.EventsResponse 183 startBlockFound := test.startBlockID == flow.ZeroID 184 185 // construct expected event responses based on the provided test configuration 186 for i, block := range s.blocks { 187 blockID := block.ID() 188 if startBlockFound || blockID == test.startBlockID { 189 startBlockFound = true 190 if test.startHeight == request.EmptyHeight || block.Header.Height >= test.startHeight { 191 // track 2 lists, one for the expected results and one that is passed back 192 // from the subscription to the handler. These cannot be shared since the 193 // response struct is passed by reference from the mock to the handler, so 194 // a bug within the handler could go unnoticed 195 expectedEvents := flow.EventsList{} 196 subscriptionEvents := flow.EventsList{} 197 for _, event := range s.blockEvents[blockID] { 198 if slices.Contains(test.eventTypes, string(event.Type)) || 199 len(test.eventTypes) == 0 { // Include all events 200 expectedEvents = append(expectedEvents, event) 201 subscriptionEvents = append(subscriptionEvents, event) 202 } 203 } 204 if len(expectedEvents) > 0 || (i+1)%int(test.heartbeatInterval) == 0 { 205 expectedEventsResponses = append(expectedEventsResponses, &backend.EventsResponse{ 206 Height: block.Header.Height, 207 BlockID: blockID, 208 Events: expectedEvents, 209 BlockTimestamp: block.Header.Timestamp, 210 }) 211 } 212 subscriptionEventsResponses = append(subscriptionEventsResponses, &backend.EventsResponse{ 213 Height: block.Header.Height, 214 BlockID: blockID, 215 Events: subscriptionEvents, 216 BlockTimestamp: block.Header.Timestamp, 217 }) 218 } 219 } 220 } 221 222 // Create a channel to receive mock EventsResponse objects 223 ch := make(chan interface{}) 224 var chReadOnly <-chan interface{} 225 // Simulate sending a mock EventsResponse 226 go func() { 227 for _, eventResponse := range subscriptionEventsResponses { 228 // Send the mock EventsResponse through the channel 229 ch <- eventResponse 230 } 231 }() 232 233 chReadOnly = ch 234 subscription.Mock.On("Channel").Return(chReadOnly) 235 236 var startHeight uint64 237 if test.startHeight == request.EmptyHeight { 238 startHeight = uint64(0) 239 } else { 240 startHeight = test.startHeight 241 } 242 stateStreamBackend.Mock. 243 On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter). 244 Return(subscription) 245 246 req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts, test.heartbeatInterval, test.headers) 247 require.NoError(s.T(), err) 248 respRecorder := newTestHijackResponseRecorder() 249 // closing the connection after 1 second 250 go func() { 251 time.Sleep(1 * time.Second) 252 respRecorder.Close() 253 }() 254 executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) 255 requireResponse(s.T(), respRecorder, expectedEventsResponses) 256 }) 257 } 258 } 259 260 func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() { 261 s.Run("returns error for block id and height", func() { 262 stateStreamBackend := mockstatestream.NewAPI(s.T()) 263 req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil, 1, nil) 264 require.NoError(s.T(), err) 265 respRecorder := newTestHijackResponseRecorder() 266 executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) 267 requireError(s.T(), respRecorder, "can only provide either block ID or start height") 268 }) 269 270 s.Run("returns error for invalid block id", func() { 271 stateStreamBackend := mockstatestream.NewAPI(s.T()) 272 invalidBlock := unittest.BlockFixture() 273 subscription := mockstatestream.NewSubscription(s.T()) 274 275 ch := make(chan interface{}) 276 var chReadOnly <-chan interface{} 277 go func() { 278 close(ch) 279 }() 280 chReadOnly = ch 281 282 subscription.Mock.On("Channel").Return(chReadOnly) 283 subscription.Mock.On("Err").Return(fmt.Errorf("subscription error")) 284 stateStreamBackend.Mock. 285 On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything). 286 Return(subscription) 287 288 req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil, 1, nil) 289 require.NoError(s.T(), err) 290 respRecorder := newTestHijackResponseRecorder() 291 executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) 292 requireError(s.T(), respRecorder, "stream encountered an error: subscription error") 293 }) 294 295 s.Run("returns error for invalid event filter", func() { 296 stateStreamBackend := mockstatestream.NewAPI(s.T()) 297 req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, []string{"foo"}, nil, nil, 1, nil) 298 require.NoError(s.T(), err) 299 respRecorder := newTestHijackResponseRecorder() 300 executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) 301 requireError(s.T(), respRecorder, "invalid event type format") 302 }) 303 304 s.Run("returns error when channel closed", func() { 305 stateStreamBackend := mockstatestream.NewAPI(s.T()) 306 subscription := mockstatestream.NewSubscription(s.T()) 307 308 ch := make(chan interface{}) 309 var chReadOnly <-chan interface{} 310 311 go func() { 312 close(ch) 313 }() 314 chReadOnly = ch 315 316 subscription.Mock.On("Channel").Return(chReadOnly) 317 subscription.Mock.On("Err").Return(nil) 318 stateStreamBackend.Mock. 319 On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything). 320 Return(subscription) 321 322 req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil, 1, nil) 323 require.NoError(s.T(), err) 324 respRecorder := newTestHijackResponseRecorder() 325 executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain()) 326 requireError(s.T(), respRecorder, "subscription channel closed") 327 }) 328 } 329 330 func getSubscribeEventsRequest(t *testing.T, 331 startBlockId flow.Identifier, 332 startHeight uint64, 333 eventTypes []string, 334 addresses []string, 335 contracts []string, 336 heartbeatInterval uint64, 337 header http.Header, 338 ) (*http.Request, error) { 339 u, _ := url.Parse("/v1/subscribe_events") 340 q := u.Query() 341 342 if startBlockId != flow.ZeroID { 343 q.Add(startBlockIdQueryParam, startBlockId.String()) 344 } 345 346 if startHeight != request.EmptyHeight { 347 q.Add(startHeightQueryParam, fmt.Sprintf("%d", startHeight)) 348 } 349 350 if len(eventTypes) > 0 { 351 q.Add(eventTypesQueryParams, strings.Join(eventTypes, ",")) 352 } 353 if len(addresses) > 0 { 354 q.Add(addressesQueryParams, strings.Join(addresses, ",")) 355 } 356 if len(contracts) > 0 { 357 q.Add(contractsQueryParams, strings.Join(contracts, ",")) 358 } 359 360 q.Add(heartbeatIntervalQueryParam, fmt.Sprintf("%d", heartbeatInterval)) 361 362 u.RawQuery = q.Encode() 363 key, err := generateWebSocketKey() 364 if err != nil { 365 err := fmt.Errorf("error generating websocket key: %v", err) 366 return nil, err 367 } 368 369 req, err := http.NewRequest("GET", u.String(), nil) 370 require.NoError(t, err) 371 372 req.Header.Set("Connection", "upgrade") 373 req.Header.Set("Upgrade", "websocket") 374 req.Header.Set("Sec-Websocket-Version", "13") 375 req.Header.Set("Sec-Websocket-Key", key) 376 377 for k, v := range header { 378 req.Header.Set(k, v[0]) 379 } 380 381 return req, nil 382 } 383 384 func generateWebSocketKey() (string, error) { 385 // Generate 16 random bytes. 386 keyBytes := make([]byte, 16) 387 if _, err := rand.Read(keyBytes); err != nil { 388 return "", err 389 } 390 391 // Encode the bytes to base64 and return the key as a string. 392 return base64.StdEncoding.EncodeToString(keyBytes), nil 393 } 394 395 func requireError(t *testing.T, recorder *testHijackResponseRecorder, expected string) { 396 <-recorder.closed 397 require.Contains(t, recorder.responseBuff.String(), expected) 398 } 399 400 // requireResponse validates that the response received from WebSocket communication matches the expected EventsResponse. 401 // This function compares the BlockID, Events count, and individual event properties for each expected and actual 402 // EventsResponse. It ensures that the response received from WebSocket matches the expected structure and content. 403 func requireResponse(t *testing.T, recorder *testHijackResponseRecorder, expected []*backend.EventsResponse) { 404 <-recorder.closed 405 // Convert the actual response from respRecorder to JSON bytes 406 actualJSON := recorder.responseBuff.Bytes() 407 // Define a regular expression pattern to match JSON objects 408 pattern := `\{"BlockID":".*?","Height":\d+,"Events":\[(\{.*?})*\],"BlockTimestamp":".*?"\}` 409 matches := regexp.MustCompile(pattern).FindAll(actualJSON, -1) 410 411 // Unmarshal each matched JSON into []state_stream.EventsResponse 412 var actual []backend.EventsResponse 413 for _, match := range matches { 414 var response backend.EventsResponse 415 if err := json.Unmarshal(match, &response); err == nil { 416 actual = append(actual, response) 417 } 418 } 419 420 // Compare the count of expected and actual responses 421 require.Equal(t, len(expected), len(actual)) 422 423 // Compare the BlockID and Events count for each response 424 for responseIndex := range expected { 425 expectedEventsResponse := expected[responseIndex] 426 actualEventsResponse := actual[responseIndex] 427 428 require.Equal(t, expectedEventsResponse.BlockID, actualEventsResponse.BlockID) 429 require.Equal(t, len(expectedEventsResponse.Events), len(actualEventsResponse.Events)) 430 431 for eventIndex, expectedEvent := range expectedEventsResponse.Events { 432 actualEvent := actualEventsResponse.Events[eventIndex] 433 require.Equal(t, expectedEvent.Type, actualEvent.Type) 434 require.Equal(t, expectedEvent.TransactionID, actualEvent.TransactionID) 435 require.Equal(t, expectedEvent.TransactionIndex, actualEvent.TransactionIndex) 436 require.Equal(t, expectedEvent.EventIndex, actualEvent.EventIndex) 437 // payload is not expected to match, but it should decode 438 439 // payload must decode to valid json-cdc encoded data 440 _, err := jsoncdc.Decode(nil, actualEvent.Payload) 441 require.NoError(t, err) 442 } 443 } 444 }