github.com/koko1123/flow-go-1@v0.29.6/engine/access/rpc/rate_limit_test.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "os" 8 "testing" 9 "time" 10 11 accessproto "github.com/onflow/flow/protobuf/go/flow/access" 12 "github.com/rs/zerolog" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/mock" 15 "github.com/stretchr/testify/suite" 16 "google.golang.org/grpc" 17 "google.golang.org/grpc/codes" 18 "google.golang.org/grpc/status" 19 20 accessmock "github.com/koko1123/flow-go-1/engine/access/mock" 21 "github.com/koko1123/flow-go-1/model/flow" 22 "github.com/koko1123/flow-go-1/module/metrics" 23 module "github.com/koko1123/flow-go-1/module/mock" 24 "github.com/koko1123/flow-go-1/network" 25 protocol "github.com/koko1123/flow-go-1/state/protocol/mock" 26 storagemock "github.com/koko1123/flow-go-1/storage/mock" 27 "github.com/koko1123/flow-go-1/utils/grpcutils" 28 "github.com/koko1123/flow-go-1/utils/unittest" 29 ) 30 31 type RateLimitTestSuite struct { 32 suite.Suite 33 state *protocol.State 34 snapshot *protocol.Snapshot 35 epochQuery *protocol.EpochQuery 36 log zerolog.Logger 37 net *network.Network 38 request *module.Requester 39 collClient *accessmock.AccessAPIClient 40 execClient *accessmock.ExecutionAPIClient 41 me *module.Local 42 chainID flow.ChainID 43 metrics *metrics.NoopCollector 44 rpcEng *Engine 45 client accessproto.AccessAPIClient 46 closer io.Closer 47 48 // storage 49 blocks *storagemock.Blocks 50 headers *storagemock.Headers 51 collections *storagemock.Collections 52 transactions *storagemock.Transactions 53 receipts *storagemock.ExecutionReceipts 54 55 // test rate limit 56 rateLimit int 57 burstLimit int 58 } 59 60 func (suite *RateLimitTestSuite) SetupTest() { 61 suite.log = zerolog.New(os.Stdout) 62 suite.net = new(network.Network) 63 suite.state = new(protocol.State) 64 suite.snapshot = new(protocol.Snapshot) 65 66 suite.epochQuery = new(protocol.EpochQuery) 67 suite.state.On("Sealed").Return(suite.snapshot, nil).Maybe() 68 suite.state.On("Final").Return(suite.snapshot, nil).Maybe() 69 suite.snapshot.On("Epochs").Return(suite.epochQuery).Maybe() 70 suite.blocks = new(storagemock.Blocks) 71 suite.headers = new(storagemock.Headers) 72 suite.transactions = new(storagemock.Transactions) 73 suite.collections = new(storagemock.Collections) 74 suite.receipts = new(storagemock.ExecutionReceipts) 75 76 suite.collClient = new(accessmock.AccessAPIClient) 77 suite.execClient = new(accessmock.ExecutionAPIClient) 78 79 suite.request = new(module.Requester) 80 suite.request.On("EntityByID", mock.Anything, mock.Anything) 81 82 suite.me = new(module.Local) 83 84 accessIdentity := unittest.IdentityFixture(unittest.WithRole(flow.RoleAccess)) 85 suite.me. 86 On("NodeID"). 87 Return(accessIdentity.NodeID) 88 89 suite.chainID = flow.Testnet 90 suite.metrics = metrics.NewNoopCollector() 91 92 config := Config{ 93 UnsecureGRPCListenAddr: unittest.DefaultAddress, 94 SecureGRPCListenAddr: unittest.DefaultAddress, 95 HTTPListenAddr: unittest.DefaultAddress, 96 } 97 98 // set the rate limit to test with 99 suite.rateLimit = 2 100 // set the burst limit to test with 101 suite.burstLimit = 2 102 103 apiRateLimt := map[string]int{ 104 "Ping": suite.rateLimit, 105 } 106 107 apiBurstLimt := map[string]int{ 108 "Ping": suite.rateLimit, 109 } 110 111 rpcEngBuilder, err := NewBuilder(suite.log, suite.state, config, suite.collClient, nil, suite.blocks, suite.headers, suite.collections, suite.transactions, nil, 112 nil, suite.chainID, suite.metrics, suite.metrics, 0, 0, false, false, apiRateLimt, apiBurstLimt) 113 assert.NoError(suite.T(), err) 114 suite.rpcEng, err = rpcEngBuilder.WithLegacy().Build() 115 assert.NoError(suite.T(), err) 116 unittest.AssertClosesBefore(suite.T(), suite.rpcEng.Ready(), 2*time.Second) 117 118 // wait for the server to startup 119 assert.Eventually(suite.T(), func() bool { 120 return suite.rpcEng.UnsecureGRPCAddress() != nil 121 }, 5*time.Second, 10*time.Millisecond) 122 123 // create the access api client 124 suite.client, suite.closer, err = accessAPIClient(suite.rpcEng.UnsecureGRPCAddress().String()) 125 assert.NoError(suite.T(), err) 126 } 127 128 func (suite *RateLimitTestSuite) TearDownTest() { 129 // close the client 130 if suite.closer != nil { 131 suite.closer.Close() 132 } 133 // close the server 134 if suite.rpcEng != nil { 135 unittest.AssertClosesBefore(suite.T(), suite.rpcEng.Done(), 2*time.Second) 136 } 137 } 138 139 func TestRateLimit(t *testing.T) { 140 suite.Run(t, new(RateLimitTestSuite)) 141 } 142 143 // TestRatelimitingWithoutBurst tests that rate limit is correctly applied to an Access API call 144 func (suite *RateLimitTestSuite) TestRatelimitingWithoutBurst() { 145 146 req := &accessproto.PingRequest{} 147 ctx := context.Background() 148 149 // expect 2 upstream calls 150 suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit) 151 suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.rateLimit) 152 153 requestCnt := 0 154 // requests within the burst should succeed 155 for requestCnt < suite.rateLimit { 156 resp, err := suite.client.Ping(ctx, req) 157 assert.NoError(suite.T(), err) 158 assert.NotNil(suite.T(), resp) 159 // sleep to prevent burst 160 time.Sleep(100 * time.Millisecond) 161 requestCnt++ 162 } 163 164 // request more than the limit should fail 165 _, err := suite.client.Ping(ctx, req) 166 suite.assertRateLimitError(err) 167 } 168 169 // TestRatelimitingWithBurst tests that burst limit is correctly applied to an Access API call 170 func (suite *RateLimitTestSuite) TestRatelimitingWithBurst() { 171 172 req := &accessproto.PingRequest{} 173 ctx := context.Background() 174 175 // expect rpc.defaultBurst number of upstream calls 176 suite.execClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit) 177 suite.collClient.On("Ping", mock.Anything, mock.Anything).Return(nil, nil).Times(suite.burstLimit) 178 179 requestCnt := 0 180 // generate a permissible burst of request and assert that they succeed 181 for requestCnt < suite.burstLimit { 182 resp, err := suite.client.Ping(ctx, req) 183 assert.NoError(suite.T(), err) 184 assert.NotNil(suite.T(), resp) 185 requestCnt++ 186 } 187 188 // request more than the permissible burst and assert that it fails 189 _, err := suite.client.Ping(ctx, req) 190 suite.assertRateLimitError(err) 191 } 192 193 func (suite *RateLimitTestSuite) assertRateLimitError(err error) { 194 assert.Error(suite.T(), err) 195 status, ok := status.FromError(err) 196 assert.True(suite.T(), ok) 197 assert.Equal(suite.T(), codes.ResourceExhausted, status.Code()) 198 } 199 200 func accessAPIClient(address string) (accessproto.AccessAPIClient, io.Closer, error) { 201 conn, err := grpc.Dial( 202 address, 203 grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(grpcutils.DefaultMaxMsgSize)), 204 grpc.WithInsecure()) //nolint:staticcheck 205 if err != nil { 206 return nil, nil, fmt.Errorf("failed to connect to address %s: %w", address, err) 207 } 208 client := accessproto.NewAccessAPIClient(conn) 209 closer := io.Closer(conn) 210 return client, closer, nil 211 }