github.com/koko1123/flow-go-1@v0.29.6/engine/access/rpc/rate_limit_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"testing"
     9  	"time"
    10  
    11  	accessproto "github.com/onflow/flow/protobuf/go/flow/access"
    12  	"github.com/rs/zerolog"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/suite"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/status"
    19  
    20  	accessmock "github.com/koko1123/flow-go-1/engine/access/mock"
    21  	"github.com/koko1123/flow-go-1/model/flow"
    22  	"github.com/koko1123/flow-go-1/module/metrics"
    23  	module "github.com/koko1123/flow-go-1/module/mock"
    24  	"github.com/koko1123/flow-go-1/network"
    25  	protocol "github.com/koko1123/flow-go-1/state/protocol/mock"
    26  	storagemock "github.com/koko1123/flow-go-1/storage/mock"
    27  	"github.com/koko1123/flow-go-1/utils/grpcutils"
    28  	"github.com/koko1123/flow-go-1/utils/unittest"
    29  )
    30  
    31  type RateLimitTestSuite struct {
    32  	suite.Suite
    33  	state      *protocol.State
    34  	snapshot   *protocol.Snapshot
    35  	epochQuery *protocol.EpochQuery
    36  	log        zerolog.Logger
    37  	net        *network.Network
    38  	request    *module.Requester
    39  	collClient *accessmock.AccessAPIClient
    40  	execClient *accessmock.ExecutionAPIClient
    41  	me         *module.Local
    42  	chainID    flow.ChainID
    43  	metrics    *metrics.NoopCollector
    44  	rpcEng     *Engine
    45  	client     accessproto.AccessAPIClient
    46  	closer     io.Closer
    47  
    48  	// storage
    49  	blocks       *storagemock.Blocks
    50  	headers      *storagemock.Headers
    51  	collections  *storagemock.Collections
    52  	transactions *storagemock.Transactions
    53  	receipts     *storagemock.ExecutionReceipts
    54  
    55  	// test rate limit
    56  	rateLimit  int
    57  	burstLimit int
    58  }
    59  
    60  func (suite *RateLimitTestSuite) SetupTest() {
    61  	suite.log = zerolog.New(os.Stdout)
    62  	suite.net = new(network.Network)
    63  	suite.state = new(protocol.State)
    64  	suite.snapshot = new(protocol.Snapshot)
    65  
    66  	suite.epochQuery = new(protocol.EpochQuery)
    67  	suite.state.On("Sealed").Return(suite.snapshot, nil).Maybe()
    68  	suite.state.On("Final").Return(suite.snapshot, nil).Maybe()
    69  	suite.snapshot.On("Epochs").Return(suite.epochQuery).Maybe()
    70  	suite.blocks = new(storagemock.Blocks)
    71  	suite.headers = new(storagemock.Headers)
    72  	suite.transactions = new(storagemock.Transactions)
    73  	suite.collections = new(storagemock.Collections)
    74  	suite.receipts = new(storagemock.ExecutionReceipts)
    75  
    76  	suite.collClient = new(accessmock.AccessAPIClient)
    77  	suite.execClient = new(accessmock.ExecutionAPIClient)
    78  
    79  	suite.request = new(module.Requester)
    80  	suite.request.On("EntityByID", mock.Anything, mock.Anything)
    81  
    82  	suite.me = new(module.Local)
    83  
    84  	accessIdentity := unittest.IdentityFixture(unittest.WithRole(flow.RoleAccess))
    85  	suite.me.
    86  		On("NodeID").
    87  		Return(accessIdentity.NodeID)
    88  
    89  	suite.chainID = flow.Testnet
    90  	suite.metrics = metrics.NewNoopCollector()
    91  
    92  	config := Config{
    93  		UnsecureGRPCListenAddr: unittest.DefaultAddress,
    94  		SecureGRPCListenAddr:   unittest.DefaultAddress,
    95  		HTTPListenAddr:         unittest.DefaultAddress,
    96  	}
    97  
    98  	// set the rate limit to test with
    99  	suite.rateLimit = 2
   100  	// set the burst limit to test with
   101  	suite.burstLimit = 2
   102  
   103  	apiRateLimt := map[string]int{
   104  		"Ping": suite.rateLimit,
   105  	}
   106  
   107  	apiBurstLimt := map[string]int{
   108  		"Ping": suite.rateLimit,
   109  	}
   110  
   111  	rpcEngBuilder, err := NewBuilder(suite.log, suite.state, config, suite.collClient, nil, suite.blocks, suite.headers, suite.collections, suite.transactions, nil,
   112  		nil, suite.chainID, suite.metrics, suite.metrics, 0, 0, false, false, apiRateLimt, apiBurstLimt)
   113  	assert.NoError(suite.T(), err)
   114  	suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build()
   115  	assert.NoError(suite.T(), err)
   116  	unittest.AssertClosesBefore(suite.T(), suite.rpcEng.Ready(), 2*time.Second)
   117  
   118  	// wait for the server to startup
   119  	assert.Eventually(suite.T(), func() bool {
   120  		return suite.rpcEng.UnsecureGRPCAddress() != nil
   121  	}, 5*time.Second, 10*time.Millisecond)
   122  
   123  	// create the access api client
   124  	suite.client, suite.closer, err = accessAPIClient(suite.rpcEng.UnsecureGRPCAddress().String())
   125  	assert.NoError(suite.T(), err)
   126  }
   127  
   128  func (suite *RateLimitTestSuite) TearDownTest() {
   129  	// close the client
   130  	if suite.closer != nil {
   131  		suite.closer.Close()
   132  	}
   133  	// close the server
   134  	if suite.rpcEng != nil {
   135  		unittest.AssertClosesBefore(suite.T(), suite.rpcEng.Done(), 2*time.Second)
   136  	}
   137  }
   138  
   139  func TestRateLimit(t *testing.T) {
   140  	suite.Run(t, new(RateLimitTestSuite))
   141  }
   142  
   143  // TestRatelimitingWithoutBurst tests that rate limit is correctly applied to an Access API call
   144  func (suite *RateLimitTestSuite) TestRatelimitingWithoutBurst() {
   145  
   146  	req := &accessproto.PingRequest{}
   147  	ctx := context.Background()
   148  
   149  	// expect 2 upstream calls
   150  	suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit)
   151  	suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit)
   152  
   153  	requestCnt := 0
   154  	// requests within the burst should succeed
   155  	for requestCnt < suite.rateLimit {
   156  		resp, err := suite.client.Ping(ctx, req)
   157  		assert.NoError(suite.T(), err)
   158  		assert.NotNil(suite.T(), resp)
   159  		// sleep to prevent burst
   160  		time.Sleep(100 * time.Millisecond)
   161  		requestCnt++
   162  	}
   163  
   164  	// request more than the limit should fail
   165  	_, err := suite.client.Ping(ctx, req)
   166  	suite.assertRateLimitError(err)
   167  }
   168  
   169  // TestRatelimitingWithBurst tests that burst limit is correctly applied to an Access API call
   170  func (suite *RateLimitTestSuite) TestRatelimitingWithBurst() {
   171  
   172  	req := &accessproto.PingRequest{}
   173  	ctx := context.Background()
   174  
   175  	// expect rpc.defaultBurst number of upstream calls
   176  	suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit)
   177  	suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit)
   178  
   179  	requestCnt := 0
   180  	// generate a permissible burst of request and assert that they succeed
   181  	for requestCnt < suite.burstLimit {
   182  		resp, err := suite.client.Ping(ctx, req)
   183  		assert.NoError(suite.T(), err)
   184  		assert.NotNil(suite.T(), resp)
   185  		requestCnt++
   186  	}
   187  
   188  	// request more than the permissible burst and assert that it fails
   189  	_, err := suite.client.Ping(ctx, req)
   190  	suite.assertRateLimitError(err)
   191  }
   192  
   193  func (suite *RateLimitTestSuite) assertRateLimitError(err error) {
   194  	assert.Error(suite.T(), err)
   195  	status, ok := status.FromError(err)
   196  	assert.True(suite.T(), ok)
   197  	assert.Equal(suite.T(), codes.ResourceExhausted, status.Code())
   198  }
   199  
   200  func accessAPIClient(address string) (accessproto.AccessAPIClient, io.Closer, error) {
   201  	conn, err := grpc.Dial(
   202  		address,
   203  		grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(grpcutils.DefaultMaxMsgSize)),
   204  		grpc.WithInsecure()) //nolint:staticcheck
   205  	if err != nil {
   206  		return nil, nil, fmt.Errorf("failed to connect to address %s: %w", address, err)
   207  	}
   208  	client := accessproto.NewAccessAPIClient(conn)
   209  	closer := io.Closer(conn)
   210  	return client, closer, nil
   211  }