github.com/polarismesh/polaris@v1.17.8/plugin/ratelimit/token/api_limit.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package token
    19  
    20  import (
    21  	"errors"
    22  	"sync"
    23  
    24  	"golang.org/x/time/rate"
    25  )
    26  
    27  // apiRatelimit 接口限流类
    28  type apiRatelimit struct {
    29  	rules  map[string]*BucketRatelimit // 存储规则
    30  	apis   sync.Map                    // 存储api -> apiLimiter
    31  	config *APILimitConfig
    32  }
    33  
    34  // newAPIRatelimit 新建一个接口限流类
    35  func newAPIRatelimit(config *APILimitConfig) (*apiRatelimit, error) {
    36  	art := &apiRatelimit{}
    37  	if err := art.initialize(config); err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	return art, nil
    42  }
    43  
    44  // initialize 接口限流具体实现
    45  func (art *apiRatelimit) initialize(config *APILimitConfig) error {
    46  	art.config = config
    47  	if config == nil || !config.Open {
    48  		log.Infof("[Plugin][%s] api rate limit is not open", PluginName)
    49  		return nil
    50  	}
    51  
    52  	log.Infof("[Plugin][%s] api ratelimit open", PluginName)
    53  	if err := art.parseRules(config.Rules); err != nil {
    54  		return err
    55  	}
    56  	if err := art.parseApis(config.Apis); err != nil {
    57  		return err
    58  	}
    59  	return nil
    60  }
    61  
    62  // parseRules 解析限流规则
    63  func (art *apiRatelimit) parseRules(rules []*RateLimitRule) error {
    64  	if len(rules) == 0 {
    65  		return errors.New("invalid api rate limit config, rules are empty")
    66  	}
    67  
    68  	art.rules = make(map[string]*BucketRatelimit, len(rules))
    69  	for _, entry := range rules {
    70  		if entry.Name == "" {
    71  			return errors.New("invalid api rate limit config, some rules name are empty")
    72  		}
    73  		if entry.Limit == nil {
    74  			return errors.New("invalid api rate limit config, some rules limit are null")
    75  		}
    76  		if entry.Limit.Open && (entry.Limit.Bucket <= 0 || entry.Limit.Rate <= 0) {
    77  			return errors.New("invalid api rate limit config, rules bucket or rate is more than 0")
    78  		}
    79  		art.rules[entry.Name] = entry.Limit
    80  	}
    81  
    82  	return nil
    83  }
    84  
    85  // parseApis 解析每个api的限流
    86  func (art *apiRatelimit) parseApis(apis []*APILimitInfo) error {
    87  	if len(apis) == 0 {
    88  		return errors.New("invalid api rate limit config, apis are empty")
    89  	}
    90  
    91  	for _, entry := range apis {
    92  		if entry.Name == "" {
    93  			return errors.New("invalid api rate limit config, api name is empty")
    94  		}
    95  		if entry.Rule == "" {
    96  			return errors.New("invalid api rate limit config, api rule is empty")
    97  		}
    98  
    99  		limit, ok := art.rules[entry.Rule]
   100  		if !ok {
   101  			return errors.New("invalid api rate limit config, api rule is not found")
   102  		}
   103  		art.createLimiter(entry.Name, limit)
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  // createLimiter 创建一个私有limiter
   110  func (art *apiRatelimit) createLimiter(name string, limit *BucketRatelimit) *apiLimiter {
   111  	limiter := newAPILimiter(name, limit.Open, limit.Rate, limit.Bucket)
   112  	art.apis.Store(name, limiter)
   113  	return limiter
   114  }
   115  
   116  // 获取limiter
   117  func (art *apiRatelimit) acquireLimiter(name string) *apiLimiter {
   118  	if value, ok := art.apis.Load(name); ok {
   119  		return value.(*apiLimiter)
   120  	}
   121  
   122  	return nil
   123  }
   124  
   125  // 系统是否开启API限流
   126  func (art *apiRatelimit) isOpen() bool {
   127  	return art.config != nil && art.config.Open
   128  }
   129  
   130  // 令牌桶限流
   131  func (art *apiRatelimit) allow(name string) bool {
   132  	// 检查系统是否开启API限流
   133  	// 系统不开启API限流,则返回true通过
   134  	if !art.isOpen() {
   135  		return true
   136  	}
   137  
   138  	limiter := art.acquireLimiter(name)
   139  	if limiter == nil {
   140  		// 找不到limiter,默认返回true
   141  		return true
   142  	}
   143  
   144  	return limiter.Allow()
   145  }
   146  
   147  // 封装rate.Limiter
   148  // 每个API接口对应一个apiLimiter
   149  type apiLimiter struct {
   150  	open          bool   // 该接口是否开启限流
   151  	name          string // 接口名
   152  	*rate.Limiter        // 令牌桶对象
   153  }
   154  
   155  // newAPILimiter 新建一个apiLimiter
   156  func newAPILimiter(name string, open bool, r int, b int) *apiLimiter {
   157  	limiter := &apiLimiter{
   158  		open:    false,
   159  		name:    name,
   160  		Limiter: nil,
   161  	}
   162  	if !open {
   163  		return limiter
   164  	}
   165  
   166  	limiter.open = true
   167  	limiter.Limiter = rate.NewLimiter(rate.Limit(r), b)
   168  	return limiter
   169  }
   170  
   171  // Allow 继承rate.Limiter.Allow函数
   172  func (a *apiLimiter) Allow() bool {
   173  	// 当前接口不开启限流
   174  	if !a.open {
   175  		return true
   176  	}
   177  
   178  	return a.Limiter.Allow()
   179  }