github.com/turingchain2020/turingchain@v1.1.21/system/p2p/dht/manage/connectionGater.go (about)

     1  package manage
     2  
     3  import (
     4  	"container/list"
     5  	"context"
     6  	"runtime"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/kevinms/leakybucket-go"
    11  	"github.com/libp2p/go-libp2p-core/control"
    12  	"github.com/libp2p/go-libp2p-core/host"
    13  	"github.com/libp2p/go-libp2p-core/network"
    14  	"github.com/libp2p/go-libp2p-core/peer"
    15  	"github.com/multiformats/go-multiaddr"
    16  	net "github.com/multiformats/go-multiaddr-net"
    17  )
    18  
    19  const (
    20  	// limit for rate limiter when processing new inbound dials.
    21  	ipLimit = 4
    22  	// burst limit for inbound dials.
    23  	ipBurst = 8
    24  	//缓存的临时的节点连接数量,虽然达到了最大限制,但是有的节点连接是查询需要,开辟缓冲区
    25  )
    26  
    27  //CacheLimit cachebuffer
    28  var CacheLimit int32 = 50
    29  
    30  //Conngater gater struct data
    31  type Conngater struct {
    32  	host          *host.Host
    33  	maxConnectNum int32
    34  	ipLimiter     *leakybucket.Collector
    35  	blacklist     *TimeCache
    36  	whitPeerList  map[peer.ID]multiaddr.Multiaddr
    37  }
    38  
    39  //NewConnGater connect gater
    40  func NewConnGater(h *host.Host, limit int32, cache *TimeCache, whitPeers []*peer.AddrInfo) *Conngater {
    41  	gater := &Conngater{}
    42  	gater.host = h
    43  	gater.maxConnectNum = limit
    44  	gater.blacklist = cache
    45  	if gater.blacklist == nil {
    46  		gater.blacklist = NewTimeCache(context.Background(), time.Minute*5)
    47  	}
    48  	gater.ipLimiter = leakybucket.NewCollector(ipLimit, ipBurst, true)
    49  
    50  	for _, pr := range whitPeers {
    51  		if gater.whitPeerList == nil {
    52  			gater.whitPeerList = make(map[peer.ID]multiaddr.Multiaddr)
    53  		}
    54  		gater.whitPeerList[pr.ID] = pr.Addrs[0]
    55  	}
    56  	return gater
    57  }
    58  
    59  // InterceptPeerDial tests whether we're permitted to Dial the specified peer.
    60  func (s *Conngater) InterceptPeerDial(p peer.ID) (allow bool) {
    61  	//具体的拦截策略
    62  	//黑名单检查
    63  	//1.增加校验p2p白名单节点列表
    64  	if !s.checkWhitePeerList(p) {
    65  		return false
    66  	}
    67  	return !s.blacklist.Has(p.Pretty())
    68  
    69  }
    70  
    71  func (s *Conngater) checkWhitePeerList(p peer.ID) bool {
    72  	if s.whitPeerList != nil {
    73  		if _, ok := s.whitPeerList[p]; !ok {
    74  			return false
    75  		}
    76  	}
    77  	return true
    78  }
    79  
    80  // InterceptAddrDial tests whether we're permitted to dial the specified
    81  // multiaddr for the given peer.
    82  func (s *Conngater) InterceptAddrDial(_ peer.ID, m multiaddr.Multiaddr) (allow bool) {
    83  	return true
    84  }
    85  
    86  // InterceptAccept tests whether an incipient inbound connection is allowed.
    87  func (s *Conngater) InterceptAccept(n network.ConnMultiaddrs) (allow bool) {
    88  	if !s.validateDial(n.RemoteMultiaddr()) { //对连进来的节点进行速率限制
    89  		// Allow other go-routines to run in the event
    90  		// we receive a large amount of junk connections.
    91  		runtime.Gosched()
    92  		return false
    93  	}
    94  	//增加校验p2p白名单节点列表
    95  	if !s.checkWhitAddr(n.RemoteMultiaddr()) {
    96  		return false
    97  	}
    98  	return !s.isPeerAtLimit(network.DirInbound)
    99  
   100  }
   101  
   102  func (s *Conngater) checkWhitAddr(addr multiaddr.Multiaddr) bool {
   103  	if s.whitPeerList == nil {
   104  		return true
   105  	}
   106  	iswhiteIP := false
   107  	checkIP, _ := net.ToIP(addr)
   108  	for _, maddr := range s.whitPeerList {
   109  		ip, err := net.ToIP(maddr)
   110  		if err != nil {
   111  			continue
   112  		}
   113  		if ip.String() == checkIP.String() {
   114  			iswhiteIP = true
   115  		}
   116  	}
   117  
   118  	return iswhiteIP
   119  }
   120  
   121  // InterceptSecured tests whether a given connection, now authenticated,
   122  // is allowed.
   123  func (s *Conngater) InterceptSecured(_ network.Direction, _ peer.ID, n network.ConnMultiaddrs) (allow bool) {
   124  	return true
   125  }
   126  
   127  // InterceptUpgraded tests whether a fully capable connection is allowed.
   128  func (s *Conngater) InterceptUpgraded(n network.Conn) (allow bool, reason control.DisconnectReason) {
   129  	return true, 0
   130  }
   131  
   132  func (s *Conngater) validateDial(addr multiaddr.Multiaddr) bool {
   133  	ip, err := net.ToIP(addr)
   134  	if err != nil {
   135  		return false
   136  	}
   137  	remaining := s.ipLimiter.Remaining(ip.String())
   138  	if remaining <= 0 {
   139  		return false
   140  	}
   141  	s.ipLimiter.Add(ip.String(), 1)
   142  	return true
   143  }
   144  
   145  func (s *Conngater) isPeerAtLimit(direction network.Direction) bool {
   146  	if s.maxConnectNum == 0 { //不对连接节点数量进行限制
   147  		return false
   148  	}
   149  	numOfConns := len((*s.host).Network().Peers())
   150  	var maxPeers int
   151  	if direction == network.DirInbound { //inbound connect
   152  		maxPeers = int(s.maxConnectNum + CacheLimit/2)
   153  	} else {
   154  		maxPeers = int(s.maxConnectNum + CacheLimit)
   155  	}
   156  	return numOfConns >= maxPeers
   157  }
   158  
   159  //TimeCache data struct
   160  type TimeCache struct {
   161  	cacheLock sync.Mutex
   162  	Q         *list.List
   163  	M         map[string]time.Time
   164  	ctx       context.Context
   165  	span      time.Duration
   166  }
   167  
   168  //NewTimeCache new time cache obj.
   169  func NewTimeCache(ctx context.Context, span time.Duration) *TimeCache {
   170  	cache := &TimeCache{
   171  		Q:    list.New(),
   172  		M:    make(map[string]time.Time),
   173  		span: span,
   174  		ctx:  ctx,
   175  	}
   176  	go cache.sweep()
   177  	return cache
   178  }
   179  
   180  //Add add key
   181  func (tc *TimeCache) Add(s string, lifetime time.Duration) {
   182  	tc.cacheLock.Lock()
   183  	defer tc.cacheLock.Unlock()
   184  	_, ok := tc.M[s]
   185  	if ok {
   186  		return
   187  	}
   188  	if lifetime == 0 {
   189  		lifetime = tc.span
   190  	}
   191  	tc.M[s] = time.Now().Add(lifetime)
   192  	tc.Q.PushFront(s)
   193  }
   194  
   195  func (tc *TimeCache) sweep() {
   196  	ticker := time.NewTicker(time.Second)
   197  	for {
   198  		select {
   199  		case <-ticker.C:
   200  			tc.checkOvertimekey()
   201  		case <-tc.ctx.Done():
   202  			return
   203  		}
   204  	}
   205  
   206  }
   207  
   208  func (tc *TimeCache) checkOvertimekey() {
   209  	tc.cacheLock.Lock()
   210  	defer tc.cacheLock.Unlock()
   211  
   212  	back := tc.Q.Back()
   213  	if back == nil {
   214  		return
   215  	}
   216  	v := back.Value.(string)
   217  	t, ok := tc.M[v]
   218  	if !ok {
   219  		return
   220  	}
   221  	//if time.Since(t) > tc.span {
   222  	if time.Now().After(t) {
   223  		tc.Q.Remove(back)
   224  		delete(tc.M, v)
   225  	}
   226  }
   227  
   228  //Has check key
   229  func (tc *TimeCache) Has(s string) bool {
   230  	tc.cacheLock.Lock()
   231  	defer tc.cacheLock.Unlock()
   232  
   233  	_, ok := tc.M[s]
   234  	return ok
   235  }