github.com/polarismesh/polaris@v1.17.8/common/conn/limit/listener.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package connlimit
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"net"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	"github.com/pkg/errors"
    30  
    31  	"github.com/polarismesh/polaris/common/log"
    32  	"github.com/polarismesh/polaris/common/utils"
    33  )
    34  
    35  const (
    36  	// 最少连接数
    37  	minHostConnLimit = 1
    38  )
    39  
    40  // 计数器
    41  // limit connections for every ip
    42  type counter struct {
    43  	size       int32
    44  	actives    map[string]*Conn // 活跃的连接
    45  	mu         *sync.RWMutex
    46  	lastAccess int64
    47  }
    48  
    49  // 新增计数器
    50  func newCounter() *counter {
    51  	return &counter{
    52  		size:       1,
    53  		actives:    make(map[string]*Conn),
    54  		mu:         &sync.RWMutex{},
    55  		lastAccess: time.Now().Unix(),
    56  	}
    57  }
    58  
    59  // Listener 包装 net.Listener
    60  type Listener struct {
    61  	net.Listener
    62  	protocol             string                           // 协议,主要用以日志记录与全局对象索引
    63  	conns                *utils.SyncMap[string, *counter] // 保存 ip -> counter
    64  	maxConnPerHost       int32                            // 每个IP最多的连接数
    65  	maxConnLimit         int32                            // 当前listener最大的连接数限制
    66  	whiteList            map[string]bool                  // 白名单列表
    67  	readTimeout          time.Duration                    // 读超时
    68  	connCount            int32                            // 当前listener保持连接的个数
    69  	purgeCounterInterval time.Duration                    // 回收过期counter的
    70  	purgeCounterExpire   int64                            // counter过期的秒数
    71  	purgeCancel          context.CancelFunc               // 停止purge协程的ctx
    72  }
    73  
    74  // NewListener returns a new listener
    75  // @param l 网络连接
    76  // @param protocol 当前listener的七层协议,比如http,grpc等
    77  func NewListener(l net.Listener, protocol string, config *Config) (net.Listener, error) {
    78  	// 参数校验
    79  	if protocol == "" {
    80  		log.Errorf("[ConnLimit] listener is missing protocol")
    81  		return nil, errors.New("listener is missing protocol")
    82  	}
    83  	if config == nil || !config.OpenConnLimit {
    84  		log.Infof("[ConnLimit][%s] apiserver is not open conn limit", protocol)
    85  		return l, nil
    86  	}
    87  	if config.PurgeCounterInterval == 0 || config.PurgeCounterExpire == 0 {
    88  		log.Errorf("[ConnLimit][%s] purge params invalid", protocol)
    89  		return nil, errors.New("purge params invalid")
    90  	}
    91  
    92  	hostConnLimit := int32(config.MaxConnPerHost)
    93  	lisConnLimit := int32(config.MaxConnLimit)
    94  	// 参数校验, perHost阈值不能小于1
    95  	if hostConnLimit < minHostConnLimit {
    96  		return nil, fmt.Errorf("invalid conn limit: %d, can't be smaller than %d", hostConnLimit, minHostConnLimit)
    97  	}
    98  
    99  	whites := strings.Split(config.WhiteList, ",")
   100  	whiteList := make(map[string]bool, len(whites))
   101  	for _, entry := range whites {
   102  		if entry == "" {
   103  			continue
   104  		}
   105  
   106  		whiteList[entry] = true
   107  	}
   108  	log.Infof("[ConnLimit] host conn limit white list: %+v", whites)
   109  
   110  	lis := &Listener{
   111  		Listener:             l,
   112  		protocol:             protocol,
   113  		maxConnPerHost:       hostConnLimit,
   114  		maxConnLimit:         lisConnLimit,
   115  		whiteList:            whiteList,
   116  		readTimeout:          config.ReadTimeout,
   117  		purgeCounterInterval: config.PurgeCounterInterval,
   118  		purgeCounterExpire:   int64(config.PurgeCounterExpire / time.Second),
   119  		conns:                utils.NewSyncMap[string, *counter](),
   120  	}
   121  	// 把listener放到全局变量中,方便外部访问
   122  	if err := SetLimitListener(lis); err != nil {
   123  		return nil, err
   124  	}
   125  	// 启动回收协程,定时回收过期counter
   126  	ctx, cancel := context.WithCancel(context.Background())
   127  	lis.purgeExpireCounter(ctx)
   128  	lis.purgeCancel = cancel
   129  	return lis, nil
   130  }
   131  
   132  // Accept 接收连接
   133  func (l *Listener) Accept() (net.Conn, error) {
   134  	c, err := l.Listener.Accept()
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	return l.accept(c), nil
   139  }
   140  
   141  // Close 关闭连接
   142  func (l *Listener) Close() error {
   143  	log.Infof("[Listener][%s] close the listen fd", l.protocol)
   144  	l.purgeCancel()
   145  	return l.Listener.Close()
   146  }
   147  
   148  // GetHostConnCount 查看对应ip的连接数
   149  func (l *Listener) GetHostConnCount(host string) int32 {
   150  	var connNum int32
   151  	if c, ok := l.conns.Load(host); ok {
   152  		c.mu.RLock()
   153  		connNum = c.size
   154  		c.mu.RUnlock()
   155  	}
   156  
   157  	return connNum
   158  }
   159  
   160  // Range 遍历当前持有连接的host
   161  func (l *Listener) Range(fn func(host string, count int32) bool) {
   162  	l.conns.Range(func(host string, value *counter) bool {
   163  		return fn(host, l.GetHostConnCount(host))
   164  	})
   165  }
   166  
   167  // GetListenerConnCount 查看当前监听server保持的连接数
   168  func (l *Listener) GetListenerConnCount() int32 {
   169  	return atomic.LoadInt32(&l.connCount)
   170  }
   171  
   172  // GetDistinctHostCount 获取当前缓存的host的个数
   173  func (l *Listener) GetDistinctHostCount() int32 {
   174  	var count int32
   175  	l.conns.Range(func(key string, value *counter) bool {
   176  		count++
   177  		return true
   178  	})
   179  	return count
   180  }
   181  
   182  // GetHostActiveConns 获取指定host的活跃的连接
   183  func (l *Listener) GetHostActiveConns(host string) map[string]*Conn {
   184  	ct, ok := l.conns.Load(host)
   185  	if !ok {
   186  		return nil
   187  	}
   188  
   189  	ct.mu.RLock()
   190  	out := make(map[string]*Conn, len(ct.actives))
   191  	for address, conn := range ct.actives {
   192  		out[address] = conn
   193  	}
   194  	ct.mu.RUnlock()
   195  
   196  	return out
   197  }
   198  
   199  // GetHostConnStats 获取客户端连接的stat信息
   200  func (l *Listener) GetHostConnStats(host string) []*HostConnStat {
   201  	loadStat := func(h string, ct *counter) *HostConnStat {
   202  		ct.mu.RLock()
   203  		stat := &HostConnStat{
   204  			Host:       h,
   205  			Amount:     ct.size,
   206  			LastAccess: time.Unix(ct.lastAccess, 0),
   207  			Actives:    make([]string, 0, len(ct.actives)),
   208  		}
   209  		for client := range ct.actives {
   210  			stat.Actives = append(stat.Actives, client)
   211  		}
   212  		ct.mu.RUnlock()
   213  		return stat
   214  	}
   215  
   216  	var out []*HostConnStat
   217  	// 只获取一个,推荐每次只获取一个
   218  	if host != "" {
   219  		if obj, ok := l.conns.Load(host); ok {
   220  			out = append(out, loadStat(host, obj))
   221  			return out
   222  		}
   223  		return nil
   224  	}
   225  
   226  	// 全量扫描,比较耗时
   227  	l.conns.Range(func(key string, value *counter) bool {
   228  		out = append(out, loadStat(key, value))
   229  		return true
   230  	})
   231  	return out
   232  }
   233  
   234  // GetHostConnection 获取指定host和port的连接
   235  func (l *Listener) GetHostConnection(host string, port int) *Conn {
   236  	ct, ok := l.conns.Load(host)
   237  	if !ok {
   238  		return nil
   239  	}
   240  
   241  	target := fmt.Sprintf("%s:%d", host, port)
   242  	ct.mu.RLock()
   243  	defer ct.mu.RUnlock()
   244  	for address, conn := range ct.actives {
   245  		if address == target {
   246  			return conn
   247  		}
   248  	}
   249  
   250  	return nil
   251  }
   252  
   253  // 封装一层,让关键函数acquire的更具备可测试性(不需要mock net.Conn)
   254  func (l *Listener) accept(conn net.Conn) net.Conn {
   255  	address := conn.RemoteAddr().String()
   256  	// addr解析失败, 不做限制
   257  	ipPort := strings.Split(address, ":")
   258  	if len(ipPort) != 2 || ipPort[0] == "" {
   259  		return conn
   260  	}
   261  	return l.acquire(conn, address, ipPort[0])
   262  }
   263  
   264  // 包裹一下conn
   265  // 增加ip的连接计数,如果发现ip连接达到上限,则关闭
   266  // conn 	原始连接
   267  // address 	客户端地址
   268  // host 	处理后的客户端IP地址
   269  func (l *Listener) acquire(conn net.Conn, address string, host string) *Conn {
   270  	limiterConn := &Conn{
   271  		Conn:     conn,
   272  		closed:   false,
   273  		address:  address,
   274  		host:     host,
   275  		listener: l,
   276  	}
   277  
   278  	log.Debugf("acquire conn for: %s", address)
   279  	if ok := l.incConnCount(); !ok {
   280  		log.Errorf("[ConnLimit][%s] host(%s) reach apiserver conn limit(%d)", l.protocol, host, l.maxConnLimit)
   281  		limiterConn.closed = true
   282  		_ = limiterConn.Conn.Close()
   283  		return limiterConn
   284  	}
   285  
   286  	c, ok := l.conns.Load(host)
   287  	// 首次访问, 置1返回ok
   288  	if !ok {
   289  		ctr := newCounter()
   290  		ctr.actives[address] = limiterConn
   291  		l.conns.Store(host, ctr)
   292  		return limiterConn
   293  	}
   294  
   295  	c.mu.Lock() // release是并发的,因此需要加锁
   296  	// 如果连接数已经超过阈值, 则返回失败, 使用方要调用release减少计数
   297  	// 如果在白名单中,则直接忽略host连接限制
   298  	if c.size >= l.maxConnPerHost && !l.ignoreHostConnLimit(host) {
   299  		c.mu.Unlock()
   300  		l.descConnCount() // 前面已经增加了计数,因此这里失败,必须减少计数
   301  		log.Errorf("[ConnLimit][%s] host(%s) reach host conn limit(%d)", l.protocol, host, l.maxConnPerHost)
   302  		limiterConn.closed = true
   303  		_ = limiterConn.Conn.Close()
   304  		return limiterConn
   305  	}
   306  
   307  	// 单个IP的连接,还有冗余,则增加计数
   308  	c.size++
   309  	c.actives[address] = limiterConn
   310  	c.lastAccess = time.Now().Unix()
   311  	// map里面存储的是指针,可以不用store,这里直接对指针的内存操作
   312  	// l.conns.Store(host, c)
   313  	c.mu.Unlock()
   314  	return limiterConn
   315  }
   316  
   317  // 减少连接计数
   318  func (l *Listener) release(conn *Conn) {
   319  	log.Debugf("release conn for: %s", conn.host)
   320  	l.descConnCount()
   321  
   322  	if c, ok := l.conns.Load(conn.host); ok {
   323  		c.mu.Lock()
   324  		c.size--
   325  		// map里面存储的是指针,可以不用store,这里直接对指针的内存操作
   326  		// l.conns.Store(host, c)
   327  		delete(c.actives, conn.address)
   328  		c.mu.Unlock()
   329  	}
   330  }
   331  
   332  // 增加监听server的连接计数
   333  // 这里使用了原子变量来增加计数,先判断是否超过最大限制
   334  // 如果超过了,则立即返回false,否则计数+1
   335  // 在计数+1的过程中,即使有Desc释放过程,也不影响
   336  func (l *Listener) incConnCount() bool {
   337  	if l.maxConnLimit <= 0 {
   338  		return true
   339  	}
   340  	if count := atomic.LoadInt32(&l.connCount); count >= l.maxConnLimit {
   341  		return false
   342  	}
   343  
   344  	atomic.AddInt32(&l.connCount, 1)
   345  	return true
   346  }
   347  
   348  // 释放监听server的连接计数
   349  func (l *Listener) descConnCount() {
   350  	if l.maxConnLimit <= 0 {
   351  		return
   352  	}
   353  
   354  	atomic.AddInt32(&l.connCount, -1)
   355  }
   356  
   357  // 判断host是否在白名单中
   358  // 如果host在白名单中,则忽略host连接限制
   359  func (l *Listener) ignoreHostConnLimit(host string) bool {
   360  	_, ok := l.whiteList[host]
   361  	return ok
   362  }
   363  
   364  // 回收长时间没有访问的IP
   365  // 定时扫描
   366  func (l *Listener) purgeExpireCounter(ctx context.Context) {
   367  	go func() {
   368  		ticker := time.NewTicker(l.purgeCounterInterval)
   369  		defer ticker.Stop()
   370  		log.Infof("[Listener][%s] start doing purge expire counter", l.protocol)
   371  		for {
   372  			select {
   373  			case <-ticker.C:
   374  				l.purgeExpireCounterHandler()
   375  			case <-ctx.Done():
   376  				log.Infof("[Listener][%s] purge expire counter exit", l.protocol)
   377  				return
   378  			}
   379  		}
   380  	}()
   381  }
   382  
   383  // 回收过期counter执行函数
   384  func (l *Listener) purgeExpireCounterHandler() {
   385  	start := time.Now()
   386  	scanCount := 0
   387  	purgeCount := 0
   388  	waitDel := []string{}
   389  	l.conns.Range(func(key string, ct *counter) bool {
   390  		scanCount++
   391  		ct.mu.RLock()
   392  		if ct.size == 0 && time.Now().Unix()-ct.lastAccess > l.purgeCounterExpire {
   393  			waitDel = append(waitDel, key)
   394  			purgeCount++
   395  		}
   396  		ct.mu.RUnlock()
   397  		return true
   398  	})
   399  
   400  	for i := range waitDel {
   401  		// log.Infof("[Listener][%s] purge expire counter: %s", l.protocol, waitDel[i])
   402  		l.conns.Delete(waitDel[i])
   403  	}
   404  
   405  	spendTime := time.Since(start)
   406  	log.Infof("[Listener][%s] purge expire counter total(%d), use time: %+v, scan total(%d), scan qps: %.2f",
   407  		l.protocol, purgeCount, spendTime, scanCount, float64(scanCount)/spendTime.Seconds())
   408  }