github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/http_client_pool.go (about)

     1  package nodes
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
     8  	"github.com/TeaOSLab/EdgeNode/internal/goman"
     9  	"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
    10  	"github.com/cespare/xxhash/v2"
    11  	"github.com/pires/go-proxyproto"
    12  	"golang.org/x/net/http2"
    13  	"net"
    14  	"net/http"
    15  	"runtime"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  )
    21  
    22  // SharedHTTPClientPool HTTP客户端池单例
    23  var SharedHTTPClientPool = NewHTTPClientPool()
    24  
    25  const httpClientProxyProtocolTag = "@ProxyProtocol@"
    26  const maxHTTPRedirects = 8
    27  
    28  // HTTPClientPool 客户端池
    29  type HTTPClientPool struct {
    30  	clientsMap map[uint64]*HTTPClient // origin key => client
    31  
    32  	cleanTicker *time.Ticker
    33  
    34  	locker sync.RWMutex
    35  }
    36  
    37  // NewHTTPClientPool 获取新对象
    38  func NewHTTPClientPool() *HTTPClientPool {
    39  	var pool = &HTTPClientPool{
    40  		cleanTicker: time.NewTicker(1 * time.Hour),
    41  		clientsMap:  map[uint64]*HTTPClient{},
    42  	}
    43  
    44  	goman.New(func() {
    45  		pool.cleanClients()
    46  	})
    47  
    48  	return pool
    49  }
    50  
    51  // Client 根据地址获取客户端
    52  func (this *HTTPClientPool) Client(req *HTTPRequest,
    53  	origin *serverconfigs.OriginConfig,
    54  	originAddr string,
    55  	proxyProtocol *serverconfigs.ProxyProtocolConfig,
    56  	followRedirects bool) (rawClient *http.Client, err error) {
    57  	if origin.Addr == nil {
    58  		return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
    59  	}
    60  
    61  	if req == nil || req.RawReq == nil || req.RawReq.URL == nil {
    62  		err = errors.New("invalid request url")
    63  		return
    64  	}
    65  	var originHost = req.RawReq.URL.Host
    66  	var urlPort = req.RawReq.URL.Port()
    67  	if len(urlPort) == 0 {
    68  		if req.RawReq.URL.Scheme == "http" {
    69  			urlPort = "80"
    70  		} else {
    71  			urlPort = "443"
    72  		}
    73  
    74  		originHost += ":" + urlPort
    75  	}
    76  
    77  	var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost
    78  
    79  	// if we are under available ProxyProtocol, we add client ip to key to make every client unique
    80  	var isProxyProtocol = false
    81  	if proxyProtocol != nil && proxyProtocol.IsOn {
    82  		rawKey += httpClientProxyProtocolTag + req.requestRemoteAddr(true)
    83  		isProxyProtocol = true
    84  	}
    85  
    86  	// follow redirects
    87  	if followRedirects {
    88  		rawKey += "@follow"
    89  	}
    90  
    91  	var key = xxhash.Sum64String(rawKey)
    92  
    93  	var isLnRequest = origin.Id == 0
    94  
    95  	this.locker.RLock()
    96  	client, found := this.clientsMap[key]
    97  	this.locker.RUnlock()
    98  	if found {
    99  		client.UpdateAccessTime()
   100  		return client.RawClient(), nil
   101  	}
   102  
   103  	// 这里不能使用RLock,避免因为并发生成多个同样的client实例
   104  	this.locker.Lock()
   105  	defer this.locker.Unlock()
   106  
   107  	// 再次查找
   108  	client, found = this.clientsMap[key]
   109  	if found {
   110  		client.UpdateAccessTime()
   111  		return client.RawClient(), nil
   112  	}
   113  
   114  	var maxConnections = origin.MaxConns
   115  	var connectionTimeout = origin.ConnTimeoutDuration()
   116  	var readTimeout = origin.ReadTimeoutDuration()
   117  	var idleTimeout = origin.IdleTimeoutDuration()
   118  	var idleConns = origin.MaxIdleConns
   119  
   120  	// 超时时间
   121  	if connectionTimeout <= 0 {
   122  		connectionTimeout = 15 * time.Second
   123  	}
   124  
   125  	if idleTimeout <= 0 {
   126  		idleTimeout = 2 * time.Minute
   127  	}
   128  
   129  	var numberCPU = runtime.NumCPU()
   130  	if numberCPU < 8 {
   131  		numberCPU = 8
   132  	}
   133  	if maxConnections <= 0 {
   134  		maxConnections = numberCPU * 64
   135  	}
   136  
   137  	if idleConns <= 0 {
   138  		idleConns = numberCPU * 16
   139  	}
   140  
   141  	if isProxyProtocol { // ProxyProtocol无需保持太多空闲连接
   142  		idleConns = 3
   143  	} else if isLnRequest { // 可以判断为Ln节点请求
   144  		maxConnections *= 8
   145  		idleConns *= 8
   146  		idleTimeout *= 4
   147  	} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
   148  		// Ln节点可以适当增加连接数
   149  		maxConnections *= 2
   150  		idleConns *= 2
   151  	}
   152  
   153  	// TLS通讯
   154  	var tlsConfig = &tls.Config{
   155  		InsecureSkipVerify: true,
   156  	}
   157  	if origin.Cert != nil {
   158  		var obj = origin.Cert.CertObject()
   159  		if obj != nil {
   160  			tlsConfig.InsecureSkipVerify = false
   161  			tlsConfig.Certificates = []tls.Certificate{*obj}
   162  			if len(origin.Cert.ServerName) > 0 {
   163  				tlsConfig.ServerName = origin.Cert.ServerName
   164  			}
   165  		}
   166  	}
   167  
   168  	var transport = &HTTPClientTransport{
   169  		Transport: &http.Transport{
   170  			DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
   171  				var realAddr = originAddr
   172  
   173  				// for redirections
   174  				if followRedirects && originHost != addr {
   175  					realAddr = addr
   176  				}
   177  
   178  				// connect
   179  				conn, dialErr := (&net.Dialer{
   180  					Timeout:   connectionTimeout,
   181  					KeepAlive: 1 * time.Minute,
   182  				}).DialContext(ctx, network, realAddr)
   183  				if dialErr != nil {
   184  					return nil, dialErr
   185  				}
   186  
   187  				// handle PROXY protocol
   188  				proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
   189  				if proxyErr != nil {
   190  					return nil, proxyErr
   191  				}
   192  
   193  				return NewOriginConn(conn), nil
   194  			},
   195  			MaxIdleConns:          0,
   196  			MaxIdleConnsPerHost:   idleConns,
   197  			MaxConnsPerHost:       maxConnections,
   198  			IdleConnTimeout:       idleTimeout,
   199  			ExpectContinueTimeout: 1 * time.Second,
   200  			TLSHandshakeTimeout:   5 * time.Second,
   201  			TLSClientConfig:       tlsConfig,
   202  			ReadBufferSize:        8 * 1024,
   203  			Proxy:                 nil,
   204  		},
   205  	}
   206  
   207  	// support http/2
   208  	if origin.HTTP2Enabled && origin.Addr != nil && origin.Addr.Protocol == serverconfigs.ProtocolHTTPS {
   209  		_ = http2.ConfigureTransport(transport.Transport)
   210  	}
   211  
   212  	rawClient = &http.Client{
   213  		Timeout:   readTimeout,
   214  		Transport: transport,
   215  		CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
   216  			// follow redirects
   217  			if followRedirects && len(via) <= maxHTTPRedirects {
   218  				return nil
   219  			}
   220  
   221  			return http.ErrUseLastResponse
   222  		},
   223  	}
   224  
   225  	this.clientsMap[key] = NewHTTPClient(rawClient, isProxyProtocol)
   226  
   227  	return rawClient, nil
   228  }
   229  
   230  // 清理不使用的Client
   231  func (this *HTTPClientPool) cleanClients() {
   232  	for range this.cleanTicker.C {
   233  		var nowTime = fasttime.Now().Unix()
   234  
   235  		var expiredKeys []uint64
   236  		var expiredClients = []*HTTPClient{}
   237  
   238  		// lookup expired clients
   239  		this.locker.RLock()
   240  		for k, client := range this.clientsMap {
   241  			if client.AccessTime() < nowTime-86400 ||
   242  				(client.IsProxyProtocol() && client.AccessTime() < nowTime-3600) { // 超过 N 秒没有调用就关闭
   243  				expiredKeys = append(expiredKeys, k)
   244  				expiredClients = append(expiredClients, client)
   245  			}
   246  		}
   247  		this.locker.RUnlock()
   248  
   249  		// remove expired keys
   250  		if len(expiredKeys) > 0 {
   251  			this.locker.Lock()
   252  			for _, k := range expiredKeys {
   253  				delete(this.clientsMap, k)
   254  			}
   255  			this.locker.Unlock()
   256  		}
   257  
   258  		// close expired clients
   259  		if len(expiredClients) > 0 {
   260  			for _, client := range expiredClients {
   261  				client.Close()
   262  			}
   263  		}
   264  	}
   265  }
   266  
   267  // 支持PROXY Protocol
   268  func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, proxyProtocol *serverconfigs.ProxyProtocolConfig) error {
   269  	if proxyProtocol != nil &&
   270  		proxyProtocol.IsOn &&
   271  		(proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
   272  		var remoteAddr = req.requestRemoteAddr(true)
   273  		var transportProtocol = proxyproto.TCPv4
   274  		if strings.Contains(remoteAddr, ":") {
   275  			transportProtocol = proxyproto.TCPv6
   276  		}
   277  		var destAddr = conn.RemoteAddr()
   278  		var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
   279  		if reqConn != nil {
   280  			destAddr = reqConn.(net.Conn).LocalAddr()
   281  		}
   282  		var header = proxyproto.Header{
   283  			Version:           byte(proxyProtocol.Version),
   284  			Command:           proxyproto.PROXY,
   285  			TransportProtocol: transportProtocol,
   286  			SourceAddr: &net.TCPAddr{
   287  				IP:   net.ParseIP(remoteAddr),
   288  				Port: req.requestRemotePort(),
   289  			},
   290  			DestinationAddr: destAddr,
   291  		}
   292  		_, err := header.WriteTo(conn)
   293  		if err != nil {
   294  			_ = conn.Close()
   295  			return err
   296  		}
   297  		return nil
   298  	}
   299  
   300  	return nil
   301  }