github.com/polarismesh/polaris@v1.17.8/service/ratelimit_config.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package service 19 20 import ( 21 "context" 22 "encoding/json" 23 "fmt" 24 "strconv" 25 "time" 26 27 "github.com/gogo/protobuf/jsonpb" 28 "github.com/golang/protobuf/ptypes" 29 apimodel "github.com/polarismesh/specification/source/go/api/v1/model" 30 apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" 31 apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" 32 33 cachetypes "github.com/polarismesh/polaris/cache/api" 34 api "github.com/polarismesh/polaris/common/api/v1" 35 "github.com/polarismesh/polaris/common/model" 36 commonstore "github.com/polarismesh/polaris/common/store" 37 commontime "github.com/polarismesh/polaris/common/time" 38 "github.com/polarismesh/polaris/common/utils" 39 ) 40 41 var ( 42 // RateLimitFilters rate limit filters 43 RateLimitFilters = map[string]bool{ 44 "id": true, 45 "name": true, 46 "service": true, 47 "namespace": true, 48 "brief": true, 49 "method": true, 50 "labels": true, 51 "disable": true, 52 "offset": true, 53 "limit": true, 54 } 55 ) 56 57 // CreateRateLimits 批量创建限流规则 58 func (s *Server) CreateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { 59 if err := checkBatchRateLimits(request); err != nil { 60 return err 61 } 62 63 responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) 64 for _, rateLimit := range request { 65 response := s.CreateRateLimit(ctx, rateLimit) 66 api.Collect(responses, response) 67 } 68 return api.FormatBatchWriteResponse(responses) 69 } 70 71 // CreateRateLimit 创建限流规则 72 func (s *Server) CreateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { 73 requestID := utils.ParseRequestID(ctx) 74 75 // 参数校验 76 if resp := checkRateLimitParams(req); resp != nil { 77 return resp 78 } 79 if resp := checkRateLimitRuleParams(requestID, req); resp != nil { 80 return resp 81 } 82 83 // 构造底层数据结构 84 data, err := api2RateLimit(req, nil) 85 if err != nil { 86 log.Error(err.Error(), utils.ZapRequestID(requestID)) 87 return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) 88 } 89 90 // 存储层操作 91 if err := s.storage.CreateRateLimit(data); err != nil { 92 log.Error(err.Error(), utils.ZapRequestID(requestID)) 93 return wrapperRateLimitStoreResponse(req, err) 94 } 95 96 msg := fmt.Sprintf("create rate limit rule: id=%v, namespace=%v, service=%v, name=%v", 97 data.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), req.GetName().GetValue()) 98 log.Info(msg, utils.ZapRequestID(requestID)) 99 100 s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, data, model.OCreate)) 101 102 req.Id = utils.NewStringValue(data.ID) 103 return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) 104 } 105 106 // DeleteRateLimits 批量删除限流规则 107 func (s *Server) DeleteRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { 108 if err := checkBatchRateLimits(request); err != nil { 109 return err 110 } 111 112 responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) 113 for _, entry := range request { 114 resp := s.DeleteRateLimit(ctx, entry) 115 api.Collect(responses, resp) 116 } 117 return api.FormatBatchWriteResponse(responses) 118 } 119 120 // DeleteRateLimit 删除单个限流规则 121 func (s *Server) DeleteRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { 122 requestID := utils.ParseRequestID(ctx) 123 platformID := utils.ParsePlatformID(ctx) 124 125 // 参数校验 126 if resp := checkRevisedRateLimitParams(req); resp != nil { 127 return resp 128 } 129 130 // 检查限流规则是否存在 131 rateLimit, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) 132 if resp != nil { 133 if resp.GetCode().GetValue() == uint32(apimodel.Code_NotFoundRateLimit) { 134 return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) 135 } 136 return resp 137 } 138 139 // 生成新的revision 140 rateLimit.Revision = utils.NewUUID() 141 142 // 存储层操作 143 if err := s.storage.DeleteRateLimit(rateLimit); err != nil { 144 log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) 145 return wrapperRateLimitStoreResponse(req, err) 146 } 147 148 msg := fmt.Sprintf("delete rate limit rule: id=%v, namespace=%v, service=%v, name=%v", 149 rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Labels) 150 log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) 151 152 s.RecordHistory(ctx, 153 rateLimitRecordEntry(ctx, req, rateLimit, model.ODelete)) 154 return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) 155 } 156 157 func (s *Server) EnableRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { 158 if err := checkBatchRateLimits(request); err != nil { 159 return err 160 } 161 responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) 162 for _, entry := range request { 163 response := s.EnableRateLimit(ctx, entry) 164 api.Collect(responses, response) 165 } 166 return api.FormatBatchWriteResponse(responses) 167 } 168 169 // EnableRateLimit 启用限流规则 170 func (s *Server) EnableRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { 171 requestID := utils.ParseRequestID(ctx) 172 platformID := utils.ParsePlatformID(ctx) 173 174 // 参数校验 175 if resp := checkRevisedRateLimitParams(req); resp != nil { 176 return resp 177 } 178 179 // 检查限流规则是否存在 180 data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) 181 if resp != nil { 182 return resp 183 } 184 185 // 构造底层数据结构 186 rateLimit := &model.RateLimit{} 187 rateLimit.ID = data.ID 188 rateLimit.ServiceID = data.ServiceID 189 rateLimit.Disable = req.GetDisable().GetValue() 190 rateLimit.Revision = utils.NewUUID() 191 192 if err := s.storage.EnableRateLimit(rateLimit); err != nil { 193 log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) 194 return wrapperRateLimitStoreResponse(req, err) 195 } 196 197 msg := fmt.Sprintf("enable rate limit: id=%v, disable=%v", 198 rateLimit.ID, rateLimit.Disable) 199 log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) 200 201 s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdateEnable)) 202 return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) 203 } 204 205 // UpdateRateLimits 批量更新限流规则 206 func (s *Server) UpdateRateLimits(ctx context.Context, request []*apitraffic.Rule) *apiservice.BatchWriteResponse { 207 if err := checkBatchRateLimits(request); err != nil { 208 return err 209 } 210 211 responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) 212 for _, entry := range request { 213 response := s.UpdateRateLimit(ctx, entry) 214 api.Collect(responses, response) 215 } 216 return api.FormatBatchWriteResponse(responses) 217 } 218 219 // UpdateRateLimit 更新限流规则 220 func (s *Server) UpdateRateLimit(ctx context.Context, req *apitraffic.Rule) *apiservice.Response { 221 requestID := utils.ParseRequestID(ctx) 222 // 参数校验 223 if resp := checkRevisedRateLimitParams(req); resp != nil { 224 return resp 225 } 226 if resp := checkRateLimitRuleParams(requestID, req); resp != nil { 227 return resp 228 } 229 if resp := checkRateLimitParamsDbLen(req); resp != nil { 230 return resp 231 } 232 233 // 检查限流规则是否存在 234 data, resp := s.checkRateLimitExisted(req.GetId().GetValue(), requestID, req) 235 if resp != nil { 236 return resp 237 } 238 239 // 构造底层数据结构 240 rateLimit, err := api2RateLimit(req, data) 241 if err != nil { 242 log.Error(err.Error(), utils.ZapRequestID(requestID)) 243 return api.NewRateLimitResponse(apimodel.Code_ParseRateLimitException, req) 244 } 245 rateLimit.ID = data.ID 246 if err := s.storage.UpdateRateLimit(rateLimit); err != nil { 247 log.Error(err.Error(), utils.ZapRequestID(requestID)) 248 return wrapperRateLimitStoreResponse(req, err) 249 } 250 251 msg := fmt.Sprintf("update rate limit: id=%v, namespace=%v, service=%v, name=%v", 252 rateLimit.ID, req.GetNamespace().GetValue(), req.GetService().GetValue(), rateLimit.Name) 253 log.Info(msg, utils.ZapRequestID(requestID)) 254 255 s.RecordHistory(ctx, rateLimitRecordEntry(ctx, req, rateLimit, model.OUpdate)) 256 return api.NewRateLimitResponse(apimodel.Code_ExecuteSuccess, req) 257 } 258 259 // GetRateLimits 查询限流规则 260 func (s *Server) GetRateLimits(ctx context.Context, query map[string]string) *apiservice.BatchQueryResponse { 261 // 处理offset和limit 262 args, errResp := parseRateLimitArgs(query) 263 if errResp != nil { 264 return errResp 265 } 266 267 total, extendRateLimits, err := s.Cache().RateLimit().QueryRateLimitRules(*args) 268 if err != nil { 269 log.Errorf("get rate limits store err: %s", err.Error()) 270 return api.NewBatchQueryResponse(commonstore.StoreCode2APICode(err)) 271 } 272 273 out := api.NewBatchQueryResponse(apimodel.Code_ExecuteSuccess) 274 out.Amount = utils.NewUInt32Value(total) 275 out.Size = utils.NewUInt32Value(uint32(len(extendRateLimits))) 276 out.RateLimits = make([]*apitraffic.Rule, 0, len(extendRateLimits)) 277 for _, item := range extendRateLimits { 278 limit, err := rateLimit2Console(item) 279 if err != nil { 280 log.Errorf("get rate limits convert err: %s", err.Error()) 281 return api.NewBatchQueryResponse(apimodel.Code_ParseRateLimitException) 282 } 283 out.RateLimits = append(out.RateLimits, limit) 284 } 285 286 return out 287 } 288 289 func parseRateLimitArgs(query map[string]string) (*cachetypes.RateLimitRuleArgs, *apiservice.BatchQueryResponse) { 290 for key := range query { 291 if _, ok := RateLimitFilters[key]; !ok { 292 log.Errorf("params %s is not allowed in querying rate limits", key) 293 return nil, api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) 294 } 295 } 296 // 处理offset和limit 297 offset, limit, err := utils.ParseOffsetAndLimit(query) 298 if err != nil { 299 return nil, api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) 300 } 301 302 args := &cachetypes.RateLimitRuleArgs{ 303 Filter: query, 304 ID: query["id"], 305 Name: query["name"], 306 Service: query["service"], 307 Namespace: query["namespace"], 308 Offset: offset, 309 Limit: limit, 310 OrderField: query["order_field"], 311 OrderType: query["order_type"], 312 } 313 if val, ok := query["disable"]; ok { 314 disable, _ := strconv.ParseBool(val) 315 args.Disable = &disable 316 } 317 318 return args, nil 319 } 320 321 // checkBatchRateLimits 检查批量请求的限流规则 322 func checkBatchRateLimits(req []*apitraffic.Rule) *apiservice.BatchWriteResponse { 323 if len(req) == 0 { 324 return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) 325 } 326 327 if len(req) > MaxBatchSize { 328 return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) 329 } 330 331 return nil 332 } 333 334 // checkRateLimitValid 检查限流规则是否允许修改/删除 335 func (s *Server) checkRateLimitValid(ctx context.Context, serviceID string, req *apitraffic.Rule) ( 336 *model.Service, *apiservice.Response) { 337 requestID := utils.ParseRequestID(ctx) 338 339 service, err := s.storage.GetServiceByID(serviceID) 340 if err != nil { 341 log.Error(err.Error(), utils.ZapRequestID(requestID)) 342 return nil, api.NewRateLimitResponse(commonstore.StoreCode2APICode(err), req) 343 } 344 345 return service, nil 346 } 347 348 // checkRateLimitParams 检查限流规则基础参数 349 func checkRateLimitParams(req *apitraffic.Rule) *apiservice.Response { 350 if req == nil { 351 return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) 352 } 353 if err := checkResourceName(req.GetNamespace()); err != nil { 354 return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) 355 } 356 if err := checkResourceName(req.GetService()); err != nil { 357 return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) 358 } 359 if resp := checkRateLimitParamsDbLen(req); nil != resp { 360 return resp 361 } 362 return nil 363 } 364 365 // checkRateLimitParams 检查限流规则基础参数 366 func checkRateLimitParamsDbLen(req *apitraffic.Rule) *apiservice.Response { 367 if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { 368 return api.NewRateLimitResponse(apimodel.Code_InvalidServiceName, req) 369 } 370 if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { 371 return api.NewRateLimitResponse(apimodel.Code_InvalidNamespaceName, req) 372 } 373 if err := utils.CheckDbStrFieldLen(req.GetName(), MaxDbRateLimitName); err != nil { 374 return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitName, req) 375 } 376 return nil 377 } 378 379 // checkRateLimitRuleParams 检查限流规则其他参数 380 func checkRateLimitRuleParams(requestID string, req *apitraffic.Rule) *apiservice.Response { 381 // 检查amounts是否有重复周期 382 amounts := req.GetAmounts() 383 durations := make(map[time.Duration]bool) 384 for _, amount := range amounts { 385 d := amount.GetValidDuration() 386 duration, err := ptypes.Duration(d) 387 if err != nil { 388 log.Error(err.Error(), utils.ZapRequestID(requestID)) 389 return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) 390 } 391 durations[duration] = true 392 } 393 if len(amounts) != len(durations) { 394 return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitAmounts, req) 395 } 396 return nil 397 } 398 399 // checkRevisedRateLimitParams 检查修改/删除限流规则基础参数 400 func checkRevisedRateLimitParams(req *apitraffic.Rule) *apiservice.Response { 401 if req == nil { 402 return api.NewRateLimitResponse(apimodel.Code_EmptyRequest, req) 403 } 404 if req.GetId().GetValue() == "" { 405 return api.NewRateLimitResponse(apimodel.Code_InvalidRateLimitID, req) 406 } 407 return nil 408 } 409 410 // checkRateLimitExisted 检查限流规则是否存在 411 func (s *Server) checkRateLimitExisted( 412 id, requestID string, req *apitraffic.Rule) (*model.RateLimit, *apiservice.Response) { 413 rateLimit, err := s.storage.GetRateLimitWithID(id) 414 if err != nil { 415 log.Error(err.Error(), utils.ZapRequestID(requestID)) 416 return nil, api.NewRateLimitResponse(commonstore.StoreCode2APICode(err), req) 417 } 418 if rateLimit == nil { 419 return nil, api.NewRateLimitResponse(apimodel.Code_NotFoundRateLimit, req) 420 } 421 return rateLimit, nil 422 } 423 424 const ( 425 defaultRuleAction = "REJECT" 426 ) 427 428 // api2RateLimit 把API参数转化为内部数据结构 429 func api2RateLimit(req *apitraffic.Rule, old *model.RateLimit) (*model.RateLimit, error) { 430 rule, err := marshalRateLimitRules(req) 431 if err != nil { 432 return nil, err 433 } 434 435 labels := req.GetLabels() 436 var labelStr []byte 437 if len(labels) > 0 { 438 labelStr, err = json.Marshal(labels) 439 } 440 441 out := &model.RateLimit{ 442 ID: utils.NewUUID(), 443 Name: req.GetName().GetValue(), 444 Method: req.GetMethod().GetValue().GetValue(), 445 Disable: req.GetDisable().GetValue(), 446 Priority: req.GetPriority().GetValue(), 447 Labels: string(labelStr), 448 Rule: rule, 449 Revision: utils.NewUUID(), 450 } 451 return out, nil 452 } 453 454 // rateLimit2api 把内部数据结构转化为API参数 455 func rateLimit2Console(rateLimit *model.RateLimit) (*apitraffic.Rule, error) { 456 if rateLimit == nil { 457 return nil, nil 458 } 459 if len(rateLimit.Rule) > 0 { 460 rateLimit.Proto = &apitraffic.Rule{} 461 // 控制台查询的请求 462 if err := json.Unmarshal([]byte(rateLimit.Rule), rateLimit.Proto); err != nil { 463 return nil, err 464 } 465 // 存量标签适配到参数列表 466 if err := rateLimit.AdaptLabels(); err != nil { 467 return nil, err 468 } 469 } 470 rule := &apitraffic.Rule{} 471 rule.Id = utils.NewStringValue(rateLimit.ID) 472 rule.Name = utils.NewStringValue(rateLimit.Name) 473 rule.Priority = utils.NewUInt32Value(rateLimit.Priority) 474 rule.Ctime = utils.NewStringValue(commontime.Time2String(rateLimit.CreateTime)) 475 rule.Mtime = utils.NewStringValue(commontime.Time2String(rateLimit.ModifyTime)) 476 rule.Disable = utils.NewBoolValue(rateLimit.Disable) 477 if rateLimit.EnableTime.Year() > 2000 { 478 rule.Etime = utils.NewStringValue(commontime.Time2String(rateLimit.EnableTime)) 479 } else { 480 rule.Etime = utils.NewStringValue("") 481 } 482 rule.Revision = utils.NewStringValue(rateLimit.Revision) 483 if nil != rateLimit.Proto { 484 copyRateLimitProto(rateLimit, rule) 485 } else { 486 rule.Method = &apimodel.MatchString{Value: utils.NewStringValue(rateLimit.Method)} 487 } 488 return rule, nil 489 } 490 491 func populateDefaultRuleValue(rule *apitraffic.Rule) { 492 if rule.GetAction().GetValue() == "" { 493 rule.Action = utils.NewStringValue(defaultRuleAction) 494 } 495 } 496 497 func copyRateLimitProto(rateLimit *model.RateLimit, rule *apitraffic.Rule) { 498 // copy proto values 499 rule.Namespace = rateLimit.Proto.Namespace 500 rule.Service = rateLimit.Proto.Service 501 rule.Method = rateLimit.Proto.Method 502 rule.Arguments = rateLimit.Proto.Arguments 503 rule.Labels = rateLimit.Proto.Labels 504 rule.Resource = rateLimit.Proto.Resource 505 rule.Type = rateLimit.Proto.Type 506 rule.Amounts = rateLimit.Proto.Amounts 507 rule.RegexCombine = rateLimit.Proto.RegexCombine 508 rule.Action = rateLimit.Proto.Action 509 rule.Failover = rateLimit.Proto.Failover 510 rule.AmountMode = rateLimit.Proto.AmountMode 511 rule.Adjuster = rateLimit.Proto.Adjuster 512 rule.MaxQueueDelay = rateLimit.Proto.MaxQueueDelay 513 populateDefaultRuleValue(rule) 514 } 515 516 // rateLimit2api 把内部数据结构转化为API参数 517 func rateLimit2Client( 518 service string, namespace string, rateLimit *model.RateLimit) (*apitraffic.Rule, error) { 519 if rateLimit == nil { 520 return nil, nil 521 } 522 523 rule := &apitraffic.Rule{} 524 rule.Id = utils.NewStringValue(rateLimit.ID) 525 rule.Name = utils.NewStringValue(rateLimit.Name) 526 rule.Service = utils.NewStringValue(service) 527 rule.Namespace = utils.NewStringValue(namespace) 528 rule.Priority = utils.NewUInt32Value(rateLimit.Priority) 529 rule.Revision = utils.NewStringValue(rateLimit.Revision) 530 rule.Disable = utils.NewBoolValue(rateLimit.Disable) 531 copyRateLimitProto(rateLimit, rule) 532 return rule, nil 533 } 534 535 // marshalRateLimitRules 序列化限流规则具体内容 536 func marshalRateLimitRules(req *apitraffic.Rule) (string, error) { 537 r := &apitraffic.Rule{ 538 Name: req.GetName(), 539 Resource: req.GetResource(), 540 Service: req.GetService(), 541 Namespace: req.GetNamespace(), 542 Type: req.GetType(), 543 Amounts: req.GetAmounts(), 544 Action: req.GetAction(), 545 Disable: req.GetDisable(), 546 Report: req.GetReport(), 547 Adjuster: req.GetAdjuster(), 548 RegexCombine: req.GetRegexCombine(), 549 AmountMode: req.GetAmountMode(), 550 Failover: req.GetFailover(), 551 Arguments: req.GetArguments(), 552 Method: req.GetMethod(), 553 MaxQueueDelay: req.GetMaxQueueDelay(), 554 } 555 rule, err := json.Marshal(r) 556 if err != nil { 557 return "", err 558 } 559 return string(rule), nil 560 } 561 562 // rateLimitRecordEntry 构建rateLimit的记录entry 563 func rateLimitRecordEntry(ctx context.Context, req *apitraffic.Rule, md *model.RateLimit, 564 opt model.OperationType) *model.RecordEntry { 565 566 marshaler := jsonpb.Marshaler{} 567 detail, _ := marshaler.MarshalToString(req) 568 569 entry := &model.RecordEntry{ 570 ResourceType: model.RRateLimit, 571 ResourceName: fmt.Sprintf("%s(%s)", md.Name, md.ID), 572 Namespace: req.GetNamespace().GetValue(), 573 Operator: utils.ParseOperator(ctx), 574 OperationType: opt, 575 Detail: detail, 576 HappenTime: time.Now(), 577 } 578 579 return entry 580 } 581 582 // wrapperRateLimitStoreResponse 封装路由存储层错误 583 func wrapperRateLimitStoreResponse(rule *apitraffic.Rule, err error) *apiservice.Response { 584 resp := storeError2Response(err) 585 if resp == nil { 586 return nil 587 } 588 resp.RateLimit = rule 589 return resp 590 }