github.com/polarismesh/polaris@v1.17.8/cache/service/ratelimit_config_test.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package service
    19  
    20  import (
    21  	"encoding/json"
    22  	"fmt"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/golang/mock/gomock"
    27  	"github.com/golang/protobuf/ptypes/duration"
    28  	apimodel "github.com/polarismesh/specification/source/go/api/v1/model"
    29  	apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage"
    30  	"github.com/stretchr/testify/assert"
    31  
    32  	types "github.com/polarismesh/polaris/cache/api"
    33  	cachemock "github.com/polarismesh/polaris/cache/mock"
    34  	"github.com/polarismesh/polaris/common/model"
    35  	"github.com/polarismesh/polaris/common/utils"
    36  	"github.com/polarismesh/polaris/store/mock"
    37  )
    38  
    39  /**
    40   * @brief 创建一个测试mock rateLimitCache
    41   */
    42  func newTestRateLimitCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *rateLimitCache) {
    43  	ctl := gomock.NewController(t)
    44  
    45  	storage := mock.NewMockStore(ctl)
    46  	mockCacheMgr := cachemock.NewMockCacheManager(ctl)
    47  
    48  	mockSvcCache := NewServiceCache(storage, mockCacheMgr)
    49  	mockInstCache := NewInstanceCache(storage, mockCacheMgr)
    50  	mockRateLimitCache := NewRateLimitCache(storage, mockCacheMgr)
    51  
    52  	mockCacheMgr.EXPECT().GetCacher(types.CacheService).Return(mockSvcCache).AnyTimes()
    53  	mockCacheMgr.EXPECT().GetCacher(types.CacheInstance).Return(mockInstCache).AnyTimes()
    54  
    55  	storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil)
    56  	var opt map[string]interface{}
    57  	_ = mockRateLimitCache.Initialize(opt)
    58  	_ = mockSvcCache.Initialize(opt)
    59  	_ = mockInstCache.Initialize(opt)
    60  	return ctl, storage, mockRateLimitCache.(*rateLimitCache)
    61  }
    62  
    63  func buildRateLimitRuleProtoWithLabels(name string, method string) *apitraffic.Rule {
    64  	rule := &apitraffic.Rule{
    65  		Priority: utils.NewUInt32Value(0),
    66  		Resource: apitraffic.Rule_QPS,
    67  		Type:     apitraffic.Rule_LOCAL,
    68  		Labels: map[string]*apimodel.MatchString{"http.method": {
    69  			Type:  apimodel.MatchString_EXACT,
    70  			Value: utils.NewStringValue("post"),
    71  		}},
    72  		Amounts: []*apitraffic.Amount{{
    73  			MaxAmount:     utils.NewUInt32Value(100),
    74  			ValidDuration: &duration.Duration{Seconds: 1},
    75  		}},
    76  		Action:       utils.NewStringValue("reject"),
    77  		Disable:      utils.NewBoolValue(false),
    78  		RegexCombine: utils.NewBoolValue(false),
    79  		Failover:     apitraffic.Rule_FAILOVER_LOCAL,
    80  		Method: &apimodel.MatchString{
    81  			Type:  apimodel.MatchString_EXACT,
    82  			Value: utils.NewStringValue(method),
    83  		},
    84  		Name: utils.NewStringValue(name),
    85  	}
    86  	return rule
    87  }
    88  
    89  func buildRateLimitRuleProtoWithArguments(name string, method string) *apitraffic.Rule {
    90  	rule := &apitraffic.Rule{
    91  		Priority: utils.NewUInt32Value(0),
    92  		Resource: apitraffic.Rule_QPS,
    93  		Type:     apitraffic.Rule_LOCAL,
    94  		Arguments: []*apitraffic.MatchArgument{
    95  			{
    96  				Type: apitraffic.MatchArgument_HEADER,
    97  				Key:  "host",
    98  				Value: &apimodel.MatchString{
    99  					Type:  apimodel.MatchString_EXACT,
   100  					Value: utils.NewStringValue("localhost"),
   101  				},
   102  			},
   103  		},
   104  		Amounts: []*apitraffic.Amount{{
   105  			MaxAmount:     utils.NewUInt32Value(100),
   106  			ValidDuration: &duration.Duration{Seconds: 1},
   107  		}},
   108  		Action:       utils.NewStringValue("reject"),
   109  		Disable:      utils.NewBoolValue(false),
   110  		RegexCombine: utils.NewBoolValue(false),
   111  		Failover:     apitraffic.Rule_FAILOVER_LOCAL,
   112  		Method: &apimodel.MatchString{
   113  			Type:  apimodel.MatchString_EXACT,
   114  			Value: utils.NewStringValue(method),
   115  		},
   116  		Name: utils.NewStringValue(name),
   117  	}
   118  	return rule
   119  }
   120  
   121  // genRateLimitsWithLabels 生成限流规则测试数据
   122  func genRateLimits(
   123  	beginNum, totalServices, totalRateLimits int, withLabels bool) []*model.RateLimit {
   124  	rateLimits := make([]*model.RateLimit, 0, totalRateLimits)
   125  	rulePerService := totalRateLimits / totalServices
   126  
   127  	for i := beginNum; i < totalServices+beginNum; i++ {
   128  		for j := 0; j < rulePerService; j++ {
   129  			name := fmt.Sprintf("limit-rule-%d-%d", i, j)
   130  			method := fmt.Sprintf("/test-%d", j)
   131  			var rule *apitraffic.Rule
   132  			if withLabels {
   133  				rule = buildRateLimitRuleProtoWithLabels(name, method)
   134  			} else {
   135  				rule = buildRateLimitRuleProtoWithArguments(name, method)
   136  			}
   137  			rule.Service = utils.NewStringValue(fmt.Sprintf("service-%d", i))
   138  			rule.Namespace = utils.NewStringValue("default")
   139  			str, _ := json.Marshal(rule)
   140  			labels, _ := json.Marshal(rule.GetLabels())
   141  			rateLimit := &model.RateLimit{
   142  				ID:        fmt.Sprintf("id-%d-%d", i, j),
   143  				ServiceID: fmt.Sprintf("service-%d", i),
   144  				Name:      name,
   145  				Method:    method,
   146  				Rule:      string(str),
   147  				Revision:  fmt.Sprintf("revision-%d-%d", i, j),
   148  				Labels:    string(labels),
   149  				Valid:     true,
   150  			}
   151  			rateLimits = append(rateLimits, rateLimit)
   152  		}
   153  	}
   154  	return rateLimits
   155  }
   156  
   157  /**
   158   * @brief 统计缓存中的限流数据
   159   */
   160  func getRateLimitsCount(serviceKey model.ServiceKey, rlc *rateLimitCache) int {
   161  	ret, _ := rlc.GetRateLimitRules(serviceKey)
   162  	return len(ret)
   163  }
   164  
   165  /**
   166   * TestRateLimitUpdate 测试更新缓存操作
   167   */
   168  func TestRateLimitUpdate(t *testing.T) {
   169  	ctl, storage, rlc := newTestRateLimitCache(t)
   170  	defer ctl.Finish()
   171  
   172  	totalServices := 5
   173  	totalRateLimits := 15
   174  	rateLimits := genRateLimits(0, totalServices, totalRateLimits, false)
   175  
   176  	t.Run("正常更新缓存,可以获取到数据", func(t *testing.T) {
   177  		_ = rlc.Clear()
   178  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).Return(rateLimits, nil)
   179  		if err := rlc.Update(); err != nil {
   180  			t.Fatalf("error: %s", err.Error())
   181  		}
   182  
   183  		// 检查数目是否一致
   184  		for i := 0; i < totalServices; i++ {
   185  			count := getRateLimitsCount(model.ServiceKey{
   186  				Namespace: "default",
   187  				Name:      fmt.Sprintf("service-%d", i),
   188  			}, rlc)
   189  			if count == totalRateLimits/totalServices {
   190  				t.Log("pass")
   191  			} else {
   192  				t.Fatalf("actual count is %d", count)
   193  			}
   194  		}
   195  
   196  		count := rlc.GetRateLimitsCount()
   197  		if count == totalRateLimits {
   198  			t.Log("pass")
   199  		} else {
   200  			t.Fatalf("actual count is %d", count)
   201  		}
   202  	})
   203  
   204  	t.Run("缓存数据为空", func(t *testing.T) {
   205  		_ = rlc.Clear()
   206  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   207  			Return(nil, nil)
   208  		if err := rlc.Update(); err != nil {
   209  			t.Fatalf("error: %s", err.Error())
   210  		}
   211  
   212  		if rlc.GetRateLimitsCount() == 0 {
   213  			t.Log("pass")
   214  		} else {
   215  			t.Fatalf("actual rate limits count is %d",
   216  				rlc.GetRateLimitsCount())
   217  		}
   218  	})
   219  
   220  	t.Run("lastMtime正确更新", func(t *testing.T) {
   221  		_ = rlc.Clear()
   222  
   223  		currentTime := time.Now()
   224  		rateLimits[0].ModifyTime = currentTime
   225  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   226  			Return(rateLimits, nil)
   227  		if err := rlc.Update(); err != nil {
   228  			t.Fatalf("error: %s", err.Error())
   229  		}
   230  
   231  		if rlc.OriginLastFetchTime().Unix() == currentTime.Unix() {
   232  			t.Log("pass")
   233  		} else {
   234  			t.Fatalf("last mtime error")
   235  		}
   236  	})
   237  
   238  	t.Run("数据库返回错误,update错误", func(t *testing.T) {
   239  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   240  			Return(nil, fmt.Errorf("stoarge error"))
   241  		if err := rlc.Update(); err != nil {
   242  			t.Log("pass")
   243  		} else {
   244  			t.Fatalf("error")
   245  		}
   246  	})
   247  }
   248  
   249  /**
   250   * TestRateLimitUpdate2 统计缓存中的限流数据
   251   */
   252  func TestRateLimitUpdate2(t *testing.T) {
   253  	ctl, storage, rlc := newTestRateLimitCache(t)
   254  	defer ctl.Finish()
   255  
   256  	totalServices := 5
   257  	totalRateLimits := 15
   258  
   259  	t.Run("更新缓存后,增加部分数据,缓存正常更新", func(t *testing.T) {
   260  		_ = rlc.Clear()
   261  
   262  		rateLimits := genRateLimits(0, totalServices, totalRateLimits, true)
   263  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   264  			Return(rateLimits, nil)
   265  		if err := rlc.Update(); err != nil {
   266  			t.Fatalf("error: %s", err.Error())
   267  		}
   268  
   269  		rateLimits = genRateLimits(5, totalServices, totalRateLimits, true)
   270  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   271  			Return(rateLimits, nil)
   272  		if err := rlc.Update(); err != nil {
   273  			t.Fatalf("error: %s", err.Error())
   274  		}
   275  
   276  		if rlc.GetRateLimitsCount() == totalRateLimits*2 {
   277  			t.Log("pass")
   278  		} else {
   279  			t.Fatalf("actual rate limits count is %d", rlc.GetRateLimitsCount())
   280  		}
   281  	})
   282  
   283  	t.Run("更新缓存后,删除部分数据,缓存正常更新", func(t *testing.T) {
   284  		_ = rlc.Clear()
   285  
   286  		rateLimits := genRateLimits(0, totalServices, totalRateLimits, true)
   287  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   288  			Return(rateLimits, nil)
   289  		if err := rlc.Update(); err != nil {
   290  			t.Fatalf("error: %s", err.Error())
   291  		}
   292  
   293  		for i := 0; i < totalRateLimits; i += 2 {
   294  			rateLimits[i].Valid = false
   295  		}
   296  
   297  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   298  			Return(rateLimits, nil)
   299  		if err := rlc.Update(); err != nil {
   300  			t.Fatalf("error: %s", err.Error())
   301  		}
   302  
   303  		if rlc.GetRateLimitsCount() == totalRateLimits/2 {
   304  			t.Log("pass")
   305  		} else {
   306  			t.Fatalf("actual rate limits count is %d",
   307  				rlc.GetRateLimitsCount())
   308  		}
   309  	})
   310  }
   311  
   312  /**
   313   * TestGetRateLimitsByServiceID 根据服务id获取限流数据和revision
   314   */
   315  func TestGetRateLimitsByServiceID(t *testing.T) {
   316  	ctl, storage, rlc := newTestRateLimitCache(t)
   317  	defer ctl.Finish()
   318  
   319  	t.Run("通过服务ID获取数据并检查labels", func(t *testing.T) {
   320  		_ = rlc.Clear()
   321  
   322  		totalServices := 5
   323  		totalRateLimits := 15
   324  		rateLimits := genRateLimits(0, totalServices, totalRateLimits, true)
   325  
   326  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   327  			Return(rateLimits, nil)
   328  		if err := rlc.Update(); err != nil {
   329  			t.Fatalf("error: %s", err.Error())
   330  		}
   331  
   332  		rules, _ := rlc.GetRateLimitRules(model.ServiceKey{
   333  			Namespace: "default",
   334  			Name:      "service-1",
   335  		})
   336  		if len(rules) == totalRateLimits/totalServices {
   337  			t.Log("pass")
   338  		} else {
   339  			t.Fatalf("expect num is %d, actual num is %d", totalRateLimits/totalServices, len(rateLimits))
   340  		}
   341  
   342  		for _, rateLimit := range rules {
   343  			assert.Equal(t, 1, len(rateLimit.Proto.Labels))
   344  			assert.Equal(t, 1, len(rateLimit.Proto.Arguments))
   345  			for _, argument := range rateLimit.Proto.Arguments {
   346  				assert.Equal(t, apitraffic.MatchArgument_CUSTOM, argument.Type)
   347  				_, hasKey := rateLimit.Proto.Labels[argument.Key]
   348  				assert.True(t, hasKey)
   349  			}
   350  		}
   351  	})
   352  
   353  	t.Run("通过服务ID获取数据并检查argument", func(t *testing.T) {
   354  		_ = rlc.Clear()
   355  
   356  		totalServices := 5
   357  		totalRateLimits := 15
   358  		rateLimits := genRateLimits(0, totalServices, totalRateLimits, false)
   359  
   360  		storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).
   361  			Return(rateLimits, nil)
   362  		if err := rlc.Update(); err != nil {
   363  			t.Fatalf("error: %s", err.Error())
   364  		}
   365  
   366  		rateLimits, _ = rlc.GetRateLimitRules(model.ServiceKey{
   367  			Namespace: "default",
   368  			Name:      "service-1",
   369  		})
   370  		if len(rateLimits) == totalRateLimits/totalServices {
   371  			t.Log("pass")
   372  		} else {
   373  			t.Fatalf("expect num is %d, actual num is %d", totalRateLimits/totalServices, len(rateLimits))
   374  		}
   375  		for _, rateLimit := range rateLimits {
   376  			assert.Equal(t, 1, len(rateLimit.Proto.Arguments))
   377  			assert.Equal(t, 1, len(rateLimit.Proto.Labels))
   378  			labelValue, hasKey := rateLimit.Proto.Labels["$header.host"]
   379  			assert.True(t, hasKey)
   380  			assert.Equal(t, rateLimit.Proto.Arguments[0].Value.Value.GetValue(), labelValue.GetValue().GetValue())
   381  		}
   382  	})
   383  }
   384  
   385  func Test_QueryRateLimitRules(t *testing.T) {
   386  	ctl, storage, rlc := newTestRateLimitCache(t)
   387  	t.Cleanup(func() {
   388  		ctl.Finish()
   389  	})
   390  
   391  	totalServices := 5
   392  	totalRateLimits := 15
   393  	rateLimits := genRateLimits(0, totalServices, totalRateLimits, true)
   394  
   395  	storage.EXPECT().GetRateLimitsForCache(gomock.Any(), gomock.Any()).AnyTimes().
   396  		Return(rateLimits, nil)
   397  	if err := rlc.Update(); err != nil {
   398  		t.Fatalf("error: %s", err.Error())
   399  	}
   400  
   401  	t.Run("根据ID进行查询", func(t *testing.T) {
   402  		total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   403  			ID:     rateLimits[0].ID,
   404  			Offset: 0,
   405  			Limit:  100,
   406  		})
   407  
   408  		assert.NoError(t, err)
   409  		assert.Equal(t, int64(1), int64(total))
   410  		assert.Equal(t, int64(1), int64(len(ret)))
   411  		assert.Equal(t, rateLimits[0].ID, ret[0].ID)
   412  	})
   413  
   414  	t.Run("根据Name进行查询", func(t *testing.T) {
   415  		total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   416  			Name:   rateLimits[0].Name,
   417  			Offset: 0,
   418  			Limit:  100,
   419  		})
   420  
   421  		assert.NoError(t, err)
   422  		assert.Equal(t, int64(1), int64(total))
   423  		assert.Equal(t, int64(1), int64(len(ret)))
   424  		assert.Equal(t, rateLimits[0].ID, ret[0].ID)
   425  	})
   426  
   427  	t.Run("根据Namespace&Service进行查询", func(t *testing.T) {
   428  		total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   429  			Service:   "service-0",
   430  			Namespace: "default",
   431  			Offset:    0,
   432  			Limit:     100,
   433  		})
   434  
   435  		assert.NoError(t, err)
   436  		assert.Equal(t, int64(3), int64(total))
   437  		assert.Equal(t, int64(3), int64(len(ret)))
   438  		for i := range ret {
   439  			assert.Equal(t, "service-0", ret[i].Proto.Service.Value)
   440  			assert.Equal(t, "default", ret[i].Proto.Namespace.Value)
   441  		}
   442  	})
   443  
   444  	t.Run("根据分页进行查询", func(t *testing.T) {
   445  		total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   446  			Offset: 10,
   447  			Limit:  5,
   448  		})
   449  
   450  		assert.NoError(t, err)
   451  		assert.Equal(t, int64(total), int64(len(rateLimits)))
   452  		assert.Equal(t, int64(5), int64(len(ret)))
   453  
   454  		total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   455  			Offset: 100,
   456  			Limit:  5,
   457  		})
   458  
   459  		assert.NoError(t, err)
   460  		assert.Equal(t, int64(total), int64(len(rateLimits)))
   461  		assert.Equal(t, int64(0), int64(len(ret)))
   462  	})
   463  
   464  	t.Run("根据Disable进行查询", func(t *testing.T) {
   465  		disable := true
   466  		total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   467  			Disable: &disable,
   468  			Offset:  0,
   469  			Limit:   100,
   470  		})
   471  
   472  		assert.NoError(t, err)
   473  		assert.Equal(t, int64(0), int64(total))
   474  		assert.Equal(t, int64(0), int64(len(ret)))
   475  
   476  		disable = false
   477  		total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{
   478  			Disable: &disable,
   479  			Offset:  0,
   480  			Limit:   100,
   481  		})
   482  
   483  		assert.NoError(t, err)
   484  		assert.Equal(t, int64(total), int64(len(rateLimits)))
   485  		assert.Equal(t, int64(total), int64(len(ret)))
   486  		for i := range ret {
   487  			assert.Equal(t, disable, ret[i].Disable)
   488  		}
   489  	})
   490  
   491  }