github.com/polarismesh/polaris@v1.17.8/cache/service/ratelimit_config_test.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package service 19 20 import ( 21 "encoding/json" 22 "fmt" 23 "testing" 24 "time" 25 26 "github.com/golang/mock/gomock" 27 "github.com/golang/protobuf/ptypes/duration" 28 apimodel "github.com/polarismesh/specification/source/go/api/v1/model" 29 apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" 30 "github.com/stretchr/testify/assert" 31 32 types "github.com/polarismesh/polaris/cache/api" 33 cachemock "github.com/polarismesh/polaris/cache/mock" 34 "github.com/polarismesh/polaris/common/model" 35 "github.com/polarismesh/polaris/common/utils" 36 "github.com/polarismesh/polaris/store/mock" 37 ) 38 39 /** 40 * @brief 创建一个测试mock rateLimitCache 41 */ 42 func newTestRateLimitCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *rateLimitCache) { 43 ctl := gomock.NewController(t) 44 45 storage := mock.NewMockStore(ctl) 46 mockCacheMgr := cachemock.NewMockCacheManager(ctl) 47 48 mockSvcCache := NewServiceCache(storage, mockCacheMgr) 49 mockInstCache := NewInstanceCache(storage, mockCacheMgr) 50 mockRateLimitCache := NewRateLimitCache(storage, mockCacheMgr) 51 52 mockCacheMgr.EXPECT().GetCacher(types.CacheService).Return(mockSvcCache).AnyTimes() 53 mockCacheMgr.EXPECT().GetCacher(types.CacheInstance).Return(mockInstCache).AnyTimes() 54 55 storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) 56 var opt map[string]interface{} 57 _ = mockRateLimitCache.Initialize(opt) 58 _ = mockSvcCache.Initialize(opt) 59 _ = mockInstCache.Initialize(opt) 60 return ctl, storage, mockRateLimitCache.(*rateLimitCache) 61 } 62 63 func buildRateLimitRuleProtoWithLabels(name string, method string) *apitraffic.Rule { 64 rule := &apitraffic.Rule{ 65 Priority: utils.NewUInt32Value(0), 66 Resource: apitraffic.Rule_QPS, 67 Type: apitraffic.Rule_LOCAL, 68 Labels: map[string]*apimodel.MatchString{"http.method": { 69 Type: apimodel.MatchString_EXACT, 70 Value: utils.NewStringValue("post"), 71 }}, 72 Amounts: []*apitraffic.Amount{{ 73 MaxAmount: utils.NewUInt32Value(100), 74 ValidDuration: &duration.Duration{Seconds: 1}, 75 }}, 76 Action: utils.NewStringValue("reject"), 77 Disable: utils.NewBoolValue(false), 78 RegexCombine: utils.NewBoolValue(false), 79 Failover: apitraffic.Rule_FAILOVER_LOCAL, 80 Method: &apimodel.MatchString{ 81 Type: apimodel.MatchString_EXACT, 82 Value: utils.NewStringValue(method), 83 }, 84 Name: utils.NewStringValue(name), 85 } 86 return rule 87 } 88 89 func buildRateLimitRuleProtoWithArguments(name string, method string) *apitraffic.Rule { 90 rule := &apitraffic.Rule{ 91 Priority: utils.NewUInt32Value(0), 92 Resource: apitraffic.Rule_QPS, 93 Type: apitraffic.Rule_LOCAL, 94 Arguments: []*apitraffic.MatchArgument{ 95 { 96 Type: apitraffic.MatchArgument_HEADER, 97 Key: "host", 98 Value: &apimodel.MatchString{ 99 Type: apimodel.MatchString_EXACT, 100 Value: utils.NewStringValue("localhost"), 101 }, 102 }, 103 }, 104 Amounts: []*apitraffic.Amount{{ 105 MaxAmount: utils.NewUInt32Value(100), 106 ValidDuration: &duration.Duration{Seconds: 1}, 107 }}, 108 Action: utils.NewStringValue("reject"), 109 Disable: utils.NewBoolValue(false), 110 RegexCombine: utils.NewBoolValue(false), 111 Failover: apitraffic.Rule_FAILOVER_LOCAL, 112 Method: &apimodel.MatchString{ 113 Type: apimodel.MatchString_EXACT, 114 Value: utils.NewStringValue(method), 115 }, 116 Name: utils.NewStringValue(name), 117 } 118 return rule 119 } 120 121 // genRateLimitsWithLabels 生成限流规则测试数据 122 func genRateLimits( 123 beginNum, totalServices, totalRateLimits int, withLabels bool) []*model.RateLimit { 124 rateLimits := make([]*model.RateLimit, 0, totalRateLimits) 125 rulePerService := totalRateLimits / totalServices 126 127 for i := beginNum; i < totalServices+beginNum; i++ { 128 for j := 0; j < rulePerService; j++ { 129 name := fmt.Sprintf("limit-rule-%d-%d", i, j) 130 method := fmt.Sprintf("/test-%d", j) 131 var rule *apitraffic.Rule 132 if withLabels { 133 rule = buildRateLimitRuleProtoWithLabels(name, method) 134 } else { 135 rule = buildRateLimitRuleProtoWithArguments(name, method) 136 } 137 rule.Service = utils.NewStringValue(fmt.Sprintf("service-%d", i)) 138 rule.Namespace = utils.NewStringValue("default") 139 str, _ := json.Marshal(rule) 140 labels, _ := json.Marshal(rule.GetLabels()) 141 rateLimit := &model.RateLimit{ 142 ID: fmt.Sprintf("id-%d-%d", i, j), 143 ServiceID: fmt.Sprintf("service-%d", i), 144 Name: name, 145 Method: method, 146 Rule: string(str), 147 Revision: fmt.Sprintf("revision-%d-%d", i, j), 148 Labels: string(labels), 149 Valid: true, 150 } 151 rateLimits = append(rateLimits, rateLimit) 152 } 153 } 154 return rateLimits 155 } 156 157 /** 158 * @brief 统计缓存中的限流数据 159 */ 160 func getRateLimitsCount(serviceKey model.ServiceKey, rlc *rateLimitCache) int { 161 ret, _ := rlc.GetRateLimitRules(serviceKey) 162 return len(ret) 163 } 164 165 /** 166 * TestRateLimitUpdate 测试更新缓存操作 167 */ 168 func TestRateLimitUpdate(t *testing.T) { 169 ctl, storage, rlc := newTestRateLimitCache(t) 170 defer ctl.Finish() 171 172 totalServices := 5 173 totalRateLimits := 15 174 rateLimits := genRateLimits(0, totalServices, totalRateLimits, false) 175 176 t.Run("正常更新缓存,可以获取到数据", func(t *testing.T) { 177 _ = rlc.Clear() 178 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()).Return(rateLimits, nil) 179 if err := rlc.Update(); err != nil { 180 t.Fatalf("error: %s", err.Error()) 181 } 182 183 // 检查数目是否一致 184 for i := 0; i < totalServices; i++ { 185 count := getRateLimitsCount(model.ServiceKey{ 186 Namespace: "default", 187 Name: fmt.Sprintf("service-%d", i), 188 }, rlc) 189 if count == totalRateLimits/totalServices { 190 t.Log("pass") 191 } else { 192 t.Fatalf("actual count is %d", count) 193 } 194 } 195 196 count := rlc.GetRateLimitsCount() 197 if count == totalRateLimits { 198 t.Log("pass") 199 } else { 200 t.Fatalf("actual count is %d", count) 201 } 202 }) 203 204 t.Run("缓存数据为空", func(t *testing.T) { 205 _ = rlc.Clear() 206 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 207 Return(nil, nil) 208 if err := rlc.Update(); err != nil { 209 t.Fatalf("error: %s", err.Error()) 210 } 211 212 if rlc.GetRateLimitsCount() == 0 { 213 t.Log("pass") 214 } else { 215 t.Fatalf("actual rate limits count is %d", 216 rlc.GetRateLimitsCount()) 217 } 218 }) 219 220 t.Run("lastMtime正确更新", func(t *testing.T) { 221 _ = rlc.Clear() 222 223 currentTime := time.Now() 224 rateLimits[0].ModifyTime = currentTime 225 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 226 Return(rateLimits, nil) 227 if err := rlc.Update(); err != nil { 228 t.Fatalf("error: %s", err.Error()) 229 } 230 231 if rlc.OriginLastFetchTime().Unix() == currentTime.Unix() { 232 t.Log("pass") 233 } else { 234 t.Fatalf("last mtime error") 235 } 236 }) 237 238 t.Run("数据库返回错误,update错误", func(t *testing.T) { 239 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 240 Return(nil, fmt.Errorf("stoarge error")) 241 if err := rlc.Update(); err != nil { 242 t.Log("pass") 243 } else { 244 t.Fatalf("error") 245 } 246 }) 247 } 248 249 /** 250 * TestRateLimitUpdate2 统计缓存中的限流数据 251 */ 252 func TestRateLimitUpdate2(t *testing.T) { 253 ctl, storage, rlc := newTestRateLimitCache(t) 254 defer ctl.Finish() 255 256 totalServices := 5 257 totalRateLimits := 15 258 259 t.Run("更新缓存后,增加部分数据,缓存正常更新", func(t *testing.T) { 260 _ = rlc.Clear() 261 262 rateLimits := genRateLimits(0, totalServices, totalRateLimits, true) 263 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 264 Return(rateLimits, nil) 265 if err := rlc.Update(); err != nil { 266 t.Fatalf("error: %s", err.Error()) 267 } 268 269 rateLimits = genRateLimits(5, totalServices, totalRateLimits, true) 270 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 271 Return(rateLimits, nil) 272 if err := rlc.Update(); err != nil { 273 t.Fatalf("error: %s", err.Error()) 274 } 275 276 if rlc.GetRateLimitsCount() == totalRateLimits*2 { 277 t.Log("pass") 278 } else { 279 t.Fatalf("actual rate limits count is %d", rlc.GetRateLimitsCount()) 280 } 281 }) 282 283 t.Run("更新缓存后,删除部分数据,缓存正常更新", func(t *testing.T) { 284 _ = rlc.Clear() 285 286 rateLimits := genRateLimits(0, totalServices, totalRateLimits, true) 287 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 288 Return(rateLimits, nil) 289 if err := rlc.Update(); err != nil { 290 t.Fatalf("error: %s", err.Error()) 291 } 292 293 for i := 0; i < totalRateLimits; i += 2 { 294 rateLimits[i].Valid = false 295 } 296 297 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 298 Return(rateLimits, nil) 299 if err := rlc.Update(); err != nil { 300 t.Fatalf("error: %s", err.Error()) 301 } 302 303 if rlc.GetRateLimitsCount() == totalRateLimits/2 { 304 t.Log("pass") 305 } else { 306 t.Fatalf("actual rate limits count is %d", 307 rlc.GetRateLimitsCount()) 308 } 309 }) 310 } 311 312 /** 313 * TestGetRateLimitsByServiceID 根据服务id获取限流数据和revision 314 */ 315 func TestGetRateLimitsByServiceID(t *testing.T) { 316 ctl, storage, rlc := newTestRateLimitCache(t) 317 defer ctl.Finish() 318 319 t.Run("通过服务ID获取数据并检查labels", func(t *testing.T) { 320 _ = rlc.Clear() 321 322 totalServices := 5 323 totalRateLimits := 15 324 rateLimits := genRateLimits(0, totalServices, totalRateLimits, true) 325 326 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 327 Return(rateLimits, nil) 328 if err := rlc.Update(); err != nil { 329 t.Fatalf("error: %s", err.Error()) 330 } 331 332 rules, _ := rlc.GetRateLimitRules(model.ServiceKey{ 333 Namespace: "default", 334 Name: "service-1", 335 }) 336 if len(rules) == totalRateLimits/totalServices { 337 t.Log("pass") 338 } else { 339 t.Fatalf("expect num is %d, actual num is %d", totalRateLimits/totalServices, len(rateLimits)) 340 } 341 342 for _, rateLimit := range rules { 343 assert.Equal(t, 1, len(rateLimit.Proto.Labels)) 344 assert.Equal(t, 1, len(rateLimit.Proto.Arguments)) 345 for _, argument := range rateLimit.Proto.Arguments { 346 assert.Equal(t, apitraffic.MatchArgument_CUSTOM, argument.Type) 347 _, hasKey := rateLimit.Proto.Labels[argument.Key] 348 assert.True(t, hasKey) 349 } 350 } 351 }) 352 353 t.Run("通过服务ID获取数据并检查argument", func(t *testing.T) { 354 _ = rlc.Clear() 355 356 totalServices := 5 357 totalRateLimits := 15 358 rateLimits := genRateLimits(0, totalServices, totalRateLimits, false) 359 360 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), rlc.IsFirstUpdate()). 361 Return(rateLimits, nil) 362 if err := rlc.Update(); err != nil { 363 t.Fatalf("error: %s", err.Error()) 364 } 365 366 rateLimits, _ = rlc.GetRateLimitRules(model.ServiceKey{ 367 Namespace: "default", 368 Name: "service-1", 369 }) 370 if len(rateLimits) == totalRateLimits/totalServices { 371 t.Log("pass") 372 } else { 373 t.Fatalf("expect num is %d, actual num is %d", totalRateLimits/totalServices, len(rateLimits)) 374 } 375 for _, rateLimit := range rateLimits { 376 assert.Equal(t, 1, len(rateLimit.Proto.Arguments)) 377 assert.Equal(t, 1, len(rateLimit.Proto.Labels)) 378 labelValue, hasKey := rateLimit.Proto.Labels["$header.host"] 379 assert.True(t, hasKey) 380 assert.Equal(t, rateLimit.Proto.Arguments[0].Value.Value.GetValue(), labelValue.GetValue().GetValue()) 381 } 382 }) 383 } 384 385 func Test_QueryRateLimitRules(t *testing.T) { 386 ctl, storage, rlc := newTestRateLimitCache(t) 387 t.Cleanup(func() { 388 ctl.Finish() 389 }) 390 391 totalServices := 5 392 totalRateLimits := 15 393 rateLimits := genRateLimits(0, totalServices, totalRateLimits, true) 394 395 storage.EXPECT().GetRateLimitsForCache(gomock.Any(), gomock.Any()).AnyTimes(). 396 Return(rateLimits, nil) 397 if err := rlc.Update(); err != nil { 398 t.Fatalf("error: %s", err.Error()) 399 } 400 401 t.Run("根据ID进行查询", func(t *testing.T) { 402 total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 403 ID: rateLimits[0].ID, 404 Offset: 0, 405 Limit: 100, 406 }) 407 408 assert.NoError(t, err) 409 assert.Equal(t, int64(1), int64(total)) 410 assert.Equal(t, int64(1), int64(len(ret))) 411 assert.Equal(t, rateLimits[0].ID, ret[0].ID) 412 }) 413 414 t.Run("根据Name进行查询", func(t *testing.T) { 415 total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 416 Name: rateLimits[0].Name, 417 Offset: 0, 418 Limit: 100, 419 }) 420 421 assert.NoError(t, err) 422 assert.Equal(t, int64(1), int64(total)) 423 assert.Equal(t, int64(1), int64(len(ret))) 424 assert.Equal(t, rateLimits[0].ID, ret[0].ID) 425 }) 426 427 t.Run("根据Namespace&Service进行查询", func(t *testing.T) { 428 total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 429 Service: "service-0", 430 Namespace: "default", 431 Offset: 0, 432 Limit: 100, 433 }) 434 435 assert.NoError(t, err) 436 assert.Equal(t, int64(3), int64(total)) 437 assert.Equal(t, int64(3), int64(len(ret))) 438 for i := range ret { 439 assert.Equal(t, "service-0", ret[i].Proto.Service.Value) 440 assert.Equal(t, "default", ret[i].Proto.Namespace.Value) 441 } 442 }) 443 444 t.Run("根据分页进行查询", func(t *testing.T) { 445 total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 446 Offset: 10, 447 Limit: 5, 448 }) 449 450 assert.NoError(t, err) 451 assert.Equal(t, int64(total), int64(len(rateLimits))) 452 assert.Equal(t, int64(5), int64(len(ret))) 453 454 total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 455 Offset: 100, 456 Limit: 5, 457 }) 458 459 assert.NoError(t, err) 460 assert.Equal(t, int64(total), int64(len(rateLimits))) 461 assert.Equal(t, int64(0), int64(len(ret))) 462 }) 463 464 t.Run("根据Disable进行查询", func(t *testing.T) { 465 disable := true 466 total, ret, err := rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 467 Disable: &disable, 468 Offset: 0, 469 Limit: 100, 470 }) 471 472 assert.NoError(t, err) 473 assert.Equal(t, int64(0), int64(total)) 474 assert.Equal(t, int64(0), int64(len(ret))) 475 476 disable = false 477 total, ret, err = rlc.QueryRateLimitRules(types.RateLimitRuleArgs{ 478 Disable: &disable, 479 Offset: 0, 480 Limit: 100, 481 }) 482 483 assert.NoError(t, err) 484 assert.Equal(t, int64(total), int64(len(rateLimits))) 485 assert.Equal(t, int64(total), int64(len(ret))) 486 for i := range ret { 487 assert.Equal(t, disable, ret[i].Disable) 488 } 489 }) 490 491 }