github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rest/routes/subscribe_events_test.go (about)

     1  package routes
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net/http"
     9  	"net/url"
    10  	"regexp"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"golang.org/x/exp/slices"
    16  
    17  	jsoncdc "github.com/onflow/cadence/encoding/json"
    18  	"github.com/onflow/flow/protobuf/go/flow/entities"
    19  	mocks "github.com/stretchr/testify/mock"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/stretchr/testify/suite"
    22  
    23  	"github.com/onflow/flow-go/engine/access/rest/request"
    24  	"github.com/onflow/flow-go/engine/access/state_stream"
    25  	"github.com/onflow/flow-go/engine/access/state_stream/backend"
    26  	mockstatestream "github.com/onflow/flow-go/engine/access/state_stream/mock"
    27  	"github.com/onflow/flow-go/model/flow"
    28  	"github.com/onflow/flow-go/utils/unittest"
    29  	"github.com/onflow/flow-go/utils/unittest/generator"
    30  )
    31  
    32  type testType struct {
    33  	name         string
    34  	startBlockID flow.Identifier
    35  	startHeight  uint64
    36  
    37  	eventTypes []string
    38  	addresses  []string
    39  	contracts  []string
    40  
    41  	heartbeatInterval uint64
    42  
    43  	headers http.Header
    44  }
    45  
    46  var chainID = flow.Testnet
    47  var testEventTypes = []flow.EventType{
    48  	unittest.EventTypeFixture(chainID),
    49  	unittest.EventTypeFixture(chainID),
    50  	unittest.EventTypeFixture(chainID),
    51  }
    52  
    53  type SubscribeEventsSuite struct {
    54  	suite.Suite
    55  
    56  	blocks      []*flow.Block
    57  	blockEvents map[flow.Identifier]flow.EventsList
    58  }
    59  
    60  func TestSubscribeEventsSuite(t *testing.T) {
    61  	suite.Run(t, new(SubscribeEventsSuite))
    62  }
    63  
    64  func (s *SubscribeEventsSuite) SetupTest() {
    65  	rootBlock := unittest.BlockFixture()
    66  	parent := rootBlock.Header
    67  
    68  	blockCount := 5
    69  
    70  	s.blocks = make([]*flow.Block, 0, blockCount)
    71  	s.blockEvents = make(map[flow.Identifier]flow.EventsList, blockCount)
    72  
    73  	// by default, events are in CCF encoding
    74  	eventsGenerator := generator.EventGenerator(generator.WithEncoding(entities.EventEncodingVersion_CCF_V0))
    75  
    76  	for i := 0; i < blockCount; i++ {
    77  		block := unittest.BlockWithParentFixture(parent)
    78  		// update for next iteration
    79  		parent = block.Header
    80  
    81  		result := unittest.ExecutionResultFixture()
    82  		blockEvents := unittest.BlockEventsFixture(block.Header, (i%len(testEventTypes))*3+1, testEventTypes...)
    83  
    84  		// update payloads with valid CCF encoded data
    85  		for i := range blockEvents.Events {
    86  			blockEvents.Events[i].Payload = eventsGenerator.New().Payload
    87  
    88  			s.T().Logf("block events %d %v => %v", block.Header.Height, block.ID(), blockEvents.Events[i].Type)
    89  		}
    90  
    91  		s.blocks = append(s.blocks, block)
    92  		s.blockEvents[block.ID()] = blockEvents.Events
    93  
    94  		s.T().Logf("adding exec data for block %d %d %v => %v", i, block.Header.Height, block.ID(), result.ExecutionDataID)
    95  	}
    96  }
    97  
    98  // TestSubscribeEvents is a happy cases tests for the SubscribeEvents functionality.
    99  // This test function covers various scenarios for subscribing to events via WebSocket.
   100  //
   101  // It tests scenarios:
   102  //   - Subscribing to events from the root height.
   103  //   - Subscribing to events from a specific start height.
   104  //   - Subscribing to events from a specific start block ID.
   105  //   - Subscribing to events from the root height with custom heartbeat interval.
   106  //
   107  // Every scenario covers the following aspects:
   108  //   - Subscribing to all events.
   109  //   - Subscribing to events of a specific type (some events).
   110  //
   111  // For each scenario, this test function creates WebSocket requests, simulates WebSocket responses with mock data,
   112  // and validates that the received WebSocket response matches the expected EventsResponses.
   113  func (s *SubscribeEventsSuite) TestSubscribeEvents() {
   114  	testVectors := []testType{
   115  		{
   116  			name:              "happy path - all events from root height",
   117  			startBlockID:      flow.ZeroID,
   118  			startHeight:       request.EmptyHeight,
   119  			heartbeatInterval: 1,
   120  		},
   121  		{
   122  			name:              "happy path - all events from startHeight",
   123  			startBlockID:      flow.ZeroID,
   124  			startHeight:       s.blocks[0].Header.Height,
   125  			heartbeatInterval: 1,
   126  		},
   127  		{
   128  			name:              "happy path - all events from startBlockID",
   129  			startBlockID:      s.blocks[0].ID(),
   130  			startHeight:       request.EmptyHeight,
   131  			heartbeatInterval: 1,
   132  		},
   133  		{
   134  			name:              "happy path - events from root height with custom heartbeat",
   135  			startBlockID:      flow.ZeroID,
   136  			startHeight:       request.EmptyHeight,
   137  			heartbeatInterval: 2,
   138  		},
   139  		{
   140  			name:              "happy path - all origins allowed",
   141  			startBlockID:      flow.ZeroID,
   142  			startHeight:       request.EmptyHeight,
   143  			heartbeatInterval: 1,
   144  			headers: http.Header{
   145  				"Origin": []string{"https://example.com"},
   146  			},
   147  		},
   148  	}
   149  
   150  	// create variations for each of the base test
   151  	tests := make([]testType, 0, len(testVectors)*2)
   152  	for _, test := range testVectors {
   153  		t1 := test
   154  		t1.name = fmt.Sprintf("%s - all events", test.name)
   155  		tests = append(tests, t1)
   156  
   157  		t2 := test
   158  		t2.name = fmt.Sprintf("%s - some events", test.name)
   159  		t2.eventTypes = []string{string(testEventTypes[0])}
   160  		tests = append(tests, t2)
   161  
   162  		t3 := test
   163  		t3.name = fmt.Sprintf("%s - non existing events", test.name)
   164  		t3.eventTypes = []string{fmt.Sprintf("%s_new", testEventTypes[0])}
   165  		tests = append(tests, t3)
   166  	}
   167  
   168  	for _, test := range tests {
   169  		s.Run(test.name, func() {
   170  			stateStreamBackend := mockstatestream.NewAPI(s.T())
   171  			subscription := mockstatestream.NewSubscription(s.T())
   172  
   173  			filter, err := state_stream.NewEventFilter(
   174  				state_stream.DefaultEventFilterConfig,
   175  				chainID.Chain(),
   176  				test.eventTypes,
   177  				test.addresses,
   178  				test.contracts)
   179  			require.NoError(s.T(), err)
   180  
   181  			var expectedEventsResponses []*backend.EventsResponse
   182  			var subscriptionEventsResponses []*backend.EventsResponse
   183  			startBlockFound := test.startBlockID == flow.ZeroID
   184  
   185  			// construct expected event responses based on the provided test configuration
   186  			for i, block := range s.blocks {
   187  				blockID := block.ID()
   188  				if startBlockFound || blockID == test.startBlockID {
   189  					startBlockFound = true
   190  					if test.startHeight == request.EmptyHeight || block.Header.Height >= test.startHeight {
   191  						// track 2 lists, one for the expected results and one that is passed back
   192  						// from the subscription to the handler. These cannot be shared since the
   193  						// response struct is passed by reference from the mock to the handler, so
   194  						// a bug within the handler could go unnoticed
   195  						expectedEvents := flow.EventsList{}
   196  						subscriptionEvents := flow.EventsList{}
   197  						for _, event := range s.blockEvents[blockID] {
   198  							if slices.Contains(test.eventTypes, string(event.Type)) ||
   199  								len(test.eventTypes) == 0 { // Include all events
   200  								expectedEvents = append(expectedEvents, event)
   201  								subscriptionEvents = append(subscriptionEvents, event)
   202  							}
   203  						}
   204  						if len(expectedEvents) > 0 || (i+1)%int(test.heartbeatInterval) == 0 {
   205  							expectedEventsResponses = append(expectedEventsResponses, &backend.EventsResponse{
   206  								Height:         block.Header.Height,
   207  								BlockID:        blockID,
   208  								Events:         expectedEvents,
   209  								BlockTimestamp: block.Header.Timestamp,
   210  							})
   211  						}
   212  						subscriptionEventsResponses = append(subscriptionEventsResponses, &backend.EventsResponse{
   213  							Height:         block.Header.Height,
   214  							BlockID:        blockID,
   215  							Events:         subscriptionEvents,
   216  							BlockTimestamp: block.Header.Timestamp,
   217  						})
   218  					}
   219  				}
   220  			}
   221  
   222  			// Create a channel to receive mock EventsResponse objects
   223  			ch := make(chan interface{})
   224  			var chReadOnly <-chan interface{}
   225  			// Simulate sending a mock EventsResponse
   226  			go func() {
   227  				for _, eventResponse := range subscriptionEventsResponses {
   228  					// Send the mock EventsResponse through the channel
   229  					ch <- eventResponse
   230  				}
   231  			}()
   232  
   233  			chReadOnly = ch
   234  			subscription.Mock.On("Channel").Return(chReadOnly)
   235  
   236  			var startHeight uint64
   237  			if test.startHeight == request.EmptyHeight {
   238  				startHeight = uint64(0)
   239  			} else {
   240  				startHeight = test.startHeight
   241  			}
   242  			stateStreamBackend.Mock.
   243  				On("SubscribeEvents", mocks.Anything, test.startBlockID, startHeight, filter).
   244  				Return(subscription)
   245  
   246  			req, err := getSubscribeEventsRequest(s.T(), test.startBlockID, test.startHeight, test.eventTypes, test.addresses, test.contracts, test.heartbeatInterval, test.headers)
   247  			require.NoError(s.T(), err)
   248  			respRecorder := newTestHijackResponseRecorder()
   249  			// closing the connection after 1 second
   250  			go func() {
   251  				time.Sleep(1 * time.Second)
   252  				respRecorder.Close()
   253  			}()
   254  			executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain())
   255  			requireResponse(s.T(), respRecorder, expectedEventsResponses)
   256  		})
   257  	}
   258  }
   259  
   260  func (s *SubscribeEventsSuite) TestSubscribeEventsHandlesErrors() {
   261  	s.Run("returns error for block id and height", func() {
   262  		stateStreamBackend := mockstatestream.NewAPI(s.T())
   263  		req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), s.blocks[0].Header.Height, nil, nil, nil, 1, nil)
   264  		require.NoError(s.T(), err)
   265  		respRecorder := newTestHijackResponseRecorder()
   266  		executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain())
   267  		requireError(s.T(), respRecorder, "can only provide either block ID or start height")
   268  	})
   269  
   270  	s.Run("returns error for invalid block id", func() {
   271  		stateStreamBackend := mockstatestream.NewAPI(s.T())
   272  		invalidBlock := unittest.BlockFixture()
   273  		subscription := mockstatestream.NewSubscription(s.T())
   274  
   275  		ch := make(chan interface{})
   276  		var chReadOnly <-chan interface{}
   277  		go func() {
   278  			close(ch)
   279  		}()
   280  		chReadOnly = ch
   281  
   282  		subscription.Mock.On("Channel").Return(chReadOnly)
   283  		subscription.Mock.On("Err").Return(fmt.Errorf("subscription error"))
   284  		stateStreamBackend.Mock.
   285  			On("SubscribeEvents", mocks.Anything, invalidBlock.ID(), uint64(0), mocks.Anything).
   286  			Return(subscription)
   287  
   288  		req, err := getSubscribeEventsRequest(s.T(), invalidBlock.ID(), request.EmptyHeight, nil, nil, nil, 1, nil)
   289  		require.NoError(s.T(), err)
   290  		respRecorder := newTestHijackResponseRecorder()
   291  		executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain())
   292  		requireError(s.T(), respRecorder, "stream encountered an error: subscription error")
   293  	})
   294  
   295  	s.Run("returns error for invalid event filter", func() {
   296  		stateStreamBackend := mockstatestream.NewAPI(s.T())
   297  		req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, []string{"foo"}, nil, nil, 1, nil)
   298  		require.NoError(s.T(), err)
   299  		respRecorder := newTestHijackResponseRecorder()
   300  		executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain())
   301  		requireError(s.T(), respRecorder, "invalid event type format")
   302  	})
   303  
   304  	s.Run("returns error when channel closed", func() {
   305  		stateStreamBackend := mockstatestream.NewAPI(s.T())
   306  		subscription := mockstatestream.NewSubscription(s.T())
   307  
   308  		ch := make(chan interface{})
   309  		var chReadOnly <-chan interface{}
   310  
   311  		go func() {
   312  			close(ch)
   313  		}()
   314  		chReadOnly = ch
   315  
   316  		subscription.Mock.On("Channel").Return(chReadOnly)
   317  		subscription.Mock.On("Err").Return(nil)
   318  		stateStreamBackend.Mock.
   319  			On("SubscribeEvents", mocks.Anything, s.blocks[0].ID(), uint64(0), mocks.Anything).
   320  			Return(subscription)
   321  
   322  		req, err := getSubscribeEventsRequest(s.T(), s.blocks[0].ID(), request.EmptyHeight, nil, nil, nil, 1, nil)
   323  		require.NoError(s.T(), err)
   324  		respRecorder := newTestHijackResponseRecorder()
   325  		executeWsRequest(req, stateStreamBackend, respRecorder, chainID.Chain())
   326  		requireError(s.T(), respRecorder, "subscription channel closed")
   327  	})
   328  }
   329  
   330  func getSubscribeEventsRequest(t *testing.T,
   331  	startBlockId flow.Identifier,
   332  	startHeight uint64,
   333  	eventTypes []string,
   334  	addresses []string,
   335  	contracts []string,
   336  	heartbeatInterval uint64,
   337  	header http.Header,
   338  ) (*http.Request, error) {
   339  	u, _ := url.Parse("/v1/subscribe_events")
   340  	q := u.Query()
   341  
   342  	if startBlockId != flow.ZeroID {
   343  		q.Add(startBlockIdQueryParam, startBlockId.String())
   344  	}
   345  
   346  	if startHeight != request.EmptyHeight {
   347  		q.Add(startHeightQueryParam, fmt.Sprintf("%d", startHeight))
   348  	}
   349  
   350  	if len(eventTypes) > 0 {
   351  		q.Add(eventTypesQueryParams, strings.Join(eventTypes, ","))
   352  	}
   353  	if len(addresses) > 0 {
   354  		q.Add(addressesQueryParams, strings.Join(addresses, ","))
   355  	}
   356  	if len(contracts) > 0 {
   357  		q.Add(contractsQueryParams, strings.Join(contracts, ","))
   358  	}
   359  
   360  	q.Add(heartbeatIntervalQueryParam, fmt.Sprintf("%d", heartbeatInterval))
   361  
   362  	u.RawQuery = q.Encode()
   363  	key, err := generateWebSocketKey()
   364  	if err != nil {
   365  		err := fmt.Errorf("error generating websocket key: %v", err)
   366  		return nil, err
   367  	}
   368  
   369  	req, err := http.NewRequest("GET", u.String(), nil)
   370  	require.NoError(t, err)
   371  
   372  	req.Header.Set("Connection", "upgrade")
   373  	req.Header.Set("Upgrade", "websocket")
   374  	req.Header.Set("Sec-Websocket-Version", "13")
   375  	req.Header.Set("Sec-Websocket-Key", key)
   376  
   377  	for k, v := range header {
   378  		req.Header.Set(k, v[0])
   379  	}
   380  
   381  	return req, nil
   382  }
   383  
   384  func generateWebSocketKey() (string, error) {
   385  	// Generate 16 random bytes.
   386  	keyBytes := make([]byte, 16)
   387  	if _, err := rand.Read(keyBytes); err != nil {
   388  		return "", err
   389  	}
   390  
   391  	// Encode the bytes to base64 and return the key as a string.
   392  	return base64.StdEncoding.EncodeToString(keyBytes), nil
   393  }
   394  
   395  func requireError(t *testing.T, recorder *testHijackResponseRecorder, expected string) {
   396  	<-recorder.closed
   397  	require.Contains(t, recorder.responseBuff.String(), expected)
   398  }
   399  
   400  // requireResponse validates that the response received from WebSocket communication matches the expected EventsResponse.
   401  // This function compares the BlockID, Events count, and individual event properties for each expected and actual
   402  // EventsResponse. It ensures that the response received from WebSocket matches the expected structure and content.
   403  func requireResponse(t *testing.T, recorder *testHijackResponseRecorder, expected []*backend.EventsResponse) {
   404  	<-recorder.closed
   405  	// Convert the actual response from respRecorder to JSON bytes
   406  	actualJSON := recorder.responseBuff.Bytes()
   407  	// Define a regular expression pattern to match JSON objects
   408  	pattern := `\{"BlockID":".*?","Height":\d+,"Events":\[(\{.*?})*\],"BlockTimestamp":".*?"\}`
   409  	matches := regexp.MustCompile(pattern).FindAll(actualJSON, -1)
   410  
   411  	// Unmarshal each matched JSON into []state_stream.EventsResponse
   412  	var actual []backend.EventsResponse
   413  	for _, match := range matches {
   414  		var response backend.EventsResponse
   415  		if err := json.Unmarshal(match, &response); err == nil {
   416  			actual = append(actual, response)
   417  		}
   418  	}
   419  
   420  	// Compare the count of expected and actual responses
   421  	require.Equal(t, len(expected), len(actual))
   422  
   423  	// Compare the BlockID and Events count for each response
   424  	for responseIndex := range expected {
   425  		expectedEventsResponse := expected[responseIndex]
   426  		actualEventsResponse := actual[responseIndex]
   427  
   428  		require.Equal(t, expectedEventsResponse.BlockID, actualEventsResponse.BlockID)
   429  		require.Equal(t, len(expectedEventsResponse.Events), len(actualEventsResponse.Events))
   430  
   431  		for eventIndex, expectedEvent := range expectedEventsResponse.Events {
   432  			actualEvent := actualEventsResponse.Events[eventIndex]
   433  			require.Equal(t, expectedEvent.Type, actualEvent.Type)
   434  			require.Equal(t, expectedEvent.TransactionID, actualEvent.TransactionID)
   435  			require.Equal(t, expectedEvent.TransactionIndex, actualEvent.TransactionIndex)
   436  			require.Equal(t, expectedEvent.EventIndex, actualEvent.EventIndex)
   437  			// payload is not expected to match, but it should decode
   438  
   439  			// payload must decode to valid json-cdc encoded data
   440  			_, err := jsoncdc.Decode(nil, actualEvent.Payload)
   441  			require.NoError(t, err)
   442  		}
   443  	}
   444  }