github.com/polarismesh/polaris@v1.17.8/test/integrate/http/ratelimit_config.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package http 19 20 import ( 21 "bytes" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 27 "github.com/golang/protobuf/jsonpb" 28 "github.com/golang/protobuf/ptypes/wrappers" 29 apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" 30 apitraffic "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" 31 32 api "github.com/polarismesh/polaris/common/api/v1" 33 ) 34 35 /** 36 * @brief 限流规则数组转JSON 37 */ 38 func JSONFromRateLimits(rateLimits []*apitraffic.Rule) (*bytes.Buffer, error) { 39 m := jsonpb.Marshaler{Indent: " "} 40 41 buffer := bytes.NewBuffer([]byte{}) 42 43 buffer.Write([]byte("[")) 44 for index, rateLimit := range rateLimits { 45 if index > 0 { 46 buffer.Write([]byte(",\n")) 47 } 48 err := m.Marshal(buffer, rateLimit) 49 if err != nil { 50 return nil, err 51 } 52 } 53 54 buffer.Write([]byte("]")) 55 return buffer, nil 56 } 57 58 /** 59 * @brief 创建限流规则 60 */ 61 func (c *Client) CreateRateLimits(rateLimits []*apitraffic.Rule) (*apiservice.BatchWriteResponse, error) { 62 fmt.Printf("\ncreate rate limits\n") 63 64 url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version) 65 66 body, err := JSONFromRateLimits(rateLimits) 67 if err != nil { 68 fmt.Printf("%v\n", err) 69 return nil, err 70 } 71 72 response, err := c.SendRequest("POST", url, body) 73 if err != nil { 74 fmt.Printf("%v\n", err) 75 return nil, err 76 } 77 78 ret, err := GetBatchWriteResponse(response) 79 if err != nil { 80 fmt.Printf("%v\n", err) 81 return ret, err 82 } 83 84 return checkCreateRateLimitsResponse(ret, rateLimits) 85 } 86 87 /** 88 * @brief 删除限流规则 89 */ 90 func (c *Client) DeleteRateLimits(rateLimits []*apitraffic.Rule) error { 91 fmt.Printf("\ndelete rate limits\n") 92 93 url := fmt.Sprintf("http://%v/naming/%v/ratelimits/delete", c.Address, c.Version) 94 95 body, err := JSONFromRateLimits(rateLimits) 96 if err != nil { 97 fmt.Printf("%v\n", err) 98 return err 99 } 100 101 response, err := c.SendRequest("POST", url, body) 102 if err != nil { 103 fmt.Printf("%v\n", err) 104 return err 105 } 106 107 _, err = GetBatchWriteResponse(response) 108 if err != nil { 109 if err == io.EOF { 110 return nil 111 } 112 113 fmt.Printf("%v\n", err) 114 return err 115 } 116 return nil 117 } 118 119 /** 120 * @brief 更新限流规则 121 */ 122 func (c *Client) UpdateRateLimits(rateLimits []*apitraffic.Rule) error { 123 fmt.Printf("\nupdate rate limits\n") 124 125 url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version) 126 127 body, err := JSONFromRateLimits(rateLimits) 128 if err != nil { 129 fmt.Printf("%v\n", err) 130 return err 131 } 132 133 response, err := c.SendRequest("PUT", url, body) 134 if err != nil { 135 fmt.Printf("%v\n", err) 136 return err 137 } 138 139 _, err = GetBatchWriteResponse(response) 140 if err != nil { 141 if err == io.EOF { 142 return nil 143 } 144 145 fmt.Printf("%v\n", err) 146 return err 147 } 148 return nil 149 } 150 151 // EnableRateLimits 启用限流规则 152 func (c *Client) EnableRateLimits(rateLimits []*apitraffic.Rule) error { 153 fmt.Printf("\nenable rate limits\n") 154 155 url := fmt.Sprintf("http://%v/naming/%v/ratelimits/enable", c.Address, c.Version) 156 157 rateLimitsEnable := make([]*apitraffic.Rule, 0, len(rateLimits)) 158 for _, rateLimit := range rateLimits { 159 rateLimitsEnable = append(rateLimitsEnable, &apitraffic.Rule{ 160 Id: rateLimit.GetId(), 161 Disable: &wrappers.BoolValue{Value: true}, 162 }) 163 } 164 body, err := JSONFromRateLimits(rateLimitsEnable) 165 if err != nil { 166 fmt.Printf("%v\n", err) 167 return err 168 } 169 170 response, err := c.SendRequest("PUT", url, body) 171 if err != nil { 172 fmt.Printf("%v\n", err) 173 return err 174 } 175 176 _, err = GetBatchWriteResponse(response) 177 if err != nil { 178 if err == io.EOF { 179 return nil 180 } 181 182 fmt.Printf("%v\n", err) 183 return err 184 } 185 return nil 186 } 187 188 /** 189 * @brief 查询限流规则 190 */ 191 func (c *Client) GetRateLimits(rateLimits []*apitraffic.Rule) error { 192 fmt.Printf("\nget rate limits\n") 193 194 url := fmt.Sprintf("http://%v/naming/%v/ratelimits", c.Address, c.Version) 195 196 params := map[string][]interface{}{ 197 "namespace": {rateLimits[0].GetNamespace().GetValue()}, 198 } 199 200 url = c.CompleteURL(url, params) 201 response, err := c.SendRequest("GET", url, nil) 202 if err != nil { 203 return err 204 } 205 206 ret, err := GetBatchQueryResponse(response) 207 if err != nil { 208 fmt.Printf("%v\n", err) 209 return err 210 } 211 212 if ret.GetCode() == nil || ret.GetCode().GetValue() != api.ExecuteSuccess { 213 return errors.New("invalid batch code") 214 } 215 216 rateLimitsSize := len(rateLimits) 217 218 if ret.GetAmount() == nil || ret.GetAmount().GetValue() != uint32(rateLimitsSize) { 219 return errors.New("invalid batch amount") 220 } 221 222 if ret.GetSize() == nil || ret.GetSize().GetValue() != uint32(rateLimitsSize) { 223 return errors.New("invalid batch size") 224 } 225 226 collection := make(map[string]*apitraffic.Rule) 227 for _, rateLimit := range rateLimits { 228 collection[rateLimit.GetService().GetValue()] = rateLimit 229 } 230 231 items := ret.GetRateLimits() 232 if items == nil || len(items) != rateLimitsSize { 233 return errors.New("invalid batch rate limits") 234 } 235 236 for _, item := range items { 237 if correctItem, ok := collection[item.GetService().GetValue()]; ok { 238 if result, err := compareRateLimit(correctItem, item); !result { 239 return fmt.Errorf("invalid rate limit. namespace is %v, service is %v, err is %s", 240 item.GetNamespace().GetValue(), item.GetService().GetValue(), err.Error()) 241 } 242 } else { 243 return fmt.Errorf("rate limit not found. namespace is %v, service is %v", 244 item.GetNamespace().GetValue(), item.GetService().GetValue()) 245 } 246 } 247 return nil 248 } 249 250 /** 251 * @brief 检查创建限流规则的回复 252 */ 253 func checkCreateRateLimitsResponse(ret *apiservice.BatchWriteResponse, rateLimits []*apitraffic.Rule) ( 254 *apiservice.BatchWriteResponse, error) { 255 switch { 256 case ret.GetCode().GetValue() != api.ExecuteSuccess: 257 return nil, errors.New("invalid batch code") 258 case ret.GetSize().GetValue() != uint32(len(rateLimits)): 259 return nil, errors.New("invalid batch size") 260 case len(ret.GetResponses()) != len(rateLimits): 261 return nil, errors.New("invalid batch response") 262 } 263 264 for index, item := range ret.GetResponses() { 265 if item.GetCode().GetValue() != api.ExecuteSuccess { 266 return nil, errors.New("invalid code") 267 } 268 rateLimit := item.GetRateLimit() 269 if rateLimit == nil { 270 return nil, errors.New("empty rate limit") 271 } 272 if result, err := compareRateLimit(rateLimits[index], rateLimit); !result { 273 return nil, err 274 } 275 } 276 return ret, nil 277 } 278 279 /** 280 * @brief 比较rate limit是否相等 281 */ 282 func compareRateLimit(correctItem *apitraffic.Rule, item *apitraffic.Rule) (bool, error) { 283 switch { 284 case (correctItem.GetId().GetValue()) != "" && (correctItem.GetId().GetValue() != item.GetId().GetValue()): 285 return false, fmt.Errorf( 286 "invalid id, expect %s, actual %s", correctItem.GetId().GetValue(), item.GetId().GetValue()) 287 case correctItem.GetService().GetValue() != item.GetService().GetValue(): 288 return false, fmt.Errorf("error service, expect %s, actual %s", 289 correctItem.GetService().GetValue(), item.GetService().GetValue()) 290 case correctItem.GetNamespace().GetValue() != item.GetNamespace().GetValue(): 291 return false, fmt.Errorf("error namespace, expect %s, actual %s", 292 correctItem.GetNamespace().GetValue(), item.GetNamespace().GetValue()) 293 case correctItem.GetPriority().GetValue() != item.GetPriority().GetValue(): 294 return false, fmt.Errorf("invalid priority, expect %v, actual %v", 295 correctItem.GetPriority().GetValue(), item.GetPriority().GetValue()) 296 case correctItem.GetResource() != item.GetResource(): 297 return false, fmt.Errorf("invalid resource, expect %v, actual %v", 298 correctItem.GetResource(), item.GetResource()) 299 case correctItem.GetType() != item.GetType(): 300 return false, fmt.Errorf("error type, exepct %v, actual %v", correctItem.GetType(), item.GetType()) 301 case correctItem.GetAction().GetValue() != item.GetAction().GetValue(): 302 return false, fmt.Errorf("error action, expect %v, actual %v", 303 correctItem.GetAction().GetValue(), item.GetAction().GetValue()) 304 case correctItem.GetDisable().GetValue() != item.GetDisable().GetValue(): 305 return false, fmt.Errorf("error disable, expect %v, actual %v", 306 correctItem.GetDisable().GetValue(), item.GetDisable().GetValue()) 307 case correctItem.GetRegexCombine().GetValue() != item.GetRegexCombine().GetValue(): 308 return false, fmt.Errorf("error regex combine, expect %v, actual %v", 309 correctItem.GetRegexCombine().GetValue(), item.GetRegexCombine().GetValue()) 310 case correctItem.GetAmountMode() != item.GetAmountMode(): 311 return false, fmt.Errorf("error amount mode, expect %v, actual %v", 312 correctItem.GetAmountMode(), item.GetAmountMode()) 313 case correctItem.GetFailover() != item.GetFailover(): 314 return false, fmt.Errorf( 315 "error fail over, expect %v, actual %v", correctItem.GetFailover(), item.GetFailover()) 316 default: 317 break 318 } 319 320 if equal, err := checkField(correctItem.GetArguments(), item.GetArguments(), "arguments"); !equal { 321 return equal, err 322 } 323 324 if equal, err := checkField(correctItem.GetAmounts(), item.GetAmounts(), "amounts"); !equal { 325 return equal, err 326 } 327 328 if equal, err := checkField(correctItem.GetAdjuster(), item.GetAdjuster(), "adjuster"); !equal { 329 return equal, err 330 } 331 332 return checkField(correctItem.GetName(), item.GetName(), "cluster") 333 } 334 335 /** 336 * @brief 检查字段是否一致 337 */ 338 func checkField(correctItem, actualItem interface{}, name string) (bool, error) { 339 expect, err := json.Marshal(correctItem) 340 if err != nil { 341 panic(err) 342 } 343 actual, err := json.Marshal(actualItem) 344 if err != nil { 345 panic(err) 346 } 347 348 if string(expect) != string(actual) { 349 return false, fmt.Errorf("error %s, expect %s ,actual %s", name, expect, actual) 350 } 351 return true, nil 352 }