github.com/polarismesh/polaris@v1.17.8/store/mysql/ratelimit_config.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package sqldb
    19  
    20  import (
    21  	"database/sql"
    22  	"errors"
    23  	"fmt"
    24  	"strconv"
    25  	"strings"
    26  	"time"
    27  
    28  	"github.com/polarismesh/polaris/common/model"
    29  	"github.com/polarismesh/polaris/store"
    30  )
    31  
    32  var _ store.RateLimitStore = (*rateLimitStore)(nil)
    33  
    34  // rateLimitStore RateLimitStore的实现
    35  type rateLimitStore struct {
    36  	master *BaseDB
    37  	slave  *BaseDB
    38  }
    39  
    40  // CreateRateLimit 新建限流规则
    41  func (rls *rateLimitStore) CreateRateLimit(limit *model.RateLimit) error {
    42  	if limit.ID == "" || limit.Revision == "" {
    43  		return errors.New("[Store][database] create rate limit missing some params")
    44  	}
    45  	err := RetryTransaction("createRateLimit", func() error {
    46  		return rls.createRateLimit(limit)
    47  	})
    48  
    49  	return store.Error(err)
    50  }
    51  
    52  func limitToEtimeStr(limit *model.RateLimit) string {
    53  	etimeStr := "sysdate()"
    54  	if limit.Disable {
    55  		etimeStr = emptyEnableTime
    56  	}
    57  	return etimeStr
    58  }
    59  
    60  // createRateLimit
    61  func (rls *rateLimitStore) createRateLimit(limit *model.RateLimit) error {
    62  	tx, err := rls.master.Begin()
    63  	if err != nil {
    64  		log.Errorf("[Store][database] create rate limit(%+v) begin tx err: %s", limit, err.Error())
    65  		return err
    66  	}
    67  
    68  	defer func() {
    69  		_ = tx.Rollback()
    70  	}()
    71  
    72  	etimeStr := limitToEtimeStr(limit)
    73  	// 新建限流规则
    74  	str := fmt.Sprintf(`insert into ratelimit_config(
    75  			id, name, disable, service_id, method, labels, priority, rule, revision, ctime, mtime, etime)
    76  			values(?,?,?,?,?,?,?,?,?,sysdate(),sysdate(), %s)`, etimeStr)
    77  	if _, err := tx.Exec(str, limit.ID, limit.Name, limit.Disable, limit.ServiceID, limit.Method, limit.Labels,
    78  		limit.Priority, limit.Rule, limit.Revision); err != nil {
    79  		log.Errorf("[Store][database] create rate limit(%+v), sql %s err: %s", limit, str, err.Error())
    80  		return err
    81  	}
    82  
    83  	if err := tx.Commit(); err != nil {
    84  		log.Errorf("[Store][database] create rate limit(%+v) commit tx err: %s", limit, err.Error())
    85  		return err
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  // UpdateRateLimit 更新限流规则
    92  func (rls *rateLimitStore) UpdateRateLimit(limit *model.RateLimit) error {
    93  	if limit.ID == "" || limit.Revision == "" {
    94  		return errors.New("[Store][database] update rate limit missing some params")
    95  	}
    96  
    97  	err := RetryTransaction("updateRateLimit", func() error {
    98  		return rls.updateRateLimit(limit)
    99  	})
   100  
   101  	return store.Error(err)
   102  }
   103  
   104  // EnableRateLimit 启用限流规则
   105  func (rls *rateLimitStore) EnableRateLimit(limit *model.RateLimit) error {
   106  	if limit.ID == "" || limit.Revision == "" {
   107  		return errors.New("[Store][database] enable rate limit missing some params")
   108  	}
   109  
   110  	err := RetryTransaction("enableRateLimit", func() error {
   111  		return rls.enableRateLimit(limit)
   112  	})
   113  
   114  	return store.Error(err)
   115  }
   116  
   117  // enableRateLimit
   118  func (rls *rateLimitStore) enableRateLimit(limit *model.RateLimit) error {
   119  	tx, err := rls.master.Begin()
   120  	if err != nil {
   121  		log.Errorf("[Store][database] update rate limit(%+v) begin tx err: %s", limit, err.Error())
   122  		return err
   123  	}
   124  
   125  	defer func() {
   126  		_ = tx.Rollback()
   127  	}()
   128  
   129  	etimeStr := limitToEtimeStr(limit)
   130  	str := fmt.Sprintf(
   131  		`update ratelimit_config set disable = ?, revision = ?, mtime = sysdate(), etime=%s where id = ?`, etimeStr)
   132  	if _, err := tx.Exec(str, limit.Disable, limit.Revision, limit.ID); err != nil {
   133  		log.Errorf("[Store][database] update rate limit(%+v), sql %s, err: %s", limit, str, err)
   134  		return err
   135  	}
   136  
   137  	if err := tx.Commit(); err != nil {
   138  		log.Errorf("[Store][database] update rate limit(%+v) commit tx err: %s", limit, err.Error())
   139  		return err
   140  	}
   141  	return nil
   142  }
   143  
   144  // updateRateLimit
   145  func (rls *rateLimitStore) updateRateLimit(limit *model.RateLimit) error {
   146  	tx, err := rls.master.Begin()
   147  	if err != nil {
   148  		log.Errorf("[Store][database] update rate limit(%+v) begin tx err: %s", limit, err.Error())
   149  		return err
   150  	}
   151  
   152  	defer func() {
   153  		_ = tx.Rollback()
   154  	}()
   155  
   156  	etimeStr := limitToEtimeStr(limit)
   157  	str := fmt.Sprintf(`update ratelimit_config set name = ?, service_id=?, disable = ?, method= ?,
   158  			labels = ?, priority = ?, rule = ?, revision = ?, mtime = sysdate(), etime=%s where id = ?`, etimeStr)
   159  	if _, err := tx.Exec(str, limit.Name, limit.ServiceID, limit.Disable,
   160  		limit.Method, limit.Labels, limit.Priority, limit.Rule, limit.Revision, limit.ID); err != nil {
   161  		log.Errorf("[Store][database] update rate limit(%+v), sql %s, err: %s", limit, str, err)
   162  		return err
   163  	}
   164  
   165  	if err := tx.Commit(); err != nil {
   166  		log.Errorf("[Store][database] update rate limit(%+v) commit tx err: %s", limit, err.Error())
   167  		return err
   168  	}
   169  	return nil
   170  }
   171  
   172  // DeleteRateLimit 删除限流规则
   173  func (rls *rateLimitStore) DeleteRateLimit(limit *model.RateLimit) error {
   174  	if limit.ID == "" || limit.Revision == "" {
   175  		return errors.New("[Store][database] delete rate limit missing some params")
   176  	}
   177  
   178  	err := RetryTransaction("deleteRateLimit", func() error {
   179  		return rls.deleteRateLimit(limit)
   180  	})
   181  
   182  	return store.Error(err)
   183  }
   184  
   185  // deleteRateLimit
   186  func (rls *rateLimitStore) deleteRateLimit(limit *model.RateLimit) error {
   187  	tx, err := rls.master.Begin()
   188  	if err != nil {
   189  		log.Errorf("[Store][database] delete rate limit(%+v) begin tx err: %s", limit, err.Error())
   190  		return err
   191  	}
   192  
   193  	defer func() {
   194  		_ = tx.Rollback()
   195  	}()
   196  
   197  	str := `update ratelimit_config set flag = 1, mtime = sysdate() where id = ?`
   198  	if _, err := tx.Exec(str, limit.ID); err != nil {
   199  		log.Errorf("[Store][database] delete rate limit(%+v) err: %s", limit, err)
   200  		return err
   201  	}
   202  
   203  	if err := tx.Commit(); err != nil {
   204  		log.Errorf("[Store][database] delete rate limit(%+v) commit tx err: %s", limit, err.Error())
   205  		return err
   206  	}
   207  	return nil
   208  }
   209  
   210  // GetRateLimitWithID 根据限流规则ID获取限流规则
   211  func (rls *rateLimitStore) GetRateLimitWithID(id string) (*model.RateLimit, error) {
   212  	if id == "" {
   213  		log.Errorf("[Store][database] get rate limit missing some params")
   214  		return nil, errors.New("get rate limit missing some params")
   215  	}
   216  
   217  	str := `select id, name, disable, service_id, method, labels, priority, rule, revision, flag,
   218  			unix_timestamp(ctime), unix_timestamp(mtime), unix_timestamp(etime)
   219  			from ratelimit_config where id = ? and flag = 0`
   220  	rows, err := rls.master.Query(str, id)
   221  	if err != nil {
   222  		log.Errorf("[Store][database] query rate limit with id(%s) err: %s", id, err.Error())
   223  		return nil, err
   224  	}
   225  	out, err := fetchRateLimitRows(rows)
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  	if len(out) == 0 {
   230  		return nil, nil
   231  	}
   232  	return out[0], nil
   233  }
   234  
   235  // fetchRateLimitRows 读取限流数据
   236  func fetchRateLimitRows(rows *sql.Rows) ([]*model.RateLimit, error) {
   237  	defer rows.Close()
   238  	var out []*model.RateLimit
   239  	for rows.Next() {
   240  		var rateLimit model.RateLimit
   241  		var flag int
   242  		var ctime, mtime, etime int64
   243  		err := rows.Scan(&rateLimit.ID, &rateLimit.Name, &rateLimit.Disable, &rateLimit.ServiceID, &rateLimit.Method,
   244  			&rateLimit.Labels, &rateLimit.Priority, &rateLimit.Rule, &rateLimit.Revision, &flag, &ctime, &mtime, &etime)
   245  		if err != nil {
   246  			log.Errorf("[Store][database] fetch rate limit scan err: %s", err.Error())
   247  			return nil, err
   248  		}
   249  		rateLimit.CreateTime = time.Unix(ctime, 0)
   250  		rateLimit.ModifyTime = time.Unix(mtime, 0)
   251  		rateLimit.EnableTime = time.Unix(etime, 0)
   252  		rateLimit.Valid = true
   253  		if flag == 1 {
   254  			rateLimit.Valid = false
   255  		}
   256  		out = append(out, &rateLimit)
   257  	}
   258  	if err := rows.Err(); err != nil {
   259  		log.Errorf("[Store][database] fetch rate limit next err: %s", err.Error())
   260  		return nil, err
   261  	}
   262  	return out, nil
   263  }
   264  
   265  // GetRateLimitsForCache 根据修改时间拉取增量限流规则及最新版本号
   266  func (rls *rateLimitStore) GetRateLimitsForCache(mtime time.Time,
   267  	firstUpdate bool) ([]*model.RateLimit, error) {
   268  	str := `select id, name, disable, ratelimit_config.service_id, method, labels, priority, rule, revision, flag,
   269  			unix_timestamp(ratelimit_config.ctime), unix_timestamp(ratelimit_config.mtime), 
   270  			unix_timestamp(ratelimit_config.etime) from ratelimit_config 
   271  			where ratelimit_config.mtime > FROM_UNIXTIME(?)`
   272  	if firstUpdate {
   273  		str += " and flag != 1"
   274  	}
   275  	rows, err := rls.slave.Query(str, timeToTimestamp(mtime))
   276  	if err != nil {
   277  		log.Errorf("[Store][database] query rate limits with mtime err: %s", err.Error())
   278  		return nil, err
   279  	}
   280  	rateLimits, err := fetchRateLimitCacheRows(rows)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	return rateLimits, nil
   285  }
   286  
   287  // fetchRateLimitCacheRows 读取限流数据以及最新版本号
   288  func fetchRateLimitCacheRows(rows *sql.Rows) ([]*model.RateLimit, error) {
   289  	defer rows.Close()
   290  
   291  	var rateLimits []*model.RateLimit
   292  
   293  	for rows.Next() {
   294  		var (
   295  			rateLimit           model.RateLimit
   296  			ctime, mtime, etime int64
   297  			serviceID           string
   298  			flag                int
   299  		)
   300  		err := rows.Scan(&rateLimit.ID, &rateLimit.Name, &rateLimit.Disable, &serviceID, &rateLimit.Method, &rateLimit.Labels,
   301  			&rateLimit.Priority, &rateLimit.Rule, &rateLimit.Revision, &flag, &ctime, &mtime, &etime)
   302  		if err != nil {
   303  			log.Errorf("[Store][database] fetch rate limit cache scan err: %s", err.Error())
   304  			return nil, err
   305  		}
   306  		rateLimit.CreateTime = time.Unix(ctime, 0)
   307  		rateLimit.ModifyTime = time.Unix(mtime, 0)
   308  		rateLimit.Valid = true
   309  		if flag == 1 {
   310  			rateLimit.Valid = false
   311  		}
   312  		rateLimit.ServiceID = serviceID
   313  
   314  		rateLimits = append(rateLimits, &rateLimit)
   315  	}
   316  
   317  	if err := rows.Err(); err != nil {
   318  		log.Errorf("[Store][database] fetch rate limit cache next err: %s", err.Error())
   319  		return nil, err
   320  	}
   321  	return rateLimits, nil
   322  }
   323  
   324  const (
   325  	briefSearch = "brief"
   326  )
   327  
   328  // GetExtendRateLimits 根据过滤条件获取限流规则及数目
   329  func (rls *rateLimitStore) GetExtendRateLimits(
   330  	filter map[string]string, offset uint32, limit uint32) (uint32, []*model.ExtendRateLimit, error) {
   331  	var out []*model.ExtendRateLimit
   332  	var err error
   333  	if bValue, ok := filter[briefSearch]; ok && strings.ToLower(bValue) == "true" {
   334  		out, err = rls.getBriefRateLimits(filter, offset, limit)
   335  	} else {
   336  		out, err = rls.getExpandRateLimits(filter, offset, limit)
   337  	}
   338  	if err != nil {
   339  		return 0, nil, err
   340  	}
   341  	num, err := rls.getExpandRateLimitsCount(filter)
   342  	if err != nil {
   343  		return 0, nil, err
   344  	}
   345  	return num, out, nil
   346  }
   347  
   348  // getBriefRateLimits 获取列表的概要信息
   349  func (rls *rateLimitStore) getBriefRateLimits(
   350  	filter map[string]string, offset uint32, limit uint32) ([]*model.ExtendRateLimit, error) {
   351  	str := `select ratelimit_config.id, ratelimit_config.name, ratelimit_config.disable,
   352              ratelimit_config.service_id, ratelimit_config.method, unix_timestamp(ratelimit_config.ctime), 
   353  			unix_timestamp(ratelimit_config.mtime), unix_timestamp(ratelimit_config.etime) 
   354  			from ratelimit_config where ratelimit_config.flag = 0`
   355  
   356  	queryStr, args := genFilterRateLimitSQL(filter)
   357  	args = append(args, offset, limit)
   358  	str = str + queryStr + ` order by ratelimit_config.mtime desc limit ?, ?`
   359  
   360  	rows, err := rls.master.Query(str, args...)
   361  	if err != nil {
   362  		log.Errorf("[Store][database] query rate limits err: %s", err.Error())
   363  		return nil, err
   364  	}
   365  	out, err := fetchBriefRateLimitRows(rows)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  	return out, nil
   370  }
   371  
   372  // fetchBriefRateLimitRows fetch the brief ratelimit list
   373  func fetchBriefRateLimitRows(rows *sql.Rows) ([]*model.ExtendRateLimit, error) {
   374  	defer rows.Close()
   375  	var out []*model.ExtendRateLimit
   376  	for rows.Next() {
   377  		var expand model.ExtendRateLimit
   378  		expand.RateLimit = &model.RateLimit{}
   379  		var ctime, mtime, etime int64
   380  		err := rows.Scan(
   381  			&expand.RateLimit.ID,
   382  			&expand.RateLimit.Name,
   383  			&expand.RateLimit.Disable,
   384  			&expand.RateLimit.ServiceID,
   385  			&expand.RateLimit.Method, &ctime, &mtime, &etime)
   386  		if err != nil {
   387  			log.Errorf("[Store][database] fetch brief rate limit scan err: %s", err.Error())
   388  			return nil, err
   389  		}
   390  		expand.RateLimit.CreateTime = time.Unix(ctime, 0)
   391  		expand.RateLimit.ModifyTime = time.Unix(mtime, 0)
   392  		expand.RateLimit.EnableTime = time.Unix(etime, 0)
   393  		out = append(out, &expand)
   394  	}
   395  	if err := rows.Err(); err != nil {
   396  		log.Errorf("[Store][database] fetch brief rate limit next err: %s", err.Error())
   397  		return nil, err
   398  	}
   399  	return out, nil
   400  }
   401  
   402  // getExpandRateLimits 根据过滤条件获取限流规则
   403  func (rls *rateLimitStore) getExpandRateLimits(
   404  	filter map[string]string, offset uint32, limit uint32) ([]*model.ExtendRateLimit, error) {
   405  	str := `select ratelimit_config.id, ratelimit_config.name, ratelimit_config.disable,
   406              ratelimit_config.service_id, ratelimit_config.method, ratelimit_config.labels, 
   407              ratelimit_config.priority, ratelimit_config.rule, ratelimit_config.revision, 
   408              unix_timestamp(ratelimit_config.ctime), unix_timestamp(ratelimit_config.mtime), 
   409  			unix_timestamp(ratelimit_config.etime) 
   410  			from ratelimit_config 
   411  			where ratelimit_config.flag = 0`
   412  
   413  	queryStr, args := genFilterRateLimitSQL(filter)
   414  	args = append(args, offset, limit)
   415  	str = str + queryStr + ` order by ratelimit_config.mtime desc limit ?, ?`
   416  
   417  	rows, err := rls.master.Query(str, args...)
   418  	if err != nil {
   419  		log.Errorf("[Store][database] query rate limits err: %s", err.Error())
   420  		return nil, err
   421  	}
   422  	out, err := fetchExpandRateLimitRows(rows)
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  	return out, nil
   427  }
   428  
   429  // fetchExpandRateLimitRows 读取包含服务信息的限流数据
   430  func fetchExpandRateLimitRows(rows *sql.Rows) ([]*model.ExtendRateLimit, error) {
   431  	defer rows.Close()
   432  	var out []*model.ExtendRateLimit
   433  	for rows.Next() {
   434  		var expand model.ExtendRateLimit
   435  		expand.RateLimit = &model.RateLimit{}
   436  		var ctime, mtime, etime int64
   437  		err := rows.Scan(
   438  			&expand.RateLimit.ID,
   439  			&expand.RateLimit.Name,
   440  			&expand.RateLimit.Disable,
   441  			&expand.RateLimit.ServiceID,
   442  			&expand.RateLimit.Method,
   443  			&expand.RateLimit.Labels,
   444  			&expand.RateLimit.Priority,
   445  			&expand.RateLimit.Rule,
   446  			&expand.RateLimit.Revision, &ctime, &mtime, &etime)
   447  		if err != nil {
   448  			log.Errorf("[Store][database] fetch expand rate limit scan err: %s", err.Error())
   449  			return nil, err
   450  		}
   451  		expand.RateLimit.CreateTime = time.Unix(ctime, 0)
   452  		expand.RateLimit.ModifyTime = time.Unix(mtime, 0)
   453  		expand.RateLimit.EnableTime = time.Unix(etime, 0)
   454  		out = append(out, &expand)
   455  	}
   456  	if err := rows.Err(); err != nil {
   457  		log.Errorf("[Store][database] fetch expand rate limit next err: %s", err.Error())
   458  		return nil, err
   459  	}
   460  	return out, nil
   461  }
   462  
   463  // getExpandRateLimitsCount 根据过滤条件获取限流规则数目
   464  func (rls *rateLimitStore) getExpandRateLimitsCount(filter map[string]string) (uint32, error) {
   465  	str := `select count(*) from ratelimit_config where ratelimit_config.flag = 0`
   466  
   467  	queryStr, args := genFilterRateLimitSQL(filter)
   468  	str = str + queryStr
   469  	var total uint32
   470  	err := rls.master.QueryRow(str, args...).Scan(&total)
   471  	switch {
   472  	case err == sql.ErrNoRows:
   473  		return 0, nil
   474  	case err != nil:
   475  		log.Errorf("[Store][database] get expand rate limits count err: %s", err.Error())
   476  		return 0, err
   477  	default:
   478  	}
   479  	return total, nil
   480  }
   481  
   482  var queryKeyToDbColumn = map[string]string{
   483  	"id":      "ratelimit_config.id",
   484  	"name":    "ratelimit_config.name",
   485  	"method":  "ratelimit_config.method",
   486  	"labels":  "ratelimit_config.labels",
   487  	"disable": "ratelimit_config.disable",
   488  }
   489  
   490  // genFilterRateLimitSQL 生成查询语句的过滤语句
   491  func genFilterRateLimitSQL(query map[string]string) (string, []interface{}) {
   492  	str := ""
   493  	args := make([]interface{}, 0, len(query))
   494  	for key, value := range query {
   495  		var arg interface{}
   496  		sqlKey := queryKeyToDbColumn[key]
   497  		if len(sqlKey) == 0 {
   498  			continue
   499  		}
   500  		if key == "name" || key == "method" || key == "labels" {
   501  			str += fmt.Sprintf(" and %s like ?", sqlKey)
   502  			arg = "%" + value + "%"
   503  		} else if key == "disable" {
   504  			str += fmt.Sprintf(" and %s = ?", sqlKey)
   505  			arg, _ = strconv.ParseBool(value)
   506  		} else {
   507  			str += fmt.Sprintf(" and %s = ?", sqlKey)
   508  			arg = value
   509  		}
   510  		args = append(args, arg)
   511  	}
   512  	return str, args
   513  }