github.com/polarismesh/polaris@v1.17.8/test/integrate/http/ratelimit_config.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package http
    19  
    20  import (
    21  	"bytes"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  
    27  	"github.com/golang/protobuf/jsonpb"
    28  	"github.com/golang/protobuf/ptypes/wrappers"
    29  	apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage"
    30  	apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage"
    31  
    32  	api "github.com/polarismesh/polaris/common/api/v1"
    33  )
    34  
    35  /**
    36   * @brief 限流规则数组转JSON
    37   */
    38  func JSONFromRateLimits(rateLimits []*apitraffic.Rule) (*bytes.Buffer, error) {
    39  	m := jsonpb.Marshaler{Indent: " "}
    40  
    41  	buffer := bytes.NewBuffer([]byte{})
    42  
    43  	buffer.Write([]byte("["))
    44  	for index, rateLimit := range rateLimits {
    45  		if index > 0 {
    46  			buffer.Write([]byte(",\n"))
    47  		}
    48  		err := m.Marshal(buffer, rateLimit)
    49  		if err != nil {
    50  			return nil, err
    51  		}
    52  	}
    53  
    54  	buffer.Write([]byte("]"))
    55  	return buffer, nil
    56  }
    57  
    58  /**
    59   * @brief 创建限流规则
    60   */
    61  func (c *Client) CreateRateLimits(rateLimits []*apitraffic.Rule) (*apiservice.BatchWriteResponse, error) {
    62  	fmt.Printf("\ncreate rate limits\n")
    63  
    64  	url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version)
    65  
    66  	body, err := JSONFromRateLimits(rateLimits)
    67  	if err != nil {
    68  		fmt.Printf("%v\n", err)
    69  		return nil, err
    70  	}
    71  
    72  	response, err := c.SendRequest("POST", url, body)
    73  	if err != nil {
    74  		fmt.Printf("%v\n", err)
    75  		return nil, err
    76  	}
    77  
    78  	ret, err := GetBatchWriteResponse(response)
    79  	if err != nil {
    80  		fmt.Printf("%v\n", err)
    81  		return ret, err
    82  	}
    83  
    84  	return checkCreateRateLimitsResponse(ret, rateLimits)
    85  }
    86  
    87  /**
    88   * @brief 删除限流规则
    89   */
    90  func (c *Client) DeleteRateLimits(rateLimits []*apitraffic.Rule) error {
    91  	fmt.Printf("\ndelete rate limits\n")
    92  
    93  	url := fmt.Sprintf("http://%v/naming/%v/ratelimits/delete", c.Address, c.Version)
    94  
    95  	body, err := JSONFromRateLimits(rateLimits)
    96  	if err != nil {
    97  		fmt.Printf("%v\n", err)
    98  		return err
    99  	}
   100  
   101  	response, err := c.SendRequest("POST", url, body)
   102  	if err != nil {
   103  		fmt.Printf("%v\n", err)
   104  		return err
   105  	}
   106  
   107  	_, err = GetBatchWriteResponse(response)
   108  	if err != nil {
   109  		if err == io.EOF {
   110  			return nil
   111  		}
   112  
   113  		fmt.Printf("%v\n", err)
   114  		return err
   115  	}
   116  	return nil
   117  }
   118  
   119  /**
   120   * @brief 更新限流规则
   121   */
   122  func (c *Client) UpdateRateLimits(rateLimits []*apitraffic.Rule) error {
   123  	fmt.Printf("\nupdate rate limits\n")
   124  
   125  	url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version)
   126  
   127  	body, err := JSONFromRateLimits(rateLimits)
   128  	if err != nil {
   129  		fmt.Printf("%v\n", err)
   130  		return err
   131  	}
   132  
   133  	response, err := c.SendRequest("PUT", url, body)
   134  	if err != nil {
   135  		fmt.Printf("%v\n", err)
   136  		return err
   137  	}
   138  
   139  	_, err = GetBatchWriteResponse(response)
   140  	if err != nil {
   141  		if err == io.EOF {
   142  			return nil
   143  		}
   144  
   145  		fmt.Printf("%v\n", err)
   146  		return err
   147  	}
   148  	return nil
   149  }
   150  
   151  // EnableRateLimits 启用限流规则
   152  func (c *Client) EnableRateLimits(rateLimits []*apitraffic.Rule) error {
   153  	fmt.Printf("\nenable rate limits\n")
   154  
   155  	url := fmt.Sprintf("http://%v/naming/%v/ratelimits/enable", c.Address, c.Version)
   156  
   157  	rateLimitsEnable := make([]*apitraffic.Rule, 0, len(rateLimits))
   158  	for _, rateLimit := range rateLimits {
   159  		rateLimitsEnable = append(rateLimitsEnable, &apitraffic.Rule{
   160  			Id:      rateLimit.GetId(),
   161  			Disable: &wrappers.BoolValue{Value: true},
   162  		})
   163  	}
   164  	body, err := JSONFromRateLimits(rateLimitsEnable)
   165  	if err != nil {
   166  		fmt.Printf("%v\n", err)
   167  		return err
   168  	}
   169  
   170  	response, err := c.SendRequest("PUT", url, body)
   171  	if err != nil {
   172  		fmt.Printf("%v\n", err)
   173  		return err
   174  	}
   175  
   176  	_, err = GetBatchWriteResponse(response)
   177  	if err != nil {
   178  		if err == io.EOF {
   179  			return nil
   180  		}
   181  
   182  		fmt.Printf("%v\n", err)
   183  		return err
   184  	}
   185  	return nil
   186  }
   187  
   188  /**
   189   * @brief 查询限流规则
   190   */
   191  func (c *Client) GetRateLimits(rateLimits []*apitraffic.Rule) error {
   192  	fmt.Printf("\nget rate limits\n")
   193  
   194  	url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version)
   195  
   196  	params := map[string][]interface{}{
   197  		"namespace": {rateLimits[0].GetNamespace().GetValue()},
   198  	}
   199  
   200  	url = c.CompleteURL(url, params)
   201  	response, err := c.SendRequest("GET", url, nil)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	ret, err := GetBatchQueryResponse(response)
   207  	if err != nil {
   208  		fmt.Printf("%v\n", err)
   209  		return err
   210  	}
   211  
   212  	if ret.GetCode() == nil || ret.GetCode().GetValue() != api.ExecuteSuccess {
   213  		return errors.New("invalid batch code")
   214  	}
   215  
   216  	rateLimitsSize := len(rateLimits)
   217  
   218  	if ret.GetAmount() == nil || ret.GetAmount().GetValue() != uint32(rateLimitsSize) {
   219  		return errors.New("invalid batch amount")
   220  	}
   221  
   222  	if ret.GetSize() == nil || ret.GetSize().GetValue() != uint32(rateLimitsSize) {
   223  		return errors.New("invalid batch size")
   224  	}
   225  
   226  	collection := make(map[string]*apitraffic.Rule)
   227  	for _, rateLimit := range rateLimits {
   228  		collection[rateLimit.GetService().GetValue()] = rateLimit
   229  	}
   230  
   231  	items := ret.GetRateLimits()
   232  	if items == nil || len(items) != rateLimitsSize {
   233  		return errors.New("invalid batch rate limits")
   234  	}
   235  
   236  	for _, item := range items {
   237  		if correctItem, ok := collection[item.GetService().GetValue()]; ok {
   238  			if result, err := compareRateLimit(correctItem, item); !result {
   239  				return fmt.Errorf("invalid rate limit. namespace is %v, service is %v, err is %s",
   240  					item.GetNamespace().GetValue(), item.GetService().GetValue(), err.Error())
   241  			}
   242  		} else {
   243  			return fmt.Errorf("rate limit not found. namespace is %v, service is %v",
   244  				item.GetNamespace().GetValue(), item.GetService().GetValue())
   245  		}
   246  	}
   247  	return nil
   248  }
   249  
   250  /**
   251   * @brief 检查创建限流规则的回复
   252   */
   253  func checkCreateRateLimitsResponse(ret *apiservice.BatchWriteResponse, rateLimits []*apitraffic.Rule) (
   254  	*apiservice.BatchWriteResponse, error) {
   255  	switch {
   256  	case ret.GetCode().GetValue() != api.ExecuteSuccess:
   257  		return nil, errors.New("invalid batch code")
   258  	case ret.GetSize().GetValue() != uint32(len(rateLimits)):
   259  		return nil, errors.New("invalid batch size")
   260  	case len(ret.GetResponses()) != len(rateLimits):
   261  		return nil, errors.New("invalid batch response")
   262  	}
   263  
   264  	for index, item := range ret.GetResponses() {
   265  		if item.GetCode().GetValue() != api.ExecuteSuccess {
   266  			return nil, errors.New("invalid code")
   267  		}
   268  		rateLimit := item.GetRateLimit()
   269  		if rateLimit == nil {
   270  			return nil, errors.New("empty rate limit")
   271  		}
   272  		if result, err := compareRateLimit(rateLimits[index], rateLimit); !result {
   273  			return nil, err
   274  		}
   275  	}
   276  	return ret, nil
   277  }
   278  
   279  /**
   280   * @brief 比较rate limit是否相等
   281   */
   282  func compareRateLimit(correctItem *apitraffic.Rule, item *apitraffic.Rule) (bool, error) {
   283  	switch {
   284  	case (correctItem.GetId().GetValue()) != "" && (correctItem.GetId().GetValue() != item.GetId().GetValue()):
   285  		return false, fmt.Errorf(
   286  			"invalid id, expect %s, actual %s", correctItem.GetId().GetValue(), item.GetId().GetValue())
   287  	case correctItem.GetService().GetValue() != item.GetService().GetValue():
   288  		return false, fmt.Errorf("error service, expect %s, actual %s",
   289  			correctItem.GetService().GetValue(), item.GetService().GetValue())
   290  	case correctItem.GetNamespace().GetValue() != item.GetNamespace().GetValue():
   291  		return false, fmt.Errorf("error namespace, expect %s, actual %s",
   292  			correctItem.GetNamespace().GetValue(), item.GetNamespace().GetValue())
   293  	case correctItem.GetPriority().GetValue() != item.GetPriority().GetValue():
   294  		return false, fmt.Errorf("invalid priority, expect %v, actual %v",
   295  			correctItem.GetPriority().GetValue(), item.GetPriority().GetValue())
   296  	case correctItem.GetResource() != item.GetResource():
   297  		return false, fmt.Errorf("invalid resource, expect %v, actual %v",
   298  			correctItem.GetResource(), item.GetResource())
   299  	case correctItem.GetType() != item.GetType():
   300  		return false, fmt.Errorf("error type, exepct %v, actual %v", correctItem.GetType(), item.GetType())
   301  	case correctItem.GetAction().GetValue() != item.GetAction().GetValue():
   302  		return false, fmt.Errorf("error action, expect %v, actual %v",
   303  			correctItem.GetAction().GetValue(), item.GetAction().GetValue())
   304  	case correctItem.GetDisable().GetValue() != item.GetDisable().GetValue():
   305  		return false, fmt.Errorf("error disable, expect %v, actual %v",
   306  			correctItem.GetDisable().GetValue(), item.GetDisable().GetValue())
   307  	case correctItem.GetRegexCombine().GetValue() != item.GetRegexCombine().GetValue():
   308  		return false, fmt.Errorf("error regex combine, expect %v, actual %v",
   309  			correctItem.GetRegexCombine().GetValue(), item.GetRegexCombine().GetValue())
   310  	case correctItem.GetAmountMode() != item.GetAmountMode():
   311  		return false, fmt.Errorf("error amount mode, expect %v, actual %v",
   312  			correctItem.GetAmountMode(), item.GetAmountMode())
   313  	case correctItem.GetFailover() != item.GetFailover():
   314  		return false, fmt.Errorf(
   315  			"error fail over, expect %v, actual %v", correctItem.GetFailover(), item.GetFailover())
   316  	default:
   317  		break
   318  	}
   319  
   320  	if equal, err := checkField(correctItem.GetArguments(), item.GetArguments(), "arguments"); !equal {
   321  		return equal, err
   322  	}
   323  
   324  	if equal, err := checkField(correctItem.GetAmounts(), item.GetAmounts(), "amounts"); !equal {
   325  		return equal, err
   326  	}
   327  
   328  	if equal, err := checkField(correctItem.GetAdjuster(), item.GetAdjuster(), "adjuster"); !equal {
   329  		return equal, err
   330  	}
   331  
   332  	return checkField(correctItem.GetName(), item.GetName(), "cluster")
   333  }
   334  
   335  /**
   336   * @brief 检查字段是否一致
   337   */
   338  func checkField(correctItem, actualItem interface{}, name string) (bool, error) {
   339  	expect, err := json.Marshal(correctItem)
   340  	if err != nil {
   341  		panic(err)
   342  	}
   343  	actual, err := json.Marshal(actualItem)
   344  	if err != nil {
   345  		panic(err)
   346  	}
   347  
   348  	if string(expect) != string(actual) {
   349  		return false, fmt.Errorf("error %s, expect %s ,actual %s", name, expect, actual)
   350  	}
   351  	return true, nil
   352  }