github.com/polarismesh/polaris@v1.17.8/common/conn/limit/listener_test.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package connlimit
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"math/rand"
    24  	"net"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/golang/mock/gomock"
    31  	. "github.com/smartystreets/goconvey/convey"
    32  
    33  	"github.com/polarismesh/polaris/common/conn/limit/mock_net"
    34  	"github.com/polarismesh/polaris/common/utils"
    35  )
    36  
    37  // TestConnLimit 模拟一下连接限制
    38  func TestConnLimit(t *testing.T) {
    39  	addr := "127.0.0.1:44444"
    40  	host := "127.0.0.1"
    41  	config := &Config{
    42  		OpenConnLimit:        true,
    43  		MaxConnPerHost:       5,
    44  		MaxConnLimit:         3,
    45  		PurgeCounterInterval: time.Hour,
    46  		PurgeCounterExpire:   time.Minute,
    47  	}
    48  	connCount := 100
    49  	lis, err := net.Listen("tcp", addr)
    50  	if err != nil {
    51  		t.Fatalf("%s", err)
    52  	}
    53  
    54  	lis, err = NewListener(lis, "tcp", config)
    55  	if err != nil {
    56  		t.Fatalf("%s", err)
    57  	}
    58  
    59  	if lis.(*Listener).GetHostConnCount(host) != 0 {
    60  		t.Fatalf("%s connNum should be 0 when no connections", host)
    61  	}
    62  
    63  	// 启动Server
    64  	go func() {
    65  		for {
    66  			conn, _ := lis.Accept()
    67  			go func(c net.Conn) {
    68  				buf := make([]byte, 10)
    69  				if _, err := c.Read(buf); err != nil {
    70  					t.Logf("server read err: %s", err.Error())
    71  					_ = c.Close()
    72  					return
    73  				}
    74  				t.Logf("server read data: %s", string(buf))
    75  				time.Sleep(time.Millisecond * 200)
    76  				_ = c.Close()
    77  			}(conn)
    78  		}
    79  	}()
    80  	time.Sleep(1 * time.Second)
    81  
    82  	var total int32
    83  	for i := 0; i < connCount; i++ {
    84  		go func(index int) {
    85  			conn, err := net.Dial("tcp", addr)
    86  			atomic.AddInt32(&total, 1)
    87  			if err != nil {
    88  				t.Logf("client conn server error: %s", err.Error())
    89  				return
    90  			}
    91  			buf := []byte("hello")
    92  			if _, err := conn.Write(buf); err != nil {
    93  				t.Logf("client write error: %s", err.Error())
    94  				_ = conn.Close()
    95  				return
    96  			}
    97  		}(i)
    98  	}
    99  
   100  	// 等待连接全部关闭
   101  	// time.Sleep(5 * time.Second)
   102  	ticker := time.NewTicker(time.Second)
   103  	defer ticker.Stop()
   104  	for range ticker.C {
   105  		if atomic.LoadInt32(&total) != int32(connCount) {
   106  			t.Logf("connection is not finished")
   107  			continue
   108  		}
   109  		hostCnt := lis.(*Listener).GetHostConnCount(host)
   110  		lisCnt := lis.(*Listener).GetListenerConnCount()
   111  		if hostCnt == 0 && lisCnt == 0 {
   112  			t.Logf("pass")
   113  			return
   114  		}
   115  
   116  		t.Logf("host conn count:%d: lis conn count:%d", hostCnt, lisCnt)
   117  	}
   118  }
   119  
   120  // test readTimeout场景
   121  /*func TestConnLimiterReadTimeout(t *testing.T) {
   122  	lis, err := net.Listen("tcp", "127.0.0.1:55555")
   123  	if err != nil {
   124  		t.Fatalf("%s", err)
   125  	}
   126  
   127  	cfg := &Config{
   128  		OpenConnLimit:  true,
   129  		MaxConnLimit:   16,
   130  		MaxConnPerHost: 8,
   131  		ReadTimeout:    time.Millisecond * 500,
   132  	}
   133  	lis, err = NewListener(lis, "http", cfg)
   134  	if err != nil {
   135  		t.Fatalf("error: %s", err.Error())
   136  	}
   137  	defer lis.Close()
   138  	handler := func(conn net.Conn) {
   139  		for {
   140  			reader := bufio.NewReader(conn)
   141  			buf := make([]byte, 12)
   142  			if _, err := io.ReadFull(reader, buf); err != nil {
   143  				t.Logf("read full return: %s", err.Error())
   144  				if e, ok := err.(net.Error); ok && e.Timeout() {
   145  					t.Logf("pass")
   146  				} else {
   147  					t.Fatalf("error")
   148  				}
   149  				return
   150  			}
   151  			t.Logf("%s", string(buf))
   152  			go func() {conn.Close()}()
   153  		}
   154  	}
   155  	go func() {
   156  		conn, err := lis.Accept()
   157  		if err != nil {
   158  			t.Fatalf("error: %s", err.Error())
   159  		}
   160  		go handler(conn)
   161  	}()
   162  
   163  	conn, err := net.Dial("tcp", "127.0.0.1:55555")
   164  	if err != nil {
   165  		t.Fatalf("error: %s", err.Error())
   166  	}
   167  	//time.Sleep(time.Second * 1)
   168  	_, err = conn.Write([]byte("hello world!"))
   169  	if err != nil {
   170  		t.Logf("%s", err.Error())
   171  	}
   172  	time.Sleep(time.Second)
   173  	conn.Close()
   174  	time.Sleep(time.Second)
   175  }*/
   176  
   177  // TestInvalidParams test invalid conn limit param
   178  func TestInvalidParams(t *testing.T) {
   179  	lis, err := net.Listen("tcp", "127.0.0.1:44445")
   180  	if err != nil {
   181  		t.Fatalf("%s", err)
   182  	}
   183  	defer func() { _ = lis.Close() }()
   184  	config := &Config{
   185  		OpenConnLimit:        true,
   186  		MaxConnPerHost:       0,
   187  		MaxConnLimit:         10,
   188  		PurgeCounterInterval: time.Hour,
   189  		PurgeCounterExpire:   time.Minute,
   190  	}
   191  
   192  	t.Run("host连接限制小于1", func(t *testing.T) {
   193  		if _, newErr := NewListener(lis, "tcp", config); newErr == nil {
   194  			t.Fatalf("must be wrong for invalidMaxConnNum")
   195  		}
   196  	})
   197  	t.Run("protocol为空", func(t *testing.T) {
   198  		config.MaxConnPerHost = 10
   199  		if _, err := NewListener(lis, "", config); err == nil {
   200  			t.Fatalf("error")
   201  		} else {
   202  			t.Logf("%s", err.Error())
   203  		}
   204  	})
   205  	t.Run("purge参数错误", func(t *testing.T) {
   206  		config.PurgeCounterInterval = 0
   207  		if _, err := NewListener(lis, "tcp1", config); err == nil {
   208  			t.Fatalf("error")
   209  		}
   210  		config.PurgeCounterInterval = time.Hour
   211  		config.PurgeCounterExpire = 0
   212  		if _, err := NewListener(lis, "tcp2", config); err == nil {
   213  			t.Fatalf("error")
   214  		} else {
   215  			t.Logf("%s", err.Error())
   216  		}
   217  	})
   218  }
   219  
   220  // TestListener_Accept 测试accept
   221  func TestListener_Accept(t *testing.T) {
   222  	Convey("正常accept", t, func() {
   223  		ctrl := gomock.NewController(t)
   224  		defer ctrl.Finish()
   225  		addr := mock_net.NewMockAddr(ctrl)
   226  		conn := mock_net.NewMockConn(ctrl)
   227  		conn.EXPECT().Close().Return(nil).AnyTimes()
   228  		addr.EXPECT().String().Return("1.2.3.4:8080").AnyTimes()
   229  		conn.EXPECT().RemoteAddr().Return(addr).AnyTimes()
   230  		lis := NewTestLimitListener(100, 10)
   231  		So(lis.accept(conn).(*Conn).isValid(), ShouldBeTrue)
   232  	})
   233  }
   234  
   235  // TestLimitListener_Acquire 测试acquire
   236  func TestLimitListener_Acquire(t *testing.T) {
   237  	ctrl := gomock.NewController(t)
   238  	defer ctrl.Finish()
   239  	conn := mock_net.NewMockConn(ctrl)
   240  	conn.EXPECT().Close().Return(nil).AnyTimes()
   241  	Convey("acquire测试", t, func() {
   242  		Convey("超过server监听的最大限制,返回false", func() {
   243  			lis := &Listener{maxConnPerHost: 1, maxConnLimit: 10, connCount: 10}
   244  			c := lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4")
   245  			So(c.isValid(), ShouldBeFalse)
   246  		})
   247  		Convey("host首次请求,可以正常获取连接", func() {
   248  			lis := NewTestLimitListener(100, 10)
   249  			c := lis.acquire(conn, "2.3.4.5:8080", "2.3.4.5")
   250  			So(c.isValid(), ShouldBeTrue)
   251  		})
   252  		Convey("host多次获取,正常", func() {
   253  			lis := NewTestLimitListener(15, 10)
   254  			for i := 0; i < 10; i++ {
   255  				So(lis.acquire(conn, fmt.Sprintf("1.2.3.4:%d", i), "1.2.3.4").isValid(), ShouldBeTrue)
   256  			}
   257  			So(lis.acquire(conn, fmt.Sprintf("1.2.3.4:%d", 20), "1.2.3.4").isValid(), ShouldBeFalse)
   258  
   259  			// 其他host没有超过限制,true
   260  			So(lis.acquire(conn, fmt.Sprintf("1.2.3.9:%d", 200), "1.2.3.9").isValid(), ShouldBeTrue)
   261  			// 占满listen的最大连接,前面成功了11个,剩下4个还没有满
   262  			for i := 0; i < 4; i++ {
   263  				So(lis.acquire(conn, fmt.Sprintf("1.2.3.8:%d", i), "1.2.3.8").isValid(), ShouldBeTrue)
   264  			}
   265  
   266  			// 总连接数被占满,false
   267  			So(lis.acquire(conn, fmt.Sprintf("1.2.3.19:%d", 123), "1.2.3.9").isValid(), ShouldBeFalse)
   268  		})
   269  	})
   270  }
   271  
   272  // TestLimitListener_ReLease release
   273  func TestLimitListener_ReLease(t *testing.T) {
   274  	ctrl := gomock.NewController(t)
   275  	defer ctrl.Finish()
   276  	conn := mock_net.NewMockConn(ctrl)
   277  	conn.EXPECT().Close().Return(nil).AnyTimes()
   278  	t.Run("并发释放测试", func(t *testing.T) {
   279  		lis := NewTestLimitListener(2048000, 204800)
   280  		conns := make([]net.Conn, 0, 10240)
   281  		for i := 0; i < 10240; i++ {
   282  			c := lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4")
   283  			conns = append(conns, c)
   284  		}
   285  
   286  		var wg sync.WaitGroup
   287  		wg.Add(1)
   288  		go func() {
   289  			defer wg.Done()
   290  			for i := 0; i < 10240; i++ {
   291  				lis.acquire(conn, "1.2.3.4:8080", "1.2.3.4")
   292  			}
   293  		}()
   294  
   295  		for i := 0; i < 2048; i++ {
   296  			wg.Add(1)
   297  			go func(index int) {
   298  				for j := 0; j < 5; j++ {
   299  					c := conns[index*5+j]
   300  					_ = c.Close()
   301  				}
   302  				wg.Done()
   303  			}(i)
   304  		}
   305  
   306  		wg.Wait()
   307  		var remain int32 = 10240 + 10240 - 2048*5
   308  		if lis.GetListenerConnCount() == remain && lis.GetHostConnCount("1.2.3.4") == remain {
   309  			t.Logf("pass")
   310  		} else {
   311  			t.Fatalf("error: %d, %d", lis.GetListenerConnCount(), lis.GetHostConnCount("1.2.3.4"))
   312  		}
   313  	})
   314  }
   315  
   316  // TestWhiteList 白名单测试
   317  func TestWhiteList(t *testing.T) {
   318  	ctrl := gomock.NewController(t)
   319  	defer ctrl.Finish()
   320  	conn := mock_net.NewMockConn(ctrl)
   321  	conn.EXPECT().Close().Return(nil).AnyTimes()
   322  
   323  	Convey("白名单下,限制不生效", t, func() {
   324  		listener := NewTestLimitListener(100, 2)
   325  		listener.whiteList = map[string]bool{
   326  			"8.8.8.8": true,
   327  		}
   328  		for i := 0; i < 100; i++ {
   329  			So(listener.acquire(conn, "8.8.8.8:123", "8.8.8.8").isValid(), ShouldBeTrue)
   330  		}
   331  		// 超过了机器的100限制,白名单也不放过
   332  		So(listener.acquire(conn, "8.8.8.8:123", "8.8.8.8").isValid(), ShouldBeFalse)
   333  		So(listener.acquire(conn, "8.8.8.9:123", "8.8.8.9").isValid(), ShouldBeFalse)
   334  		So(listener.acquire(conn, "8.8.8.10:123", "8.8.8.10").isValid(), ShouldBeFalse)
   335  	})
   336  }
   337  
   338  // TestActiveConns 测试activeConns
   339  func TestActiveConns(t *testing.T) {
   340  	ctrl := gomock.NewController(t)
   341  	defer ctrl.Finish()
   342  	conn := mock_net.NewMockConn(ctrl)
   343  	conn.EXPECT().Close().Return(nil).AnyTimes()
   344  	listener := NewTestLimitListener(1024, 64)
   345  	var conns []*Conn
   346  	Convey("初始化", t, func() {
   347  		for i := 0; i < 32; i++ {
   348  			c := listener.acquire(conn, fmt.Sprintf("8.8.8.8:%d", i), "8.8.8.8")
   349  			So(c.isValid(), ShouldBeTrue)
   350  			conns = append(conns, c)
   351  		}
   352  	})
   353  	Convey("测试活跃连接", t, func() {
   354  		Convey("已活跃的连接可以正常存储", func() {
   355  			actives := listener.GetHostActiveConns("8.8.8.8")
   356  			So(actives, ShouldNotBeNil)
   357  			So(len(actives), ShouldEqual, 32)
   358  		})
   359  		Convey("连接关闭,活跃连接map会剔除", func() {
   360  			for i := 0; i < 8; i++ {
   361  				_ = conns[i].Close()
   362  			}
   363  			actives := listener.GetHostActiveConns("8.8.8.8")
   364  			So(actives, ShouldNotBeNil)
   365  			So(len(actives), ShouldEqual, 24) // 32 - 8
   366  		})
   367  		Convey("重复关闭连接,活跃连接map不受影响,size不受影响", func() {
   368  			for i := 0; i < 8; i++ {
   369  				_ = conns[i].Close()
   370  			}
   371  			actives := listener.GetHostActiveConns("8.8.8.8")
   372  			So(actives, ShouldNotBeNil)
   373  			So(len(actives), ShouldEqual, 24)
   374  			So(listener.GetHostConnCount("8.8.8.8"), ShouldEqual, 24)
   375  		})
   376  		Convey("多主机数据,可以正常存储", func() {
   377  			for i := 0; i < 16; i++ {
   378  				c := listener.acquire(conn, fmt.Sprintf("8.8.8.16:%d", i), "8.8.8.16")
   379  				So(c.isValid(), ShouldBeTrue)
   380  				conns = append(conns, c)
   381  			}
   382  			actives := listener.GetHostActiveConns("8.8.8.16")
   383  			So(actives, ShouldNotBeNil)
   384  			So(len(actives), ShouldEqual, 16)
   385  		})
   386  	})
   387  }
   388  
   389  // TestPurgeExpireCounterHandler 测试回收过期Counter函数
   390  func TestPurgeExpireCounterHandler(t *testing.T) {
   391  	Convey("可以正常purge", t, func() {
   392  		listener := NewTestLimitListener(1024, 16)
   393  		listener.purgeCounterExpire = 3
   394  		for i := 0; i < 102400; i++ {
   395  			ct := newCounter()
   396  			ct.size = 0
   397  			listener.conns.Store(fmt.Sprintf("127.0.0.:%d", i), ct)
   398  		}
   399  		time.Sleep(time.Second * 4)
   400  		for i := 0; i < 102400; i++ {
   401  			ct := newCounter()
   402  			ct.size = 0
   403  			listener.conns.Store(fmt.Sprintf("127.0.1.%d", i), ct)
   404  		}
   405  		So(listener.GetDistinctHostCount(), ShouldEqual, 204800)
   406  		listener.purgeExpireCounterHandler()
   407  		So(listener.GetDistinctHostCount(), ShouldEqual, 102400)
   408  	})
   409  	Convey("并发store和range,扫描的速度测试", t, func() {
   410  		listener := NewTestLimitListener(1024, 16)
   411  		listener.purgeCounterInterval = time.Microsecond * 10
   412  		listener.purgeCounterExpire = 1
   413  		rand.Seed(time.Now().UnixNano())
   414  		ctx, cancel := context.WithCancel(context.Background())
   415  		for i := 0; i < 10240; i++ {
   416  			go func(index int) {
   417  				for {
   418  					select {
   419  					case <-ctx.Done():
   420  						return
   421  					default:
   422  					}
   423  					for j := 0; j < 100; j++ {
   424  						ct := newCounter()
   425  						ct.size = 0
   426  						listener.conns.Store(fmt.Sprintf("%d.%d", index, j), ct)
   427  						time.Sleep(time.Millisecond)
   428  					}
   429  				}
   430  
   431  			}(i)
   432  		}
   433  		listener.purgeExpireCounter(ctx)
   434  		<-time.After(time.Second * 5)
   435  		cancel()
   436  	})
   437  }
   438  
   439  // NewTestLimitListener 返回一个测试listener
   440  func NewTestLimitListener(maxLimit int32, hostLimit int32) *Listener {
   441  	return &Listener{
   442  		maxConnLimit:         maxLimit,
   443  		maxConnPerHost:       hostLimit,
   444  		purgeCounterInterval: time.Hour,
   445  		purgeCounterExpire:   300,
   446  		conns:                utils.NewSyncMap[string, *counter](),
   447  	}
   448  }