github.com/polarismesh/polaris@v1.17.8/common/conn/limit/listener_test.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package connlimit 19 20 import ( 21 "context" 22 "fmt" 23 "math/rand" 24 "net" 25 "sync" 26 "sync/atomic" 27 "testing" 28 "time" 29 30 "github.com/golang/mock/gomock" 31 . "github.com/smartystreets/goconvey/convey" 32 33 "github.com/polarismesh/polaris/common/conn/limit/mock_net" 34 "github.com/polarismesh/polaris/common/utils" 35 ) 36 37 // TestConnLimit 模拟一下连接限制 38 func TestConnLimit(t *testing.T) { 39 addr := "127.0.0.1:44444" 40 host := "127.0.0.1" 41 config := &Config{ 42 OpenConnLimit: true, 43 MaxConnPerHost: 5, 44 MaxConnLimit: 3, 45 PurgeCounterInterval: time.Hour, 46 PurgeCounterExpire: time.Minute, 47 } 48 connCount := 100 49 lis, err := net.Listen("tcp", addr) 50 if err != nil { 51 t.Fatalf("%s", err) 52 } 53 54 lis, err = NewListener(lis, "tcp", config) 55 if err != nil { 56 t.Fatalf("%s", err) 57 } 58 59 if lis.(*Listener).GetHostConnCount(host) != 0 { 60 t.Fatalf("%s connNum should be 0 when no connections", host) 61 } 62 63 // 启动Server 64 go func() { 65 for { 66 conn, _ := lis.Accept() 67 go func(c net.Conn) { 68 buf := make([]byte, 10) 69 if _, err := c.Read(buf); err != nil { 70 t.Logf("server read err: %s", err.Error()) 71 _ = c.Close() 72 return 73 } 74 t.Logf("server read data: %s", string(buf)) 75 time.Sleep(time.Millisecond * 200) 76 _ = c.Close() 77 }(conn) 78 } 79 }() 80 time.Sleep(1 * time.Second) 81 82 var total int32 83 for i := 0; i < connCount; i++ { 84 go func(index int) { 85 conn, err := net.Dial("tcp", addr) 86 atomic.AddInt32(&total, 1) 87 if err != nil { 88 t.Logf("client conn server error: %s", err.Error()) 89 return 90 } 91 buf := []byte("hello") 92 if _, err := conn.Write(buf); err != nil { 93 t.Logf("client write error: %s", err.Error()) 94 _ = conn.Close() 95 return 96 } 97 }(i) 98 } 99 100 // 等待连接全部关闭 101 // time.Sleep(5 * time.Second) 102 ticker := time.NewTicker(time.Second) 103 defer ticker.Stop() 104 for range ticker.C { 105 if atomic.LoadInt32(&total) != int32(connCount) { 106 t.Logf("connection is not finished") 107 continue 108 } 109 hostCnt := lis.(*Listener).GetHostConnCount(host) 110 lisCnt := lis.(*Listener).GetListenerConnCount() 111 if hostCnt == 0 && lisCnt == 0 { 112 t.Logf("pass") 113 return 114 } 115 116 t.Logf("host conn count:%d: lis conn count:%d", hostCnt, lisCnt) 117 } 118 } 119 120 // test readTimeout场景 121 /*func TestConnLimiterReadTimeout(t *testing.T) { 122 lis, err := net.Listen("tcp", "127.0.0.1:55555") 123 if err != nil { 124 t.Fatalf("%s", err) 125 } 126 127 cfg := &Config{ 128 OpenConnLimit: true, 129 MaxConnLimit: 16, 130 MaxConnPerHost: 8, 131 ReadTimeout: time.Millisecond * 500, 132 } 133 lis, err = NewListener(lis, "http", cfg) 134 if err != nil { 135 t.Fatalf("error: %s", err.Error()) 136 } 137 defer lis.Close() 138 handler := func(conn net.Conn) { 139 for { 140 reader := bufio.NewReader(conn) 141 buf := make([]byte, 12) 142 if _, err := io.ReadFull(reader, buf); err != nil { 143 t.Logf("read full return: %s", err.Error()) 144 if e, ok := err.(net.Error); ok && e.Timeout() { 145 t.Logf("pass") 146 } else { 147 t.Fatalf("error") 148 } 149 return 150 } 151 t.Logf("%s", string(buf)) 152 go func() {conn.Close()}() 153 } 154 } 155 go func() { 156 conn, err := lis.Accept() 157 if err != nil { 158 t.Fatalf("error: %s", err.Error()) 159 } 160 go handler(conn) 161 }() 162 163 conn, err := net.Dial("tcp", "127.0.0.1:55555") 164 if err != nil { 165 t.Fatalf("error: %s", err.Error()) 166 } 167 //time.Sleep(time.Second * 1) 168 _, err = conn.Write([]byte("hello world!")) 169 if err != nil { 170 t.Logf("%s", err.Error()) 171 } 172 time.Sleep(time.Second) 173 conn.Close() 174 time.Sleep(time.Second) 175 }*/ 176 177 // TestInvalidParams test invalid conn limit param 178 func TestInvalidParams(t *testing.T) { 179 lis, err := net.Listen("tcp", "127.0.0.1:44445") 180 if err != nil { 181 t.Fatalf("%s", err) 182 } 183 defer func() { _ = lis.Close() }() 184 config := &Config{ 185 OpenConnLimit: true, 186 MaxConnPerHost: 0, 187 MaxConnLimit: 10, 188 PurgeCounterInterval: time.Hour, 189 PurgeCounterExpire: time.Minute, 190 } 191 192 t.Run("host连接限制小于1", func(t *testing.T) { 193 if _, newErr := NewListener(lis, "tcp", config); newErr == nil { 194 t.Fatalf("must be wrong for invalidMaxConnNum") 195 } 196 }) 197 t.Run("protocol为空", func(t *testing.T) { 198 config.MaxConnPerHost = 10 199 if _, err := NewListener(lis, "", config); err == nil { 200 t.Fatalf("error") 201 } else { 202 t.Logf("%s", err.Error()) 203 } 204 }) 205 t.Run("purge参数错误", func(t *testing.T) { 206 config.PurgeCounterInterval = 0 207 if _, err := NewListener(lis, "tcp1", config); err == nil { 208 t.Fatalf("error") 209 } 210 config.PurgeCounterInterval = time.Hour 211 config.PurgeCounterExpire = 0 212 if _, err := NewListener(lis, "tcp2", config); err == nil { 213 t.Fatalf("error") 214 } else { 215 t.Logf("%s", err.Error()) 216 } 217 }) 218 } 219 220 // TestListener_Accept 测试accept 221 func TestListener_Accept(t *testing.T) { 222 Convey("正常accept", t, func() { 223 ctrl := gomock.NewController(t) 224 defer ctrl.Finish() 225 addr := mock_net.NewMockAddr(ctrl) 226 conn := mock_net.NewMockConn(ctrl) 227 conn.EXPECT().Close().Return(nil).AnyTimes() 228 addr.EXPECT().String().Return("1.2.3.4:8080").AnyTimes() 229 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 230 lis := NewTestLimitListener(100, 10) 231 So(lis.accept(conn).(*Conn).isValid(), ShouldBeTrue) 232 }) 233 } 234 235 // TestLimitListener_Acquire 测试acquire 236 func TestLimitListener_Acquire(t *testing.T) { 237 ctrl := gomock.NewController(t) 238 defer ctrl.Finish() 239 conn := mock_net.NewMockConn(ctrl) 240 conn.EXPECT().Close().Return(nil).AnyTimes() 241 Convey("acquire测试", t, func() { 242 Convey("超过server监听的最大限制,返回false", func() { 243 lis := &Listener{maxConnPerHost: 1, maxConnLimit: 10, connCount: 10} 244 c := lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4") 245 So(c.isValid(), ShouldBeFalse) 246 }) 247 Convey("host首次请求,可以正常获取连接", func() { 248 lis := NewTestLimitListener(100, 10) 249 c := lis.acquire(conn, "2.3.4.5:8080", "2.3.4.5") 250 So(c.isValid(), ShouldBeTrue) 251 }) 252 Convey("host多次获取,正常", func() { 253 lis := NewTestLimitListener(15, 10) 254 for i := 0; i < 10; i++ { 255 So(lis.acquire(conn, fmt.Sprintf("1.2.3.4:%d", i), "1.2.3.4").isValid(), ShouldBeTrue) 256 } 257 So(lis.acquire(conn, fmt.Sprintf("1.2.3.4:%d", 20), "1.2.3.4").isValid(), ShouldBeFalse) 258 259 // 其他host没有超过限制,true 260 So(lis.acquire(conn, fmt.Sprintf("1.2.3.9:%d", 200), "1.2.3.9").isValid(), ShouldBeTrue) 261 // 占满listen的最大连接,前面成功了11个,剩下4个还没有满 262 for i := 0; i < 4; i++ { 263 So(lis.acquire(conn, fmt.Sprintf("1.2.3.8:%d", i), "1.2.3.8").isValid(), ShouldBeTrue) 264 } 265 266 // 总连接数被占满,false 267 So(lis.acquire(conn, fmt.Sprintf("1.2.3.19:%d", 123), "1.2.3.9").isValid(), ShouldBeFalse) 268 }) 269 }) 270 } 271 272 // TestLimitListener_ReLease release 273 func TestLimitListener_ReLease(t *testing.T) { 274 ctrl := gomock.NewController(t) 275 defer ctrl.Finish() 276 conn := mock_net.NewMockConn(ctrl) 277 conn.EXPECT().Close().Return(nil).AnyTimes() 278 t.Run("并发释放测试", func(t *testing.T) { 279 lis := NewTestLimitListener(2048000, 204800) 280 conns := make([]net.Conn, 0, 10240) 281 for i := 0; i < 10240; i++ { 282 c := lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4") 283 conns = append(conns, c) 284 } 285 286 var wg sync.WaitGroup 287 wg.Add(1) 288 go func() { 289 defer wg.Done() 290 for i := 0; i < 10240; i++ { 291 lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4") 292 } 293 }() 294 295 for i := 0; i < 2048; i++ { 296 wg.Add(1) 297 go func(index int) { 298 for j := 0; j < 5; j++ { 299 c := conns[index*5+j] 300 _ = c.Close() 301 } 302 wg.Done() 303 }(i) 304 } 305 306 wg.Wait() 307 var remain int32 = 10240 + 10240 - 2048*5 308 if lis.GetListenerConnCount() == remain && lis.GetHostConnCount("1.2.3.4") == remain { 309 t.Logf("pass") 310 } else { 311 t.Fatalf("error: %d, %d", lis.GetListenerConnCount(), lis.GetHostConnCount("1.2.3.4")) 312 } 313 }) 314 } 315 316 // TestWhiteList 白名单测试 317 func TestWhiteList(t *testing.T) { 318 ctrl := gomock.NewController(t) 319 defer ctrl.Finish() 320 conn := mock_net.NewMockConn(ctrl) 321 conn.EXPECT().Close().Return(nil).AnyTimes() 322 323 Convey("白名单下,限制不生效", t, func() { 324 listener := NewTestLimitListener(100, 2) 325 listener.whiteList = map[string]bool{ 326 "8.8.8.8": true, 327 } 328 for i := 0; i < 100; i++ { 329 So(listener.acquire(conn, "8.8.8.8:123", "8.8.8.8").isValid(), ShouldBeTrue) 330 } 331 // 超过了机器的100限制,白名单也不放过 332 So(listener.acquire(conn, "8.8.8.8:123", "8.8.8.8").isValid(), ShouldBeFalse) 333 So(listener.acquire(conn, "8.8.8.9:123", "8.8.8.9").isValid(), ShouldBeFalse) 334 So(listener.acquire(conn, "8.8.8.10:123", "8.8.8.10").isValid(), ShouldBeFalse) 335 }) 336 } 337 338 // TestActiveConns 测试activeConns 339 func TestActiveConns(t *testing.T) { 340 ctrl := gomock.NewController(t) 341 defer ctrl.Finish() 342 conn := mock_net.NewMockConn(ctrl) 343 conn.EXPECT().Close().Return(nil).AnyTimes() 344 listener := NewTestLimitListener(1024, 64) 345 var conns []*Conn 346 Convey("初始化", t, func() { 347 for i := 0; i < 32; i++ { 348 c := listener.acquire(conn, fmt.Sprintf("8.8.8.8:%d", i), "8.8.8.8") 349 So(c.isValid(), ShouldBeTrue) 350 conns = append(conns, c) 351 } 352 }) 353 Convey("测试活跃连接", t, func() { 354 Convey("已活跃的连接可以正常存储", func() { 355 actives := listener.GetHostActiveConns("8.8.8.8") 356 So(actives, ShouldNotBeNil) 357 So(len(actives), ShouldEqual, 32) 358 }) 359 Convey("连接关闭,活跃连接map会剔除", func() { 360 for i := 0; i < 8; i++ { 361 _ = conns[i].Close() 362 } 363 actives := listener.GetHostActiveConns("8.8.8.8") 364 So(actives, ShouldNotBeNil) 365 So(len(actives), ShouldEqual, 24) // 32 - 8 366 }) 367 Convey("重复关闭连接,活跃连接map不受影响,size不受影响", func() { 368 for i := 0; i < 8; i++ { 369 _ = conns[i].Close() 370 } 371 actives := listener.GetHostActiveConns("8.8.8.8") 372 So(actives, ShouldNotBeNil) 373 So(len(actives), ShouldEqual, 24) 374 So(listener.GetHostConnCount("8.8.8.8"), ShouldEqual, 24) 375 }) 376 Convey("多主机数据,可以正常存储", func() { 377 for i := 0; i < 16; i++ { 378 c := listener.acquire(conn, fmt.Sprintf("8.8.8.16:%d", i), "8.8.8.16") 379 So(c.isValid(), ShouldBeTrue) 380 conns = append(conns, c) 381 } 382 actives := listener.GetHostActiveConns("8.8.8.16") 383 So(actives, ShouldNotBeNil) 384 So(len(actives), ShouldEqual, 16) 385 }) 386 }) 387 } 388 389 // TestPurgeExpireCounterHandler 测试回收过期Counter函数 390 func TestPurgeExpireCounterHandler(t *testing.T) { 391 Convey("可以正常purge", t, func() { 392 listener := NewTestLimitListener(1024, 16) 393 listener.purgeCounterExpire = 3 394 for i := 0; i < 102400; i++ { 395 ct := newCounter() 396 ct.size = 0 397 listener.conns.Store(fmt.Sprintf("127.0.0.:%d", i), ct) 398 } 399 time.Sleep(time.Second * 4) 400 for i := 0; i < 102400; i++ { 401 ct := newCounter() 402 ct.size = 0 403 listener.conns.Store(fmt.Sprintf("127.0.1.%d", i), ct) 404 } 405 So(listener.GetDistinctHostCount(), ShouldEqual, 204800) 406 listener.purgeExpireCounterHandler() 407 So(listener.GetDistinctHostCount(), ShouldEqual, 102400) 408 }) 409 Convey("并发store和range,扫描的速度测试", t, func() { 410 listener := NewTestLimitListener(1024, 16) 411 listener.purgeCounterInterval = time.Microsecond * 10 412 listener.purgeCounterExpire = 1 413 rand.Seed(time.Now().UnixNano()) 414 ctx, cancel := context.WithCancel(context.Background()) 415 for i := 0; i < 10240; i++ { 416 go func(index int) { 417 for { 418 select { 419 case <-ctx.Done(): 420 return 421 default: 422 } 423 for j := 0; j < 100; j++ { 424 ct := newCounter() 425 ct.size = 0 426 listener.conns.Store(fmt.Sprintf("%d.%d", index, j), ct) 427 time.Sleep(time.Millisecond) 428 } 429 } 430 431 }(i) 432 } 433 listener.purgeExpireCounter(ctx) 434 <-time.After(time.Second * 5) 435 cancel() 436 }) 437 } 438 439 // NewTestLimitListener 返回一个测试listener 440 func NewTestLimitListener(maxLimit int32, hostLimit int32) *Listener { 441 return &Listener{ 442 maxConnLimit: maxLimit, 443 maxConnPerHost: hostLimit, 444 purgeCounterInterval: time.Hour, 445 purgeCounterExpire: 300, 446 conns: utils.NewSyncMap[string, *counter](), 447 } 448 }