github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/client_conn_limiter.go (about)

     1  // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  
     3  package nodes
     4  
     5  import (
     6  	"github.com/TeaOSLab/EdgeNode/internal/zero"
     7  	"sync"
     8  )
     9  
    10  var sharedClientConnLimiter = NewClientConnLimiter()
    11  
    12  // ClientConnRemoteAddr 客户端地址定义
    13  type ClientConnRemoteAddr struct {
    14  	remoteAddr string
    15  	serverId   int64
    16  }
    17  
    18  // ClientConnLimiter 客户端连接数限制
    19  type ClientConnLimiter struct {
    20  	remoteAddrMap map[string]*ClientConnRemoteAddr // raw remote addr => remoteAddr
    21  	ipConns       map[string]map[string]zero.Zero  // remoteAddr => { raw remote addr => Zero }
    22  	serverConns   map[int64]map[string]zero.Zero   // serverId => { remoteAddr => Zero }
    23  
    24  	locker sync.Mutex
    25  }
    26  
    27  func NewClientConnLimiter() *ClientConnLimiter {
    28  	return &ClientConnLimiter{
    29  		remoteAddrMap: map[string]*ClientConnRemoteAddr{},
    30  		ipConns:       map[string]map[string]zero.Zero{},
    31  		serverConns:   map[int64]map[string]zero.Zero{},
    32  	}
    33  }
    34  
    35  // Add 添加新连接
    36  // 返回值为true的时候表示允许添加;否则表示不允许添加
    37  func (this *ClientConnLimiter) Add(rawRemoteAddr string, serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
    38  	if (maxConnsPerServer <= 0 && maxConnsPerIP <= 0) || len(remoteAddr) == 0 || serverId <= 0 {
    39  		return true
    40  	}
    41  
    42  	this.locker.Lock()
    43  	defer this.locker.Unlock()
    44  
    45  	// 检查服务连接数
    46  	var serverMap = this.serverConns[serverId]
    47  	if maxConnsPerServer > 0 {
    48  		if serverMap == nil {
    49  			serverMap = map[string]zero.Zero{}
    50  			this.serverConns[serverId] = serverMap
    51  		}
    52  
    53  		if maxConnsPerServer <= len(serverMap) {
    54  			return false
    55  		}
    56  	}
    57  
    58  	// 检查IP连接数
    59  	var ipMap = this.ipConns[remoteAddr]
    60  	if maxConnsPerIP > 0 {
    61  		if ipMap == nil {
    62  			ipMap = map[string]zero.Zero{}
    63  			this.ipConns[remoteAddr] = ipMap
    64  		}
    65  		if maxConnsPerIP > 0 && maxConnsPerIP <= len(ipMap) {
    66  			return false
    67  		}
    68  	}
    69  
    70  	this.remoteAddrMap[rawRemoteAddr] = &ClientConnRemoteAddr{
    71  		remoteAddr: remoteAddr,
    72  		serverId:   serverId,
    73  	}
    74  
    75  	if maxConnsPerServer > 0 {
    76  		serverMap[rawRemoteAddr] = zero.New()
    77  	}
    78  	if maxConnsPerIP > 0 {
    79  		ipMap[rawRemoteAddr] = zero.New()
    80  	}
    81  
    82  	return true
    83  }
    84  
    85  // Remove 删除连接
    86  func (this *ClientConnLimiter) Remove(rawRemoteAddr string) {
    87  	this.locker.Lock()
    88  	defer this.locker.Unlock()
    89  
    90  	addr, ok := this.remoteAddrMap[rawRemoteAddr]
    91  	if !ok {
    92  		return
    93  	}
    94  
    95  	delete(this.remoteAddrMap, rawRemoteAddr)
    96  	delete(this.ipConns[addr.remoteAddr], rawRemoteAddr)
    97  	delete(this.serverConns[addr.serverId], rawRemoteAddr)
    98  
    99  	if len(this.ipConns[addr.remoteAddr]) == 0 {
   100  		delete(this.ipConns, addr.remoteAddr)
   101  	}
   102  
   103  	if len(this.serverConns[addr.serverId]) == 0 {
   104  		delete(this.serverConns, addr.serverId)
   105  	}
   106  }
   107  
   108  // Conns 获取连接信息
   109  // 用于调试
   110  func (this *ClientConnLimiter) Conns() (ipConns map[string][]string, serverConns map[int64][]string) {
   111  	this.locker.Lock()
   112  	defer this.locker.Unlock()
   113  
   114  	ipConns = map[string][]string{}    // ip => [addr1, addr2, ...]
   115  	serverConns = map[int64][]string{} // serverId => [addr1, addr2, ...]
   116  
   117  	for ip, m := range this.ipConns {
   118  		for addr := range m {
   119  			ipConns[ip] = append(ipConns[ip], addr)
   120  		}
   121  	}
   122  
   123  	for serverId, m := range this.serverConns {
   124  		for addr := range m {
   125  			serverConns[serverId] = append(serverConns[serverId], addr)
   126  		}
   127  	}
   128  
   129  	return
   130  }