github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rpc/connection/connection_test.go (about)

     1  package connection
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"math/big"
     8  	"net"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	lru "github.com/hashicorp/golang-lru/v2"
    14  	"github.com/onflow/flow/protobuf/go/flow/access"
    15  	"github.com/onflow/flow/protobuf/go/flow/execution"
    16  	"github.com/sony/gobreaker"
    17  	"github.com/stretchr/testify/assert"
    18  	testifymock "github.com/stretchr/testify/mock"
    19  	"github.com/stretchr/testify/require"
    20  	"go.uber.org/atomic"
    21  	"google.golang.org/grpc"
    22  	"google.golang.org/grpc/codes"
    23  	"google.golang.org/grpc/connectivity"
    24  	"google.golang.org/grpc/status"
    25  	"pgregory.net/rapid"
    26  
    27  	"github.com/onflow/flow-go/module/metrics"
    28  	"github.com/onflow/flow-go/utils/grpcutils"
    29  	"github.com/onflow/flow-go/utils/unittest"
    30  )
    31  
    32  func TestProxyAccessAPI(t *testing.T) {
    33  	logger := unittest.Logger()
    34  	metrics := metrics.NewNoopCollector()
    35  
    36  	// create a collection node
    37  	cn := new(collectionNode)
    38  	cn.start(t)
    39  	defer cn.stop(t)
    40  
    41  	req := &access.PingRequest{}
    42  	expected := &access.PingResponse{}
    43  	cn.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
    44  
    45  	// create the factory
    46  	connectionFactory := new(ConnectionFactoryImpl)
    47  	// set the collection grpc port
    48  	connectionFactory.CollectionGRPCPort = cn.port
    49  	// set metrics reporting
    50  	connectionFactory.AccessMetrics = metrics
    51  	connectionFactory.Manager = NewManager(
    52  		logger,
    53  		connectionFactory.AccessMetrics,
    54  		nil,
    55  		0,
    56  		CircuitBreakerConfig{},
    57  		grpcutils.NoCompressor,
    58  	)
    59  
    60  	proxyConnectionFactory := ProxyConnectionFactory{
    61  		ConnectionFactory: connectionFactory,
    62  		targetAddress:     cn.listener.Addr().String(),
    63  	}
    64  
    65  	// get a collection API client
    66  	client, conn, err := proxyConnectionFactory.GetAccessAPIClient("foo", nil)
    67  	defer conn.Close()
    68  	assert.NoError(t, err)
    69  
    70  	ctx := context.Background()
    71  	// make the call to the collection node
    72  	resp, err := client.Ping(ctx, req)
    73  	assert.NoError(t, err)
    74  	assert.Equal(t, resp, expected)
    75  }
    76  
    77  func TestProxyExecutionAPI(t *testing.T) {
    78  	logger := unittest.Logger()
    79  	metrics := metrics.NewNoopCollector()
    80  
    81  	// create an execution node
    82  	en := new(executionNode)
    83  	en.start(t)
    84  	defer en.stop(t)
    85  
    86  	req := &execution.PingRequest{}
    87  	expected := &execution.PingResponse{}
    88  	en.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
    89  
    90  	// create the factory
    91  	connectionFactory := new(ConnectionFactoryImpl)
    92  	// set the execution grpc port
    93  	connectionFactory.ExecutionGRPCPort = en.port
    94  
    95  	// set metrics reporting
    96  	connectionFactory.AccessMetrics = metrics
    97  	connectionFactory.Manager = NewManager(
    98  		logger,
    99  		connectionFactory.AccessMetrics,
   100  		nil,
   101  		0,
   102  		CircuitBreakerConfig{},
   103  		grpcutils.NoCompressor,
   104  	)
   105  
   106  	proxyConnectionFactory := ProxyConnectionFactory{
   107  		ConnectionFactory: connectionFactory,
   108  		targetAddress:     en.listener.Addr().String(),
   109  	}
   110  
   111  	// get an execution API client
   112  	client, _, err := proxyConnectionFactory.GetExecutionAPIClient("foo")
   113  	assert.NoError(t, err)
   114  
   115  	ctx := context.Background()
   116  	// make the call to the execution node
   117  	resp, err := client.Ping(ctx, req)
   118  	assert.NoError(t, err)
   119  	assert.Equal(t, resp, expected)
   120  }
   121  
   122  func TestProxyAccessAPIConnectionReuse(t *testing.T) {
   123  	logger := unittest.Logger()
   124  	metrics := metrics.NewNoopCollector()
   125  
   126  	// create a collection node
   127  	cn := new(collectionNode)
   128  	cn.start(t)
   129  	defer cn.stop(t)
   130  
   131  	req := &access.PingRequest{}
   132  	expected := &access.PingResponse{}
   133  	cn.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   134  
   135  	// create the factory
   136  	connectionFactory := new(ConnectionFactoryImpl)
   137  	// set the collection grpc port
   138  	connectionFactory.CollectionGRPCPort = cn.port
   139  
   140  	// set the connection pool cache size
   141  	cacheSize := 1
   142  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   143  	require.NoError(t, err)
   144  
   145  	// set metrics reporting
   146  	connectionFactory.AccessMetrics = metrics
   147  	connectionFactory.Manager = NewManager(
   148  		logger,
   149  		connectionFactory.AccessMetrics,
   150  		connectionCache,
   151  		0,
   152  		CircuitBreakerConfig{},
   153  		grpcutils.NoCompressor,
   154  	)
   155  
   156  	proxyConnectionFactory := ProxyConnectionFactory{
   157  		ConnectionFactory: connectionFactory,
   158  		targetAddress:     cn.listener.Addr().String(),
   159  	}
   160  
   161  	// get a collection API client
   162  	_, closer, err := proxyConnectionFactory.GetAccessAPIClient("foo", nil)
   163  	assert.Equal(t, connectionCache.Len(), 1)
   164  	assert.NoError(t, err)
   165  	assert.Nil(t, closer.Close())
   166  
   167  	var conn *grpc.ClientConn
   168  	res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress)
   169  	assert.True(t, ok)
   170  	conn = res.ClientConn()
   171  
   172  	// check if api client can be rebuilt with retrieved connection
   173  	accessAPIClient := access.NewAccessAPIClient(conn)
   174  	ctx := context.Background()
   175  	resp, err := accessAPIClient.Ping(ctx, req)
   176  	assert.NoError(t, err)
   177  	assert.Equal(t, resp, expected)
   178  }
   179  
   180  func TestProxyExecutionAPIConnectionReuse(t *testing.T) {
   181  	logger := unittest.Logger()
   182  	metrics := metrics.NewNoopCollector()
   183  
   184  	// create an execution node
   185  	en := new(executionNode)
   186  	en.start(t)
   187  	defer en.stop(t)
   188  
   189  	req := &execution.PingRequest{}
   190  	expected := &execution.PingResponse{}
   191  	en.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   192  
   193  	// create the factory
   194  	connectionFactory := new(ConnectionFactoryImpl)
   195  	// set the execution grpc port
   196  	connectionFactory.ExecutionGRPCPort = en.port
   197  
   198  	// set the connection pool cache size
   199  	cacheSize := 5
   200  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   201  	require.NoError(t, err)
   202  
   203  	// set metrics reporting
   204  	connectionFactory.AccessMetrics = metrics
   205  	connectionFactory.Manager = NewManager(
   206  		logger,
   207  		connectionFactory.AccessMetrics,
   208  		connectionCache,
   209  		0,
   210  		CircuitBreakerConfig{},
   211  		grpcutils.NoCompressor,
   212  	)
   213  
   214  	proxyConnectionFactory := ProxyConnectionFactory{
   215  		ConnectionFactory: connectionFactory,
   216  		targetAddress:     en.listener.Addr().String(),
   217  	}
   218  
   219  	// get an execution API client
   220  	_, closer, err := proxyConnectionFactory.GetExecutionAPIClient("foo")
   221  	assert.Equal(t, connectionCache.Len(), 1)
   222  	assert.NoError(t, err)
   223  	assert.Nil(t, closer.Close())
   224  
   225  	var conn *grpc.ClientConn
   226  	res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress)
   227  	assert.True(t, ok)
   228  	conn = res.ClientConn()
   229  
   230  	// check if api client can be rebuilt with retrieved connection
   231  	executionAPIClient := execution.NewExecutionAPIClient(conn)
   232  	ctx := context.Background()
   233  	resp, err := executionAPIClient.Ping(ctx, req)
   234  	assert.NoError(t, err)
   235  	assert.Equal(t, resp, expected)
   236  }
   237  
   238  // TestExecutionNodeClientTimeout tests that the execution API client times out after the timeout duration
   239  func TestExecutionNodeClientTimeout(t *testing.T) {
   240  	logger := unittest.Logger()
   241  	metrics := metrics.NewNoopCollector()
   242  
   243  	timeout := 10 * time.Millisecond
   244  
   245  	// create an execution node
   246  	en := new(executionNode)
   247  	en.start(t)
   248  	defer en.stop(t)
   249  
   250  	// setup the handler mock to not respond within the timeout
   251  	req := &execution.PingRequest{}
   252  	resp := &execution.PingResponse{}
   253  	en.handler.On("Ping", testifymock.Anything, req).After(timeout+time.Second).Return(resp, nil)
   254  
   255  	// create the factory
   256  	connectionFactory := new(ConnectionFactoryImpl)
   257  	// set the execution grpc port
   258  	connectionFactory.ExecutionGRPCPort = en.port
   259  	// set the execution grpc client timeout
   260  	connectionFactory.ExecutionNodeGRPCTimeout = timeout
   261  
   262  	// set the connection pool cache size
   263  	cacheSize := 5
   264  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   265  	require.NoError(t, err)
   266  
   267  	// set metrics reporting
   268  	connectionFactory.AccessMetrics = metrics
   269  	connectionFactory.Manager = NewManager(
   270  		logger,
   271  		connectionFactory.AccessMetrics,
   272  		connectionCache,
   273  		0,
   274  		CircuitBreakerConfig{},
   275  		grpcutils.NoCompressor,
   276  	)
   277  
   278  	// create the execution API client
   279  	client, _, err := connectionFactory.GetExecutionAPIClient(en.listener.Addr().String())
   280  	require.NoError(t, err)
   281  
   282  	ctx := context.Background()
   283  	// make the call to the execution node
   284  	_, err = client.Ping(ctx, req)
   285  
   286  	// assert that the client timed out
   287  	assert.Equal(t, codes.DeadlineExceeded, status.Code(err))
   288  }
   289  
   290  // TestCollectionNodeClientTimeout tests that the collection API client times out after the timeout duration
   291  func TestCollectionNodeClientTimeout(t *testing.T) {
   292  	logger := unittest.Logger()
   293  	metrics := metrics.NewNoopCollector()
   294  
   295  	timeout := 10 * time.Millisecond
   296  
   297  	// create a collection node
   298  	cn := new(collectionNode)
   299  	cn.start(t)
   300  	defer cn.stop(t)
   301  
   302  	// setup the handler mock to not respond within the timeout
   303  	req := &access.PingRequest{}
   304  	resp := &access.PingResponse{}
   305  	cn.handler.On("Ping", testifymock.Anything, req).After(timeout+time.Second).Return(resp, nil)
   306  
   307  	// create the factory
   308  	connectionFactory := new(ConnectionFactoryImpl)
   309  	// set the collection grpc port
   310  	connectionFactory.CollectionGRPCPort = cn.port
   311  	// set the collection grpc client timeout
   312  	connectionFactory.CollectionNodeGRPCTimeout = timeout
   313  
   314  	// set the connection pool cache size
   315  	cacheSize := 5
   316  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   317  	require.NoError(t, err)
   318  
   319  	// set metrics reporting
   320  	connectionFactory.AccessMetrics = metrics
   321  	connectionFactory.Manager = NewManager(
   322  		logger,
   323  		connectionFactory.AccessMetrics,
   324  		connectionCache,
   325  		0,
   326  		CircuitBreakerConfig{},
   327  		grpcutils.NoCompressor,
   328  	)
   329  
   330  	// create the collection API client
   331  	client, _, err := connectionFactory.GetAccessAPIClient(cn.listener.Addr().String(), nil)
   332  	assert.NoError(t, err)
   333  
   334  	ctx := context.Background()
   335  	// make the call to the execution node
   336  	_, err = client.Ping(ctx, req)
   337  
   338  	// assert that the client timed out
   339  	assert.Equal(t, codes.DeadlineExceeded, status.Code(err))
   340  }
   341  
   342  // TestConnectionPoolFull tests that the LRU cache replaces connections when full
   343  func TestConnectionPoolFull(t *testing.T) {
   344  	logger := unittest.Logger()
   345  	metrics := metrics.NewNoopCollector()
   346  
   347  	// create a collection node
   348  	cn1, cn2, cn3 := new(collectionNode), new(collectionNode), new(collectionNode)
   349  	cn1.start(t)
   350  	cn2.start(t)
   351  	cn3.start(t)
   352  	defer cn1.stop(t)
   353  	defer cn2.stop(t)
   354  	defer cn3.stop(t)
   355  
   356  	req := &access.PingRequest{}
   357  	expected := &access.PingResponse{}
   358  	cn1.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   359  	cn2.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   360  	cn3.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   361  
   362  	// create the factory
   363  	connectionFactory := new(ConnectionFactoryImpl)
   364  	// set the collection grpc port
   365  	connectionFactory.CollectionGRPCPort = cn1.port
   366  
   367  	// set the connection pool cache size
   368  	cacheSize := 2
   369  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   370  	require.NoError(t, err)
   371  
   372  	// set metrics reporting
   373  	connectionFactory.AccessMetrics = metrics
   374  	connectionFactory.Manager = NewManager(
   375  		logger,
   376  		connectionFactory.AccessMetrics,
   377  		connectionCache,
   378  		0,
   379  		CircuitBreakerConfig{},
   380  		grpcutils.NoCompressor,
   381  	)
   382  
   383  	cn1Address := "foo1:123"
   384  	cn2Address := "foo2:123"
   385  	cn3Address := "foo3:123"
   386  
   387  	// get a collection API client
   388  	// Create and add first client to cache
   389  	_, _, err = connectionFactory.GetAccessAPIClient(cn1Address, nil)
   390  	assert.Equal(t, connectionCache.Len(), 1)
   391  	assert.NoError(t, err)
   392  
   393  	// Create and add second client to cache
   394  	_, _, err = connectionFactory.GetAccessAPIClient(cn2Address, nil)
   395  	assert.Equal(t, connectionCache.Len(), 2)
   396  	assert.NoError(t, err)
   397  
   398  	// Get the first client from cache.
   399  	_, _, err = connectionFactory.GetAccessAPIClient(cn1Address, nil)
   400  	assert.Equal(t, connectionCache.Len(), 2)
   401  	assert.NoError(t, err)
   402  
   403  	// Create and add third client to cache, second client will be removed from cache
   404  	_, _, err = connectionFactory.GetAccessAPIClient(cn3Address, nil)
   405  	assert.Equal(t, connectionCache.Len(), 2)
   406  	assert.NoError(t, err)
   407  
   408  	var hostnameOrIP string
   409  
   410  	hostnameOrIP, _, err = net.SplitHostPort(cn1Address)
   411  	require.NoError(t, err)
   412  	grpcAddress1 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort)
   413  
   414  	hostnameOrIP, _, err = net.SplitHostPort(cn2Address)
   415  	require.NoError(t, err)
   416  	grpcAddress2 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort)
   417  
   418  	hostnameOrIP, _, err = net.SplitHostPort(cn3Address)
   419  	require.NoError(t, err)
   420  	grpcAddress3 := fmt.Sprintf("%s:%d", hostnameOrIP, connectionFactory.CollectionGRPCPort)
   421  
   422  	assert.True(t, connectionCache.cache.Contains(grpcAddress1))
   423  	assert.False(t, connectionCache.cache.Contains(grpcAddress2))
   424  	assert.True(t, connectionCache.cache.Contains(grpcAddress3))
   425  }
   426  
   427  // TestConnectionPoolStale tests that a new connection will be established if the old one cached is stale
   428  func TestConnectionPoolStale(t *testing.T) {
   429  	logger := unittest.Logger()
   430  	metrics := metrics.NewNoopCollector()
   431  
   432  	// create a collection node
   433  	cn := new(collectionNode)
   434  	cn.start(t)
   435  	defer cn.stop(t)
   436  
   437  	req := &access.PingRequest{}
   438  	expected := &access.PingResponse{}
   439  	cn.handler.On("Ping", testifymock.Anything, req).Return(expected, nil)
   440  
   441  	// create the factory
   442  	connectionFactory := new(ConnectionFactoryImpl)
   443  	// set the collection grpc port
   444  	connectionFactory.CollectionGRPCPort = cn.port
   445  
   446  	// set the connection pool cache size
   447  	cacheSize := 5
   448  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   449  	require.NoError(t, err)
   450  
   451  	// set metrics reporting
   452  	connectionFactory.AccessMetrics = metrics
   453  	connectionFactory.Manager = NewManager(
   454  		logger,
   455  		connectionFactory.AccessMetrics,
   456  		connectionCache,
   457  		0,
   458  		CircuitBreakerConfig{},
   459  		grpcutils.NoCompressor,
   460  	)
   461  
   462  	proxyConnectionFactory := ProxyConnectionFactory{
   463  		ConnectionFactory: connectionFactory,
   464  		targetAddress:     cn.listener.Addr().String(),
   465  	}
   466  
   467  	// get a collection API client
   468  	client, _, err := proxyConnectionFactory.GetAccessAPIClient("foo", nil)
   469  	assert.Equal(t, connectionCache.Len(), 1)
   470  	assert.NoError(t, err)
   471  	// close connection to simulate something "going wrong" with our stored connection
   472  	cachedClient, _ := connectionCache.cache.Get(proxyConnectionFactory.targetAddress)
   473  
   474  	cachedClient.Invalidate()
   475  	cachedClient.Close()
   476  
   477  	ctx := context.Background()
   478  	// make the call to the collection node (should fail, connection closed)
   479  	_, err = client.Ping(ctx, req)
   480  	assert.Error(t, err)
   481  
   482  	// re-access, should replace stale connection in cache with new one
   483  	_, _, _ = proxyConnectionFactory.GetAccessAPIClient("foo", nil)
   484  	assert.Equal(t, connectionCache.Len(), 1)
   485  
   486  	var conn *grpc.ClientConn
   487  	res, ok := connectionCache.cache.Get(proxyConnectionFactory.targetAddress)
   488  	assert.True(t, ok)
   489  	conn = res.ClientConn()
   490  
   491  	// check if api client can be rebuilt with retrieved connection
   492  	accessAPIClient := access.NewAccessAPIClient(conn)
   493  	ctx = context.Background()
   494  	resp, err := accessAPIClient.Ping(ctx, req)
   495  	assert.NoError(t, err)
   496  	assert.Equal(t, resp, expected)
   497  }
   498  
   499  // TestExecutionNodeClientClosedGracefully tests the scenario where the execution node client is closed gracefully.
   500  //
   501  // Test Steps:
   502  // - Generate a random number of requests and start goroutines to handle each request.
   503  // - Invalidate the execution API client.
   504  // - Wait for all goroutines to finish.
   505  // - Verify that the number of completed requests matches the number of sent responses.
   506  func TestExecutionNodeClientClosedGracefully(t *testing.T) {
   507  	logger := unittest.Logger()
   508  	metrics := metrics.NewNoopCollector()
   509  
   510  	// Add createExecNode function to recreate it each time for rapid test
   511  	createExecNode := func() (*executionNode, func()) {
   512  		en := new(executionNode)
   513  		en.start(t)
   514  		return en, func() {
   515  			en.stop(t)
   516  		}
   517  	}
   518  
   519  	// Add rapid test, to check graceful close on different number of requests
   520  	rapid.Check(t, func(tt *rapid.T) {
   521  		en, closer := createExecNode()
   522  		defer closer()
   523  
   524  		// setup the handler mock
   525  		req := &execution.PingRequest{}
   526  		resp := &execution.PingResponse{}
   527  		respSent := atomic.NewUint64(0)
   528  		en.handler.On("Ping", testifymock.Anything, req).Run(func(_ testifymock.Arguments) {
   529  			respSent.Inc()
   530  		}).Return(resp, nil)
   531  
   532  		// create the factory
   533  		connectionFactory := new(ConnectionFactoryImpl)
   534  		// set the execution grpc port
   535  		connectionFactory.ExecutionGRPCPort = en.port
   536  		// set the execution grpc client timeout
   537  		connectionFactory.ExecutionNodeGRPCTimeout = time.Second
   538  
   539  		// set the connection pool cache size
   540  		cacheSize := 1
   541  		connectionCache, err := NewCache(logger, metrics, cacheSize)
   542  		require.NoError(t, err)
   543  
   544  		// set metrics reporting
   545  		connectionFactory.AccessMetrics = metrics
   546  		connectionFactory.Manager = NewManager(
   547  			logger,
   548  			connectionFactory.AccessMetrics,
   549  			connectionCache,
   550  			0,
   551  			CircuitBreakerConfig{},
   552  			grpcutils.NoCompressor,
   553  		)
   554  
   555  		clientAddress := en.listener.Addr().String()
   556  		// create the execution API client
   557  		client, _, err := connectionFactory.GetExecutionAPIClient(clientAddress)
   558  		assert.NoError(t, err)
   559  
   560  		ctx := context.Background()
   561  
   562  		// Generate random number of requests
   563  		nofRequests := rapid.IntRange(10, 100).Draw(tt, "nofRequests")
   564  		reqCompleted := atomic.NewUint64(0)
   565  
   566  		var waitGroup sync.WaitGroup
   567  
   568  		for i := 0; i < nofRequests; i++ {
   569  			waitGroup.Add(1)
   570  
   571  			// call Ping request from different goroutines
   572  			go func() {
   573  				defer waitGroup.Done()
   574  				_, err := client.Ping(ctx, req)
   575  
   576  				if err == nil {
   577  					reqCompleted.Inc()
   578  				} else {
   579  					require.Equalf(t, codes.Unavailable, status.Code(err), "unexpected error: %v", err)
   580  				}
   581  			}()
   582  		}
   583  
   584  		// Close connection
   585  		// connectionFactory.Manager.Remove(clientAddress)
   586  
   587  		waitGroup.Wait()
   588  
   589  		assert.Equal(t, reqCompleted.Load(), respSent.Load())
   590  	})
   591  }
   592  
   593  // TestEvictingCacheClients tests the eviction of cached clients.
   594  // It verifies that when a client is evicted from the cache, subsequent requests are handled correctly.
   595  //
   596  // Test Steps:
   597  //   - Call the gRPC method Ping
   598  //   - While the request is still in progress, remove the connection
   599  //   - Call the gRPC method GetNetworkParameters on the client immediately after eviction and assert the expected
   600  //     error response.
   601  //   - Wait for the client state to change from "Ready" to "Shutdown", indicating that the client connection was closed.
   602  func TestEvictingCacheClients(t *testing.T) {
   603  	logger := unittest.Logger()
   604  	metrics := metrics.NewNoopCollector()
   605  
   606  	// Create a new collection node for testing
   607  	cn := new(collectionNode)
   608  	cn.start(t)
   609  	defer cn.stop(t)
   610  
   611  	// Channels used to synchronize test with grpc calls
   612  	startPing := make(chan struct{})      // notify Ping in progress
   613  	returnFromPing := make(chan struct{}) // notify OK to return from Ping
   614  
   615  	// Set up mock handlers for Ping and GetNetworkParameters
   616  	pingReq := &access.PingRequest{}
   617  	pingResp := &access.PingResponse{}
   618  	cn.handler.On("Ping", testifymock.Anything, pingReq).Return(
   619  		func(context.Context, *access.PingRequest) *access.PingResponse {
   620  			close(startPing)
   621  			<-returnFromPing // keeps request open until returnFromPing is closed
   622  			return pingResp
   623  		},
   624  		func(context.Context, *access.PingRequest) error { return nil },
   625  	)
   626  
   627  	netReq := &access.GetNetworkParametersRequest{}
   628  	netResp := &access.GetNetworkParametersResponse{}
   629  	cn.handler.On("GetNetworkParameters", testifymock.Anything, netReq).Return(netResp, nil)
   630  
   631  	// Create the connection factory
   632  	connectionFactory := new(ConnectionFactoryImpl)
   633  	// Set the gRPC port
   634  	connectionFactory.CollectionGRPCPort = cn.port
   635  	// Set the gRPC client timeout
   636  	connectionFactory.CollectionNodeGRPCTimeout = 5 * time.Second
   637  	// Set the connection pool cache size
   638  	cacheSize := 1
   639  
   640  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   641  	require.NoError(t, err)
   642  
   643  	// create a non-blocking cache
   644  	connectionCache.cache, err = lru.NewWithEvict[string, *CachedClient](cacheSize, func(_ string, client *CachedClient) {
   645  		go client.Close()
   646  	})
   647  	require.NoError(t, err)
   648  
   649  	// set metrics reporting
   650  	connectionFactory.AccessMetrics = metrics
   651  	connectionFactory.Manager = NewManager(
   652  		logger,
   653  		connectionFactory.AccessMetrics,
   654  		connectionCache,
   655  		0,
   656  		CircuitBreakerConfig{},
   657  		grpcutils.NoCompressor,
   658  	)
   659  
   660  	clientAddress := cn.listener.Addr().String()
   661  	// Create the execution API client
   662  	client, _, err := connectionFactory.GetAccessAPIClient(clientAddress, nil)
   663  	require.NoError(t, err)
   664  
   665  	ctx := context.Background()
   666  
   667  	// Retrieve the cached client from the cache
   668  	cachedClient, ok := connectionCache.cache.Get(clientAddress)
   669  	require.True(t, ok)
   670  
   671  	// wait until the client connection is ready
   672  	require.Eventually(t, func() bool {
   673  		return cachedClient.ClientConn().GetState() == connectivity.Ready
   674  	}, 100*time.Millisecond, 10*time.Millisecond, "client timed out before ready")
   675  
   676  	// Schedule the invalidation of the access API client while the Ping call is in progress
   677  	wg := sync.WaitGroup{}
   678  	wg.Add(1)
   679  	go func() {
   680  		defer wg.Done()
   681  
   682  		<-startPing // wait until Ping is called
   683  
   684  		// Invalidate the access API client
   685  		cachedClient.Invalidate()
   686  
   687  		// Invalidate marks the connection for closure asynchronously, so give it some time to run
   688  		require.Eventually(t, func() bool {
   689  			return cachedClient.closeRequested.Load()
   690  		}, 100*time.Millisecond, 10*time.Millisecond, "client timed out closing connection")
   691  
   692  		// Call a gRPC method on the client, requests should be blocked since the connection is invalidated
   693  		resp, err := client.GetNetworkParameters(ctx, netReq)
   694  		assert.Equal(t, status.Errorf(codes.Unavailable, "the connection to %s was closed", clientAddress), err)
   695  		assert.Nil(t, resp)
   696  
   697  		close(returnFromPing) // signal it's ok to return from Ping
   698  	}()
   699  
   700  	// Call a gRPC method on the client
   701  	_, err = client.Ping(ctx, pingReq)
   702  	// Check that Ping was called
   703  	cn.handler.AssertCalled(t, "Ping", testifymock.Anything, pingReq)
   704  	assert.NoError(t, err)
   705  
   706  	// Wait for the client connection to change state from "Ready" to "Shutdown" as connection was closed.
   707  	require.Eventually(t, func() bool {
   708  		return cachedClient.ClientConn().WaitForStateChange(ctx, connectivity.Ready)
   709  	}, 100*time.Millisecond, 10*time.Millisecond, "client timed out transitioning state")
   710  
   711  	assert.Equal(t, connectivity.Shutdown, cachedClient.ClientConn().GetState())
   712  	assert.Equal(t, 0, connectionCache.Len())
   713  
   714  	wg.Wait() // wait until the move test routine is done
   715  }
   716  
   717  func TestConcurrentConnections(t *testing.T) {
   718  	logger := unittest.Logger()
   719  	metrics := metrics.NewNoopCollector()
   720  
   721  	// Add createExecNode function to recreate it each time for rapid test
   722  	createExecNode := func() (*executionNode, func()) {
   723  		en := new(executionNode)
   724  		en.start(t)
   725  		return en, func() {
   726  			en.stop(t)
   727  		}
   728  	}
   729  
   730  	// setup the handler mock
   731  	req := &execution.PingRequest{}
   732  	resp := &execution.PingResponse{}
   733  
   734  	// Note: rapid will randomly fail with an error: "group did not use any data from bitstream"
   735  	// See https://github.com/flyingmutant/rapid/issues/65
   736  	rapid.Check(t, func(tt *rapid.T) {
   737  		en, closer := createExecNode()
   738  		defer closer()
   739  
   740  		// Note: rapid does not support concurrent calls to Draw for a given T, so they must be serialized
   741  		mu := sync.Mutex{}
   742  		getSleep := func() time.Duration {
   743  			mu.Lock()
   744  			defer mu.Unlock()
   745  			return time.Duration(rapid.Int64Range(100, 10_000).Draw(tt, "s"))
   746  		}
   747  
   748  		requestCount := rapid.IntRange(50, 1000).Draw(tt, "r")
   749  		responsesSent := atomic.NewInt32(0)
   750  		en.handler.
   751  			On("Ping", testifymock.Anything, req).
   752  			Return(func(_ context.Context, _ *execution.PingRequest) (*execution.PingResponse, error) {
   753  				time.Sleep(getSleep() * time.Microsecond)
   754  
   755  				// randomly fail ~25% of the time to test that client connection and reuse logic
   756  				// handles concurrent connect/disconnects
   757  				fail, err := rand.Int(rand.Reader, big.NewInt(4))
   758  				require.NoError(tt, err)
   759  
   760  				if fail.Uint64()%4 == 0 {
   761  					err = status.Errorf(codes.Unavailable, "random error")
   762  				}
   763  
   764  				responsesSent.Inc()
   765  				return resp, err
   766  			})
   767  
   768  		connectionCache, err := NewCache(logger, metrics, 1)
   769  		require.NoError(tt, err)
   770  
   771  		connectionFactory := &ConnectionFactoryImpl{
   772  			ExecutionGRPCPort:        en.port,
   773  			ExecutionNodeGRPCTimeout: time.Second,
   774  			AccessMetrics:            metrics,
   775  			Manager: NewManager(
   776  				logger,
   777  				metrics,
   778  				connectionCache,
   779  				0,
   780  				CircuitBreakerConfig{},
   781  				grpcutils.NoCompressor,
   782  			),
   783  		}
   784  
   785  		clientAddress := en.listener.Addr().String()
   786  
   787  		ctx := context.Background()
   788  
   789  		// Generate random number of requests
   790  		var wg sync.WaitGroup
   791  		wg.Add(requestCount)
   792  
   793  		for i := 0; i < requestCount; i++ {
   794  			go func() {
   795  				defer wg.Done()
   796  
   797  				client, _, err := connectionFactory.GetExecutionAPIClient(clientAddress)
   798  				require.NoError(tt, err)
   799  
   800  				_, err = client.Ping(ctx, req)
   801  
   802  				if err != nil {
   803  					// Note: for some reason, when Unavailable is returned, the error message is
   804  					// changed to "the connection to 127.0.0.1:57753 was closed". Other error codes
   805  					// preserve the message.
   806  					require.Equalf(tt, codes.Unavailable, status.Code(err), "unexpected error: %v", err)
   807  				}
   808  			}()
   809  		}
   810  		wg.Wait()
   811  
   812  		// the grpc client seems to throttle requests to servers that return Unavailable, so not
   813  		// all of the requests make it through to the backend every test. Requiring that at least 1
   814  		// request is handled for these cases, but all should be handled in most runs.
   815  		assert.LessOrEqual(tt, responsesSent.Load(), int32(requestCount))
   816  		assert.Greater(tt, responsesSent.Load(), int32(0))
   817  	})
   818  }
   819  
   820  var successCodes = []codes.Code{
   821  	codes.Canceled,
   822  	codes.InvalidArgument,
   823  	codes.NotFound,
   824  	codes.Unimplemented,
   825  	codes.OutOfRange,
   826  }
   827  
   828  // TestCircuitBreakerExecutionNode tests the circuit breaker for execution nodes.
   829  func TestCircuitBreakerExecutionNode(t *testing.T) {
   830  	logger := unittest.Logger()
   831  	metrics := metrics.NewNoopCollector()
   832  
   833  	requestTimeout := 500 * time.Millisecond
   834  	circuitBreakerRestoreTimeout := 1500 * time.Millisecond
   835  
   836  	// Create an execution node for testing.
   837  	en := new(executionNode)
   838  	en.start(t)
   839  	defer en.stop(t)
   840  
   841  	// Create the connection factory.
   842  	connectionFactory := new(ConnectionFactoryImpl)
   843  
   844  	// Set the execution gRPC port.
   845  	connectionFactory.ExecutionGRPCPort = en.port
   846  
   847  	// Set the execution gRPC client requestTimeout.
   848  	connectionFactory.ExecutionNodeGRPCTimeout = requestTimeout
   849  
   850  	// Set the connection pool cache size.
   851  	cacheSize := 1
   852  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   853  	require.NoError(t, err)
   854  
   855  	connectionFactory.Manager = NewManager(
   856  		logger,
   857  		connectionFactory.AccessMetrics,
   858  		connectionCache,
   859  		0,
   860  		CircuitBreakerConfig{
   861  			Enabled:        true,
   862  			MaxFailures:    1,
   863  			MaxRequests:    1,
   864  			RestoreTimeout: circuitBreakerRestoreTimeout,
   865  		},
   866  		grpcutils.NoCompressor,
   867  	)
   868  
   869  	// Set metrics reporting.
   870  	connectionFactory.AccessMetrics = metrics
   871  
   872  	// Create the execution API client.
   873  	client, _, err := connectionFactory.GetExecutionAPIClient(en.listener.Addr().String())
   874  	require.NoError(t, err)
   875  
   876  	req := &execution.PingRequest{}
   877  	resp := &execution.PingResponse{}
   878  
   879  	// Helper function to make the Ping call to the execution node and measure the duration.
   880  	callAndMeasurePingDuration := func(ctx context.Context) (time.Duration, error) {
   881  		start := time.Now()
   882  
   883  		// Make the call to the execution node.
   884  		_, err = client.Ping(ctx, req)
   885  		en.handler.AssertCalled(t, "Ping", testifymock.Anything, req)
   886  
   887  		return time.Since(start), err
   888  	}
   889  
   890  	t.Run("test different states of the circuit breaker", func(t *testing.T) {
   891  		ctx := context.Background()
   892  
   893  		// Set up the handler mock to not respond within the requestTimeout.
   894  		en.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil)
   895  
   896  		// Call and measure the duration for the first invocation.
   897  		duration, err := callAndMeasurePingDuration(ctx)
   898  		assert.Equal(t, codes.DeadlineExceeded, status.Code(err))
   899  		assert.LessOrEqual(t, requestTimeout, duration)
   900  
   901  		// Call and measure the duration for the second invocation (circuit breaker state is now "Open").
   902  		duration, err = callAndMeasurePingDuration(ctx)
   903  		assert.Equal(t, gobreaker.ErrOpenState, err)
   904  		assert.Greater(t, requestTimeout, duration)
   905  
   906  		// Reset the mock Ping for the next invocation to return response without delay
   907  		en.handler.On("Ping", testifymock.Anything, req).Unset()
   908  		en.handler.On("Ping", testifymock.Anything, req).Return(resp, nil)
   909  
   910  		// Wait until the circuit breaker transitions to the "HalfOpen" state.
   911  		time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond))
   912  
   913  		// Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen").
   914  		duration, err = callAndMeasurePingDuration(ctx)
   915  		assert.Greater(t, requestTimeout, duration)
   916  		assert.Equal(t, nil, err)
   917  	})
   918  
   919  	for _, code := range successCodes {
   920  		t.Run(fmt.Sprintf("test error %s treated as a success for circuit breaker ", code.String()), func(t *testing.T) {
   921  			ctx := context.Background()
   922  
   923  			en.handler.On("Ping", testifymock.Anything, req).Unset()
   924  			en.handler.On("Ping", testifymock.Anything, req).Return(nil, status.Error(code, code.String()))
   925  
   926  			duration, err := callAndMeasurePingDuration(ctx)
   927  			require.Error(t, err)
   928  			require.Equal(t, code, status.Code(err))
   929  			require.Greater(t, requestTimeout, duration)
   930  		})
   931  	}
   932  }
   933  
   934  // TestCircuitBreakerCollectionNode tests the circuit breaker for collection nodes.
   935  func TestCircuitBreakerCollectionNode(t *testing.T) {
   936  	logger := unittest.Logger()
   937  	metrics := metrics.NewNoopCollector()
   938  
   939  	requestTimeout := 500 * time.Millisecond
   940  	circuitBreakerRestoreTimeout := 1500 * time.Millisecond
   941  
   942  	// Create a collection node for testing.
   943  	cn := new(collectionNode)
   944  	cn.start(t)
   945  	defer cn.stop(t)
   946  
   947  	// Create the connection factory.
   948  	connectionFactory := new(ConnectionFactoryImpl)
   949  
   950  	// Set the collection gRPC port.
   951  	connectionFactory.CollectionGRPCPort = cn.port
   952  
   953  	// Set the collection gRPC client requestTimeout.
   954  	connectionFactory.CollectionNodeGRPCTimeout = requestTimeout
   955  
   956  	// Set the connection pool cache size.
   957  	cacheSize := 1
   958  	connectionCache, err := NewCache(logger, metrics, cacheSize)
   959  	require.NoError(t, err)
   960  
   961  	connectionFactory.Manager = NewManager(
   962  		logger,
   963  		connectionFactory.AccessMetrics,
   964  		connectionCache,
   965  		0,
   966  		CircuitBreakerConfig{
   967  			Enabled:        true,
   968  			MaxFailures:    1,
   969  			MaxRequests:    1,
   970  			RestoreTimeout: circuitBreakerRestoreTimeout,
   971  		},
   972  		grpcutils.NoCompressor,
   973  	)
   974  
   975  	// Set metrics reporting.
   976  	connectionFactory.AccessMetrics = metrics
   977  
   978  	// Create the collection API client.
   979  	client, _, err := connectionFactory.GetAccessAPIClient(cn.listener.Addr().String(), nil)
   980  	assert.NoError(t, err)
   981  
   982  	req := &access.PingRequest{}
   983  	resp := &access.PingResponse{}
   984  
   985  	// Helper function to make the Ping call to the collection node and measure the duration.
   986  	callAndMeasurePingDuration := func(ctx context.Context) (time.Duration, error) {
   987  		start := time.Now()
   988  
   989  		// Make the call to the collection node.
   990  		_, err = client.Ping(ctx, req)
   991  		cn.handler.AssertCalled(t, "Ping", testifymock.Anything, req)
   992  
   993  		return time.Since(start), err
   994  	}
   995  
   996  	t.Run("test different states of the circuit breaker", func(t *testing.T) {
   997  		ctx := context.Background()
   998  
   999  		// Set up the handler mock to not respond within the requestTimeout.
  1000  		cn.handler.On("Ping", testifymock.Anything, req).After(2*requestTimeout).Return(resp, nil)
  1001  
  1002  		// Call and measure the duration for the first invocation.
  1003  		duration, err := callAndMeasurePingDuration(ctx)
  1004  		assert.Equal(t, codes.DeadlineExceeded, status.Code(err))
  1005  		assert.LessOrEqual(t, requestTimeout, duration)
  1006  
  1007  		// Call and measure the duration for the second invocation (circuit breaker state is now "Open").
  1008  		duration, err = callAndMeasurePingDuration(ctx)
  1009  		assert.Equal(t, gobreaker.ErrOpenState, err)
  1010  		assert.Greater(t, requestTimeout, duration)
  1011  
  1012  		// Reset the mock Ping for the next invocation to return response without delay
  1013  		cn.handler.On("Ping", testifymock.Anything, req).Unset()
  1014  		cn.handler.On("Ping", testifymock.Anything, req).Return(resp, nil)
  1015  
  1016  		// Wait until the circuit breaker transitions to the "HalfOpen" state.
  1017  		time.Sleep(circuitBreakerRestoreTimeout + (500 * time.Millisecond))
  1018  
  1019  		// Call and measure the duration for the third invocation (circuit breaker state is now "HalfOpen").
  1020  		duration, err = callAndMeasurePingDuration(ctx)
  1021  		assert.Greater(t, requestTimeout, duration)
  1022  		assert.Equal(t, nil, err)
  1023  	})
  1024  
  1025  	for _, code := range successCodes {
  1026  		t.Run(fmt.Sprintf("test error %s treated as a success for circuit breaker ", code.String()), func(t *testing.T) {
  1027  			ctx := context.Background()
  1028  
  1029  			cn.handler.On("Ping", testifymock.Anything, req).Unset()
  1030  			cn.handler.On("Ping", testifymock.Anything, req).Return(nil, status.Error(code, code.String()))
  1031  
  1032  			duration, err := callAndMeasurePingDuration(ctx)
  1033  			require.Error(t, err)
  1034  			require.Equal(t, code, status.Code(err))
  1035  			require.Greater(t, requestTimeout, duration)
  1036  		})
  1037  	}
  1038  }