github.com/polarismesh/polaris@v1.17.8/common/conn/limit/listener.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package connlimit 19 20 import ( 21 "context" 22 "fmt" 23 "net" 24 "strings" 25 "sync" 26 "sync/atomic" 27 "time" 28 29 "github.com/pkg/errors" 30 31 "github.com/polarismesh/polaris/common/log" 32 "github.com/polarismesh/polaris/common/utils" 33 ) 34 35 const ( 36 // 最少连接数 37 minHostConnLimit = 1 38 ) 39 40 // 计数器 41 // limit connections for every ip 42 type counter struct { 43 size int32 44 actives map[string]*Conn // 活跃的连接 45 mu *sync.RWMutex 46 lastAccess int64 47 } 48 49 // 新增计数器 50 func newCounter() *counter { 51 return &counter{ 52 size: 1, 53 actives: make(map[string]*Conn), 54 mu: &sync.RWMutex{}, 55 lastAccess: time.Now().Unix(), 56 } 57 } 58 59 // Listener 包装 net.Listener 60 type Listener struct { 61 net.Listener 62 protocol string // 协议,主要用以日志记录与全局对象索引 63 conns *utils.SyncMap[string, *counter] // 保存 ip -> counter 64 maxConnPerHost int32 // 每个IP最多的连接数 65 maxConnLimit int32 // 当前listener最大的连接数限制 66 whiteList map[string]bool // 白名单列表 67 readTimeout time.Duration // 读超时 68 connCount int32 // 当前listener保持连接的个数 69 purgeCounterInterval time.Duration // 回收过期counter的 70 purgeCounterExpire int64 // counter过期的秒数 71 purgeCancel context.CancelFunc // 停止purge协程的ctx 72 } 73 74 // NewListener returns a new listener 75 // @param l 网络连接 76 // @param protocol 当前listener的七层协议,比如http,grpc等 77 func NewListener(l net.Listener, protocol string, config *Config) (net.Listener, error) { 78 // 参数校验 79 if protocol == "" { 80 log.Errorf("[ConnLimit] listener is missing protocol") 81 return nil, errors.New("listener is missing protocol") 82 } 83 if config == nil || !config.OpenConnLimit { 84 log.Infof("[ConnLimit][%s] apiserver is not open conn limit", protocol) 85 return l, nil 86 } 87 if config.PurgeCounterInterval == 0 || config.PurgeCounterExpire == 0 { 88 log.Errorf("[ConnLimit][%s] purge params invalid", protocol) 89 return nil, errors.New("purge params invalid") 90 } 91 92 hostConnLimit := int32(config.MaxConnPerHost) 93 lisConnLimit := int32(config.MaxConnLimit) 94 // 参数校验, perHost阈值不能小于1 95 if hostConnLimit < minHostConnLimit { 96 return nil, fmt.Errorf("invalid conn limit: %d, can't be smaller than %d", hostConnLimit, minHostConnLimit) 97 } 98 99 whites := strings.Split(config.WhiteList, ",") 100 whiteList := make(map[string]bool, len(whites)) 101 for _, entry := range whites { 102 if entry == "" { 103 continue 104 } 105 106 whiteList[entry] = true 107 } 108 log.Infof("[ConnLimit] host conn limit white list: %+v", whites) 109 110 lis := &Listener{ 111 Listener: l, 112 protocol: protocol, 113 maxConnPerHost: hostConnLimit, 114 maxConnLimit: lisConnLimit, 115 whiteList: whiteList, 116 readTimeout: config.ReadTimeout, 117 purgeCounterInterval: config.PurgeCounterInterval, 118 purgeCounterExpire: int64(config.PurgeCounterExpire / time.Second), 119 conns: utils.NewSyncMap[string, *counter](), 120 } 121 // 把listener放到全局变量中,方便外部访问 122 if err := SetLimitListener(lis); err != nil { 123 return nil, err 124 } 125 // 启动回收协程,定时回收过期counter 126 ctx, cancel := context.WithCancel(context.Background()) 127 lis.purgeExpireCounter(ctx) 128 lis.purgeCancel = cancel 129 return lis, nil 130 } 131 132 // Accept 接收连接 133 func (l *Listener) Accept() (net.Conn, error) { 134 c, err := l.Listener.Accept() 135 if err != nil { 136 return nil, err 137 } 138 return l.accept(c), nil 139 } 140 141 // Close 关闭连接 142 func (l *Listener) Close() error { 143 log.Infof("[Listener][%s] close the listen fd", l.protocol) 144 l.purgeCancel() 145 return l.Listener.Close() 146 } 147 148 // GetHostConnCount 查看对应ip的连接数 149 func (l *Listener) GetHostConnCount(host string) int32 { 150 var connNum int32 151 if c, ok := l.conns.Load(host); ok { 152 c.mu.RLock() 153 connNum = c.size 154 c.mu.RUnlock() 155 } 156 157 return connNum 158 } 159 160 // Range 遍历当前持有连接的host 161 func (l *Listener) Range(fn func(host string, count int32) bool) { 162 l.conns.Range(func(host string, value *counter) bool { 163 return fn(host, l.GetHostConnCount(host)) 164 }) 165 } 166 167 // GetListenerConnCount 查看当前监听server保持的连接数 168 func (l *Listener) GetListenerConnCount() int32 { 169 return atomic.LoadInt32(&l.connCount) 170 } 171 172 // GetDistinctHostCount 获取当前缓存的host的个数 173 func (l *Listener) GetDistinctHostCount() int32 { 174 var count int32 175 l.conns.Range(func(key string, value *counter) bool { 176 count++ 177 return true 178 }) 179 return count 180 } 181 182 // GetHostActiveConns 获取指定host的活跃的连接 183 func (l *Listener) GetHostActiveConns(host string) map[string]*Conn { 184 ct, ok := l.conns.Load(host) 185 if !ok { 186 return nil 187 } 188 189 ct.mu.RLock() 190 out := make(map[string]*Conn, len(ct.actives)) 191 for address, conn := range ct.actives { 192 out[address] = conn 193 } 194 ct.mu.RUnlock() 195 196 return out 197 } 198 199 // GetHostConnStats 获取客户端连接的stat信息 200 func (l *Listener) GetHostConnStats(host string) []*HostConnStat { 201 loadStat := func(h string, ct *counter) *HostConnStat { 202 ct.mu.RLock() 203 stat := &HostConnStat{ 204 Host: h, 205 Amount: ct.size, 206 LastAccess: time.Unix(ct.lastAccess, 0), 207 Actives: make([]string, 0, len(ct.actives)), 208 } 209 for client := range ct.actives { 210 stat.Actives = append(stat.Actives, client) 211 } 212 ct.mu.RUnlock() 213 return stat 214 } 215 216 var out []*HostConnStat 217 // 只获取一个,推荐每次只获取一个 218 if host != "" { 219 if obj, ok := l.conns.Load(host); ok { 220 out = append(out, loadStat(host, obj)) 221 return out 222 } 223 return nil 224 } 225 226 // 全量扫描,比较耗时 227 l.conns.Range(func(key string, value *counter) bool { 228 out = append(out, loadStat(key, value)) 229 return true 230 }) 231 return out 232 } 233 234 // GetHostConnection 获取指定host和port的连接 235 func (l *Listener) GetHostConnection(host string, port int) *Conn { 236 ct, ok := l.conns.Load(host) 237 if !ok { 238 return nil 239 } 240 241 target := fmt.Sprintf("%s:%d", host, port) 242 ct.mu.RLock() 243 defer ct.mu.RUnlock() 244 for address, conn := range ct.actives { 245 if address == target { 246 return conn 247 } 248 } 249 250 return nil 251 } 252 253 // 封装一层,让关键函数acquire的更具备可测试性(不需要mock net.Conn) 254 func (l *Listener) accept(conn net.Conn) net.Conn { 255 address := conn.RemoteAddr().String() 256 // addr解析失败, 不做限制 257 ipPort := strings.Split(address, ":") 258 if len(ipPort) != 2 || ipPort[0] == "" { 259 return conn 260 } 261 return l.acquire(conn, address, ipPort[0]) 262 } 263 264 // 包裹一下conn 265 // 增加ip的连接计数,如果发现ip连接达到上限,则关闭 266 // conn 原始连接 267 // address 客户端地址 268 // host 处理后的客户端IP地址 269 func (l *Listener) acquire(conn net.Conn, address string, host string) *Conn { 270 limiterConn := &Conn{ 271 Conn: conn, 272 closed: false, 273 address: address, 274 host: host, 275 listener: l, 276 } 277 278 log.Debugf("acquire conn for: %s", address) 279 if ok := l.incConnCount(); !ok { 280 log.Errorf("[ConnLimit][%s] host(%s) reach apiserver conn limit(%d)", l.protocol, host, l.maxConnLimit) 281 limiterConn.closed = true 282 _ = limiterConn.Conn.Close() 283 return limiterConn 284 } 285 286 c, ok := l.conns.Load(host) 287 // 首次访问, 置1返回ok 288 if !ok { 289 ctr := newCounter() 290 ctr.actives[address] = limiterConn 291 l.conns.Store(host, ctr) 292 return limiterConn 293 } 294 295 c.mu.Lock() // release是并发的,因此需要加锁 296 // 如果连接数已经超过阈值, 则返回失败, 使用方要调用release减少计数 297 // 如果在白名单中,则直接忽略host连接限制 298 if c.size >= l.maxConnPerHost && !l.ignoreHostConnLimit(host) { 299 c.mu.Unlock() 300 l.descConnCount() // 前面已经增加了计数,因此这里失败,必须减少计数 301 log.Errorf("[ConnLimit][%s] host(%s) reach host conn limit(%d)", l.protocol, host, l.maxConnPerHost) 302 limiterConn.closed = true 303 _ = limiterConn.Conn.Close() 304 return limiterConn 305 } 306 307 // 单个IP的连接,还有冗余,则增加计数 308 c.size++ 309 c.actives[address] = limiterConn 310 c.lastAccess = time.Now().Unix() 311 // map里面存储的是指针,可以不用store,这里直接对指针的内存操作 312 // l.conns.Store(host, c) 313 c.mu.Unlock() 314 return limiterConn 315 } 316 317 // 减少连接计数 318 func (l *Listener) release(conn *Conn) { 319 log.Debugf("release conn for: %s", conn.host) 320 l.descConnCount() 321 322 if c, ok := l.conns.Load(conn.host); ok { 323 c.mu.Lock() 324 c.size-- 325 // map里面存储的是指针,可以不用store,这里直接对指针的内存操作 326 // l.conns.Store(host, c) 327 delete(c.actives, conn.address) 328 c.mu.Unlock() 329 } 330 } 331 332 // 增加监听server的连接计数 333 // 这里使用了原子变量来增加计数,先判断是否超过最大限制 334 // 如果超过了,则立即返回false,否则计数+1 335 // 在计数+1的过程中,即使有Desc释放过程,也不影响 336 func (l *Listener) incConnCount() bool { 337 if l.maxConnLimit <= 0 { 338 return true 339 } 340 if count := atomic.LoadInt32(&l.connCount); count >= l.maxConnLimit { 341 return false 342 } 343 344 atomic.AddInt32(&l.connCount, 1) 345 return true 346 } 347 348 // 释放监听server的连接计数 349 func (l *Listener) descConnCount() { 350 if l.maxConnLimit <= 0 { 351 return 352 } 353 354 atomic.AddInt32(&l.connCount, -1) 355 } 356 357 // 判断host是否在白名单中 358 // 如果host在白名单中,则忽略host连接限制 359 func (l *Listener) ignoreHostConnLimit(host string) bool { 360 _, ok := l.whiteList[host] 361 return ok 362 } 363 364 // 回收长时间没有访问的IP 365 // 定时扫描 366 func (l *Listener) purgeExpireCounter(ctx context.Context) { 367 go func() { 368 ticker := time.NewTicker(l.purgeCounterInterval) 369 defer ticker.Stop() 370 log.Infof("[Listener][%s] start doing purge expire counter", l.protocol) 371 for { 372 select { 373 case <-ticker.C: 374 l.purgeExpireCounterHandler() 375 case <-ctx.Done(): 376 log.Infof("[Listener][%s] purge expire counter exit", l.protocol) 377 return 378 } 379 } 380 }() 381 } 382 383 // 回收过期counter执行函数 384 func (l *Listener) purgeExpireCounterHandler() { 385 start := time.Now() 386 scanCount := 0 387 purgeCount := 0 388 waitDel := []string{} 389 l.conns.Range(func(key string, ct *counter) bool { 390 scanCount++ 391 ct.mu.RLock() 392 if ct.size == 0 && time.Now().Unix()-ct.lastAccess > l.purgeCounterExpire { 393 waitDel = append(waitDel, key) 394 purgeCount++ 395 } 396 ct.mu.RUnlock() 397 return true 398 }) 399 400 for i := range waitDel { 401 // log.Infof("[Listener][%s] purge expire counter: %s", l.protocol, waitDel[i]) 402 l.conns.Delete(waitDel[i]) 403 } 404 405 spendTime := time.Since(start) 406 log.Infof("[Listener][%s] purge expire counter total(%d), use time: %+v, scan total(%d), scan qps: %.2f", 407 l.protocol, purgeCount, spendTime, scanCount, float64(scanCount)/spendTime.Seconds()) 408 }