github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/client_conn_base.go (about)

     1  // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  
     3  package nodes
     4  
     5  import (
     6  	"crypto/tls"
     7  	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
     8  	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
     9  	"net"
    10  	"sync/atomic"
    11  	"time"
    12  )
    13  
    14  type BaseClientConn struct {
    15  	rawConn net.Conn
    16  
    17  	isBound    bool
    18  	userId     int64
    19  	userPlanId int64
    20  	serverId   int64
    21  	remoteAddr string
    22  	hasLimit   bool
    23  
    24  	isPersistent bool // 是否为持久化连接
    25  	fingerprint  []byte
    26  
    27  	isClosed bool
    28  
    29  	rawIP string
    30  
    31  	totalSentBytes int64
    32  }
    33  
    34  func (this *BaseClientConn) IsClosed() bool {
    35  	return this.isClosed
    36  }
    37  
    38  // IsBound 是否已绑定服务
    39  func (this *BaseClientConn) IsBound() bool {
    40  	return this.isBound
    41  }
    42  
    43  // Bind 绑定服务
    44  func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerServer int, maxConnsPerIP int) bool {
    45  	if this.isBound {
    46  		return true
    47  	}
    48  	this.isBound = true
    49  	this.serverId = serverId
    50  	this.remoteAddr = remoteAddr
    51  	this.hasLimit = true
    52  
    53  	// 检查是否可以连接
    54  	return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP)
    55  }
    56  
    57  // SetServerId 设置服务ID
    58  func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
    59  	goNext = true
    60  
    61  	// 检查服务相关IP黑名单
    62  	var rawIP = this.RawIP()
    63  	if serverId > 0 && len(rawIP) > 0 {
    64  		// 是否在白名单中
    65  		ok, _, expiresAt := iplibrary.AllowIP(rawIP, serverId)
    66  		if !ok {
    67  			_ = this.rawConn.Close()
    68  			firewalls.DropTemporaryTo(rawIP, expiresAt)
    69  			return false
    70  		}
    71  	}
    72  
    73  	this.serverId = serverId
    74  
    75  	// 设置包装前连接
    76  	switch conn := this.rawConn.(type) {
    77  	case *tls.Conn:
    78  		nativeConn, ok := conn.NetConn().(ClientConnInterface)
    79  		if ok {
    80  			nativeConn.SetServerId(serverId)
    81  		}
    82  	case *ClientConn:
    83  		conn.SetServerId(serverId)
    84  	}
    85  
    86  	return true
    87  }
    88  
    89  // ServerId 读取当前连接绑定的服务ID
    90  func (this *BaseClientConn) ServerId() int64 {
    91  	return this.serverId
    92  }
    93  
    94  // SetUserId 设置所属服务的用户ID
    95  func (this *BaseClientConn) SetUserId(userId int64) {
    96  	this.userId = userId
    97  
    98  	// 设置包装前连接
    99  	switch conn := this.rawConn.(type) {
   100  	case *tls.Conn:
   101  		nativeConn, ok := conn.NetConn().(ClientConnInterface)
   102  		if ok {
   103  			nativeConn.SetUserId(userId)
   104  		}
   105  	case *ClientConn:
   106  		conn.SetUserId(userId)
   107  	}
   108  }
   109  
   110  func (this *BaseClientConn) SetUserPlanId(userPlanId int64) {
   111  	this.userPlanId = userPlanId
   112  
   113  	// 设置包装前连接
   114  	switch conn := this.rawConn.(type) {
   115  	case *tls.Conn:
   116  		nativeConn, ok := conn.NetConn().(ClientConnInterface)
   117  		if ok {
   118  			nativeConn.SetUserPlanId(userPlanId)
   119  		}
   120  	case *ClientConn:
   121  		conn.SetUserPlanId(userPlanId)
   122  	}
   123  }
   124  
   125  // UserId 获取当前连接所属服务的用户ID
   126  func (this *BaseClientConn) UserId() int64 {
   127  	return this.userId
   128  }
   129  
   130  // UserPlanId 用户套餐ID
   131  func (this *BaseClientConn) UserPlanId() int64 {
   132  	return this.userPlanId
   133  }
   134  
   135  // RawIP 原本IP
   136  func (this *BaseClientConn) RawIP() string {
   137  	if len(this.rawIP) > 0 {
   138  		return this.rawIP
   139  	}
   140  
   141  	ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String())
   142  	this.rawIP = ip
   143  	return ip
   144  }
   145  
   146  // TCPConn 转换为TCPConn
   147  func (this *BaseClientConn) TCPConn() (tcpConn *net.TCPConn, ok bool) {
   148  	// 设置包装前连接
   149  	switch conn := this.rawConn.(type) {
   150  	case *tls.Conn:
   151  		var internalConn = conn.NetConn()
   152  		clientConn, isClientConn := internalConn.(*ClientConn)
   153  		if isClientConn {
   154  			return clientConn.TCPConn()
   155  		}
   156  		tcpConn, ok = internalConn.(*net.TCPConn)
   157  	default:
   158  		tcpConn, ok = this.rawConn.(*net.TCPConn)
   159  	}
   160  	return
   161  }
   162  
   163  // SetLinger 设置Linger
   164  func (this *BaseClientConn) SetLinger(seconds int) error {
   165  	tcpConn, ok := this.TCPConn()
   166  	if ok {
   167  		return tcpConn.SetLinger(seconds)
   168  	}
   169  	return nil
   170  }
   171  
   172  func (this *BaseClientConn) SetIsPersistent(isPersistent bool) {
   173  	this.isPersistent = isPersistent
   174  
   175  	_ = this.rawConn.SetDeadline(time.Time{})
   176  }
   177  
   178  // SetFingerprint 设置指纹信息
   179  func (this *BaseClientConn) SetFingerprint(fingerprint []byte) {
   180  	this.fingerprint = fingerprint
   181  }
   182  
   183  // Fingerprint 读取指纹信息
   184  func (this *BaseClientConn) Fingerprint() []byte {
   185  	return this.fingerprint
   186  }
   187  
   188  // LastRequestBytes 读取上一次请求发送的字节数
   189  func (this *BaseClientConn) LastRequestBytes() int64 {
   190  	var result = atomic.LoadInt64(&this.totalSentBytes)
   191  	atomic.StoreInt64(&this.totalSentBytes, 0)
   192  	return result
   193  }