github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rpc/rate_limit_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  	"time"
    10  
    11  	accessproto "github.com/onflow/flow/protobuf/go/flow/access"
    12  	"github.com/rs/zerolog"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/require"
    16  	"github.com/stretchr/testify/suite"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/credentials"
    20  	"google.golang.org/grpc/credentials/insecure"
    21  	"google.golang.org/grpc/status"
    22  
    23  	accessmock "github.com/onflow/flow-go/engine/access/mock"
    24  	"github.com/onflow/flow-go/engine/access/rpc/backend"
    25  	statestreambackend "github.com/onflow/flow-go/engine/access/state_stream/backend"
    26  	"github.com/onflow/flow-go/model/flow"
    27  	"github.com/onflow/flow-go/module/grpcserver"
    28  	"github.com/onflow/flow-go/module/irrecoverable"
    29  	"github.com/onflow/flow-go/module/metrics"
    30  	module "github.com/onflow/flow-go/module/mock"
    31  	"github.com/onflow/flow-go/network"
    32  	protocol "github.com/onflow/flow-go/state/protocol/mock"
    33  	storagemock "github.com/onflow/flow-go/storage/mock"
    34  	"github.com/onflow/flow-go/utils/grpcutils"
    35  	"github.com/onflow/flow-go/utils/unittest"
    36  )
    37  
    38  type RateLimitTestSuite struct {
    39  	suite.Suite
    40  	state      *protocol.State
    41  	snapshot   *protocol.Snapshot
    42  	epochQuery *protocol.EpochQuery
    43  	log        zerolog.Logger
    44  	net        *network.EngineRegistry
    45  	request    *module.Requester
    46  	collClient *accessmock.AccessAPIClient
    47  	execClient *accessmock.ExecutionAPIClient
    48  	me         *module.Local
    49  	chainID    flow.ChainID
    50  	metrics    *metrics.NoopCollector
    51  	rpcEng     *Engine
    52  	client     accessproto.AccessAPIClient
    53  	closer     io.Closer
    54  
    55  	// storage
    56  	blocks       *storagemock.Blocks
    57  	headers      *storagemock.Headers
    58  	collections  *storagemock.Collections
    59  	transactions *storagemock.Transactions
    60  	receipts     *storagemock.ExecutionReceipts
    61  
    62  	// test rate limit
    63  	rateLimit  int
    64  	burstLimit int
    65  
    66  	ctx    irrecoverable.SignalerContext
    67  	cancel context.CancelFunc
    68  
    69  	// grpc servers
    70  	secureGrpcServer   *grpcserver.GrpcServer
    71  	unsecureGrpcServer *grpcserver.GrpcServer
    72  }
    73  
    74  func (suite *RateLimitTestSuite) SetupTest() {
    75  	suite.log = zerolog.New(os.Stdout)
    76  	suite.net = new(network.EngineRegistry)
    77  	suite.state = new(protocol.State)
    78  	suite.snapshot = new(protocol.Snapshot)
    79  
    80  	rootHeader := unittest.BlockHeaderFixture()
    81  	params := new(protocol.Params)
    82  	params.On("SporkID").Return(unittest.IdentifierFixture(), nil)
    83  	params.On("ProtocolVersion").Return(uint(unittest.Uint64InRange(10, 30)), nil)
    84  	params.On("SporkRootBlockHeight").Return(rootHeader.Height, nil)
    85  	params.On("SealedRoot").Return(rootHeader, nil)
    86  
    87  	suite.epochQuery = new(protocol.EpochQuery)
    88  	suite.state.On("Sealed").Return(suite.snapshot, nil).Maybe()
    89  	suite.state.On("Final").Return(suite.snapshot, nil).Maybe()
    90  	suite.state.On("Params").Return(params, nil).Maybe()
    91  	suite.snapshot.On("Epochs").Return(suite.epochQuery).Maybe()
    92  	suite.blocks = new(storagemock.Blocks)
    93  	suite.headers = new(storagemock.Headers)
    94  	suite.transactions = new(storagemock.Transactions)
    95  	suite.collections = new(storagemock.Collections)
    96  	suite.receipts = new(storagemock.ExecutionReceipts)
    97  
    98  	suite.collClient = new(accessmock.AccessAPIClient)
    99  	suite.execClient = new(accessmock.ExecutionAPIClient)
   100  
   101  	suite.request = new(module.Requester)
   102  	suite.request.On("EntityByID", mock.Anything, mock.Anything)
   103  
   104  	suite.me = new(module.Local)
   105  
   106  	accessIdentity := unittest.IdentityFixture(unittest.WithRole(flow.RoleAccess))
   107  	suite.me.
   108  		On("NodeID").
   109  		Return(accessIdentity.NodeID)
   110  
   111  	suite.chainID = flow.Testnet
   112  	suite.metrics = metrics.NewNoopCollector()
   113  
   114  	config := Config{
   115  		UnsecureGRPCListenAddr: unittest.DefaultAddress,
   116  		SecureGRPCListenAddr:   unittest.DefaultAddress,
   117  		HTTPListenAddr:         unittest.DefaultAddress,
   118  	}
   119  
   120  	// generate a server certificate that will be served by the GRPC server
   121  	networkingKey := unittest.NetworkingPrivKeyFixture()
   122  	x509Certificate, err := grpcutils.X509Certificate(networkingKey)
   123  	assert.NoError(suite.T(), err)
   124  	tlsConfig := grpcutils.DefaultServerTLSConfig(x509Certificate)
   125  	// set the transport credentials for the server to use
   126  	config.TransportCredentials = credentials.NewTLS(tlsConfig)
   127  
   128  	// set the rate limit to test with
   129  	suite.rateLimit = 2
   130  	// set the burst limit to test with
   131  	suite.burstLimit = 2
   132  
   133  	apiRateLimt := map[string]int{
   134  		"Ping": suite.rateLimit,
   135  	}
   136  
   137  	apiBurstLimt := map[string]int{
   138  		"Ping": suite.rateLimit,
   139  	}
   140  
   141  	suite.secureGrpcServer = grpcserver.NewGrpcServerBuilder(suite.log,
   142  		config.SecureGRPCListenAddr,
   143  		grpcutils.DefaultMaxMsgSize,
   144  		false,
   145  		apiRateLimt,
   146  		apiBurstLimt,
   147  		grpcserver.WithTransportCredentials(config.TransportCredentials)).Build()
   148  
   149  	suite.unsecureGrpcServer = grpcserver.NewGrpcServerBuilder(suite.log,
   150  		config.UnsecureGRPCListenAddr,
   151  		grpcutils.DefaultMaxMsgSize,
   152  		false,
   153  		apiRateLimt,
   154  		apiBurstLimt).Build()
   155  
   156  	block := unittest.BlockHeaderFixture()
   157  	suite.snapshot.On("Head").Return(block, nil)
   158  
   159  	bnd, err := backend.New(backend.Params{
   160  		State:                suite.state,
   161  		CollectionRPC:        suite.collClient,
   162  		Blocks:               suite.blocks,
   163  		Headers:              suite.headers,
   164  		Collections:          suite.collections,
   165  		Transactions:         suite.transactions,
   166  		ChainID:              suite.chainID,
   167  		AccessMetrics:        suite.metrics,
   168  		MaxHeightRange:       0,
   169  		Log:                  suite.log,
   170  		SnapshotHistoryLimit: 0,
   171  		Communicator:         backend.NewNodeCommunicator(false),
   172  	})
   173  	suite.Require().NoError(err)
   174  
   175  	stateStreamConfig := statestreambackend.Config{}
   176  	rpcEngBuilder, err := NewBuilder(
   177  		suite.log,
   178  		suite.state,
   179  		config,
   180  		suite.chainID,
   181  		suite.metrics,
   182  		false,
   183  		suite.me,
   184  		bnd,
   185  		bnd,
   186  		suite.secureGrpcServer,
   187  		suite.unsecureGrpcServer,
   188  		nil,
   189  		stateStreamConfig)
   190  	require.NoError(suite.T(), err)
   191  	suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build()
   192  	require.NoError(suite.T(), err)
   193  	suite.ctx, suite.cancel = irrecoverable.NewMockSignalerContextWithCancel(suite.T(), context.Background())
   194  
   195  	suite.rpcEng.Start(suite.ctx)
   196  
   197  	suite.secureGrpcServer.Start(suite.ctx)
   198  	suite.unsecureGrpcServer.Start(suite.ctx)
   199  
   200  	// wait for the servers to startup
   201  	unittest.AssertClosesBefore(suite.T(), suite.secureGrpcServer.Ready(), 2*time.Second)
   202  	unittest.AssertClosesBefore(suite.T(), suite.unsecureGrpcServer.Ready(), 2*time.Second)
   203  
   204  	// wait for the engine to startup
   205  	unittest.RequireCloseBefore(suite.T(), suite.rpcEng.Ready(), 2*time.Second, "engine not ready at startup")
   206  
   207  	// create the access api client
   208  	suite.client, suite.closer, err = accessAPIClient(suite.unsecureGrpcServer.GRPCAddress().String())
   209  	require.NoError(suite.T(), err)
   210  }
   211  
   212  func (suite *RateLimitTestSuite) TearDownTest() {
   213  	if suite.cancel != nil {
   214  		suite.cancel()
   215  	}
   216  	// close the client
   217  	if suite.closer != nil {
   218  		suite.closer.Close()
   219  	}
   220  	// close servers
   221  	unittest.AssertClosesBefore(suite.T(), suite.secureGrpcServer.Done(), 2*time.Second)
   222  	unittest.AssertClosesBefore(suite.T(), suite.unsecureGrpcServer.Done(), 2*time.Second)
   223  }
   224  
   225  func TestRateLimit(t *testing.T) {
   226  	suite.Run(t, new(RateLimitTestSuite))
   227  }
   228  
   229  // TestRatelimitingWithoutBurst tests that rate limit is correctly applied to an Access API call
   230  func (suite *RateLimitTestSuite) TestRatelimitingWithoutBurst() {
   231  
   232  	req := &accessproto.PingRequest{}
   233  	ctx := context.Background()
   234  
   235  	// expect 2 upstream calls
   236  	suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit)
   237  	suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit)
   238  
   239  	requestCnt := 0
   240  	// requests within the burst should succeed
   241  	for requestCnt < suite.rateLimit {
   242  		resp, err := suite.client.Ping(ctx, req)
   243  		assert.NoError(suite.T(), err)
   244  		assert.NotNil(suite.T(), resp)
   245  		// sleep to prevent burst
   246  		time.Sleep(100 * time.Millisecond)
   247  		requestCnt++
   248  	}
   249  
   250  	// request more than the limit should fail
   251  	_, err := suite.client.Ping(ctx, req)
   252  	suite.assertRateLimitError(err)
   253  }
   254  
   255  // TestRatelimitingWithBurst tests that burst limit is correctly applied to an Access API call
   256  func (suite *RateLimitTestSuite) TestRatelimitingWithBurst() {
   257  
   258  	req := &accessproto.PingRequest{}
   259  	ctx := context.Background()
   260  
   261  	// expect rpc.defaultBurst number of upstream calls
   262  	suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit)
   263  	suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit)
   264  
   265  	requestCnt := 0
   266  	// generate a permissible burst of request and assert that they succeed
   267  	for requestCnt < suite.burstLimit {
   268  		resp, err := suite.client.Ping(ctx, req)
   269  		assert.NoError(suite.T(), err)
   270  		assert.NotNil(suite.T(), resp)
   271  		requestCnt++
   272  	}
   273  
   274  	// request more than the permissible burst and assert that it fails
   275  	_, err := suite.client.Ping(ctx, req)
   276  	suite.assertRateLimitError(err)
   277  }
   278  
   279  func (suite *RateLimitTestSuite) assertRateLimitError(err error) {
   280  	assert.Error(suite.T(), err)
   281  	status, ok := status.FromError(err)
   282  	assert.True(suite.T(), ok)
   283  	assert.Equal(suite.T(), codes.ResourceExhausted, status.Code())
   284  }
   285  
   286  func accessAPIClient(address string) (accessproto.AccessAPIClient, io.Closer, error) {
   287  	conn, err := grpc.Dial(
   288  		address,
   289  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(grpcutils.DefaultMaxMsgSize)),
   290  		grpc.WithTransportCredentials(insecure.NewCredentials()))
   291  	if err != nil {
   292  		return nil, nil, fmt.Errorf("failed to connect to address %s: %w", address, err)
   293  	}
   294  	client := accessproto.NewAccessAPIClient(conn)
   295  	closer := io.Closer(conn)
   296  	return client, closer, nil
   297  }