github.com/cloudwego/kitex@v0.9.0/pkg/remote/bound/limiter_inbound_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package bound
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  
    24  	"github.com/golang/mock/gomock"
    25  
    26  	mockslimiter "github.com/cloudwego/kitex/internal/mocks/limiter"
    27  	"github.com/cloudwego/kitex/internal/test"
    28  	"github.com/cloudwego/kitex/pkg/kerrors"
    29  	"github.com/cloudwego/kitex/pkg/remote"
    30  	"github.com/cloudwego/kitex/pkg/remote/trans/invoke"
    31  )
    32  
    33  var errFoo = errors.New("mockError")
    34  
    35  // TestNewServerLimiterHandler test NewServerLimiterHandler function and assert result to be not nil.
    36  func TestNewServerLimiterHandler(t *testing.T) {
    37  	ctrl := gomock.NewController(t)
    38  	defer ctrl.Finish()
    39  
    40  	concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
    41  	rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
    42  	limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
    43  
    44  	handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
    45  	test.Assert(t, handler != nil)
    46  	handler = NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
    47  	test.Assert(t, handler != nil)
    48  }
    49  
    50  // TestLimiterOnActive test OnActive function of serverLimiterHandler, to assert the concurrencyLimiter 'Acquire' works.
    51  // If not acquired, the message would be rejected with ErrConnOverLimit error.
    52  func TestLimiterOnActive(t *testing.T) {
    53  	t.Run("Test OnActive with limit acquired", func(t *testing.T) {
    54  		ctrl := gomock.NewController(t)
    55  		defer ctrl.Finish()
    56  
    57  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
    58  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
    59  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
    60  		ctx := context.Background()
    61  
    62  		concurrencyLimiter.EXPECT().Acquire(ctx).Return(true).Times(2)
    63  
    64  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
    65  		ctx, err := handler.OnActive(ctx, invoke.NewMessage(nil, nil))
    66  		test.Assert(t, ctx != nil)
    67  		test.Assert(t, err == nil)
    68  
    69  		muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
    70  		ctx, err = muxHandler.OnActive(ctx, invoke.NewMessage(nil, nil))
    71  		test.Assert(t, ctx != nil)
    72  		test.Assert(t, err == nil)
    73  	})
    74  
    75  	t.Run("Test OnActive with limit acquire false and nil reporter", func(t *testing.T) {
    76  		ctrl := gomock.NewController(t)
    77  		defer ctrl.Finish()
    78  
    79  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
    80  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
    81  		ctx := context.Background()
    82  
    83  		concurrencyLimiter.EXPECT().Acquire(ctx).Return(false).Times(2)
    84  
    85  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, false)
    86  		ctx, err := handler.OnActive(ctx, invoke.NewMessage(nil, nil))
    87  		test.Assert(t, ctx != nil)
    88  		test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err))
    89  
    90  		muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, true)
    91  		ctx, err = muxHandler.OnActive(ctx, invoke.NewMessage(nil, nil))
    92  		test.Assert(t, ctx != nil)
    93  		test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err))
    94  	})
    95  
    96  	t.Run("Test OnActive with limit acquire false and non-nil reporter", func(t *testing.T) {
    97  		ctrl := gomock.NewController(t)
    98  		defer ctrl.Finish()
    99  
   100  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   101  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   102  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   103  		ctx := context.Background()
   104  
   105  		concurrencyLimiter.EXPECT().Acquire(ctx).Return(false).Times(2)
   106  		limitReporter.EXPECT().ConnOverloadReport().Times(2)
   107  
   108  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
   109  		ctx, err := handler.OnActive(ctx, invoke.NewMessage(nil, nil))
   110  		test.Assert(t, ctx != nil)
   111  		test.Assert(t, err != nil)
   112  		test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err))
   113  
   114  		muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
   115  		ctx, err = muxHandler.OnActive(ctx, invoke.NewMessage(nil, nil))
   116  		test.Assert(t, ctx != nil)
   117  		test.Assert(t, err != nil)
   118  		test.Assert(t, errors.Is(kerrors.ErrConnOverLimit, err))
   119  	})
   120  }
   121  
   122  // TestLimiterOnRead test OnRead function of serverLimiterHandler, to assert the rateLimiter 'Acquire' works.
   123  // If not acquired, the message would be rejected with ErrQPSOverLimit error.
   124  func TestLimiterOnRead(t *testing.T) {
   125  	t.Run("Test OnRead with limit acquired", func(t *testing.T) {
   126  		ctrl := gomock.NewController(t)
   127  		defer ctrl.Finish()
   128  
   129  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   130  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   131  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   132  		ctx := context.Background()
   133  
   134  		rateLimiter.EXPECT().Acquire(ctx).Return(true).Times(1)
   135  
   136  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
   137  		ctx, err := handler.OnRead(ctx, invoke.NewMessage(nil, nil))
   138  
   139  		test.Assert(t, ctx != nil)
   140  		test.Assert(t, err == nil)
   141  	})
   142  
   143  	t.Run("Test OnRead with limit acquire false and nil reporter", func(t *testing.T) {
   144  		ctrl := gomock.NewController(t)
   145  		defer ctrl.Finish()
   146  
   147  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   148  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   149  		ctx := context.Background()
   150  
   151  		rateLimiter.EXPECT().Acquire(ctx).Return(false)
   152  
   153  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, false)
   154  		ctx, err := handler.OnRead(ctx, invoke.NewMessage(nil, nil))
   155  
   156  		test.Assert(t, ctx != nil)
   157  		test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err))
   158  	})
   159  
   160  	t.Run("Test OnRead with limit acquire false and non-nil reporter", func(t *testing.T) {
   161  		ctrl := gomock.NewController(t)
   162  		defer ctrl.Finish()
   163  
   164  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   165  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   166  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   167  		ctx := context.Background()
   168  
   169  		rateLimiter.EXPECT().Acquire(ctx).Return(false).Times(1)
   170  		limitReporter.EXPECT().QPSOverloadReport().Times(1)
   171  
   172  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
   173  		ctx, err := handler.OnRead(ctx, invoke.NewMessage(nil, nil))
   174  
   175  		test.Assert(t, ctx != nil)
   176  		test.Assert(t, err != nil)
   177  		test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err))
   178  	})
   179  }
   180  
   181  // TestLimiterOnInactive test OnInactive function of serverLimiterHandler, to assert the release is called.
   182  func TestLimiterOnInactive(t *testing.T) {
   183  	t.Run("Test OnInactive", func(t *testing.T) {
   184  		ctrl := gomock.NewController(t)
   185  		defer ctrl.Finish()
   186  
   187  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   188  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   189  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   190  		ctx := context.Background()
   191  
   192  		concurrencyLimiter.EXPECT().Release(ctx).Times(2)
   193  
   194  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, false)
   195  		ctx = handler.OnInactive(ctx, invoke.NewMessage(nil, nil))
   196  		test.Assert(t, ctx != nil)
   197  
   198  		muxHandler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
   199  		ctx = muxHandler.OnInactive(ctx, invoke.NewMessage(nil, nil))
   200  		test.Assert(t, ctx != nil)
   201  	})
   202  }
   203  
   204  // TestLimiterOnMessage test OnMessage function of serverLimiterHandler
   205  func TestLimiterOnMessage(t *testing.T) {
   206  	t.Run("Test OnMessage with limit acquired", func(t *testing.T) {
   207  		ctrl := gomock.NewController(t)
   208  		defer ctrl.Finish()
   209  
   210  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   211  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   212  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   213  		ctx := context.Background()
   214  		req := remote.NewMessage(nil, nil, nil, remote.Call, remote.Client)
   215  
   216  		rateLimiter.EXPECT().Acquire(ctx).Return(true).Times(1)
   217  
   218  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
   219  		ctx, err := handler.OnMessage(ctx, req, remote.NewMessage(nil, nil, nil, remote.Reply, remote.Client))
   220  
   221  		test.Assert(t, ctx != nil)
   222  		test.Assert(t, err == nil)
   223  	})
   224  
   225  	t.Run("Test OnMessage with limit acquire false and nil reporter", func(t *testing.T) {
   226  		ctrl := gomock.NewController(t)
   227  		defer ctrl.Finish()
   228  
   229  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   230  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   231  		ctx := context.Background()
   232  		req := remote.NewMessage(nil, nil, nil, remote.Call, remote.Client)
   233  
   234  		rateLimiter.EXPECT().Acquire(ctx).Return(false).Times(1)
   235  
   236  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, nil, true)
   237  		ctx, err := handler.OnMessage(ctx, req, remote.NewMessage(nil, nil, nil, remote.Reply, remote.Client))
   238  
   239  		test.Assert(t, ctx != nil)
   240  		test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err))
   241  	})
   242  
   243  	t.Run("Test OnMessage with limit acquire false and non-nil reporter", func(t *testing.T) {
   244  		ctrl := gomock.NewController(t)
   245  		defer ctrl.Finish()
   246  
   247  		concurrencyLimiter := mockslimiter.NewMockConcurrencyLimiter(ctrl)
   248  		rateLimiter := mockslimiter.NewMockRateLimiter(ctrl)
   249  		limitReporter := mockslimiter.NewMockLimitReporter(ctrl)
   250  		ctx := context.Background()
   251  		req := remote.NewMessage(nil, nil, nil, remote.Call, remote.Client)
   252  
   253  		rateLimiter.EXPECT().Acquire(ctx).Return(false).Times(1)
   254  		limitReporter.EXPECT().QPSOverloadReport().Times(1)
   255  
   256  		handler := NewServerLimiterHandler(concurrencyLimiter, rateLimiter, limitReporter, true)
   257  		ctx, err := handler.OnMessage(ctx, req, remote.NewMessage(nil, nil, nil, remote.Reply, remote.Client))
   258  
   259  		test.Assert(t, ctx != nil)
   260  		test.Assert(t, err != nil)
   261  		test.Assert(t, errors.Is(kerrors.ErrQPSOverLimit, err))
   262  	})
   263  }