github.com/fzfile/BaiduPCS-Go@v0.0.0-20200606205115-4408961cf336/requester/dial.go (about)

     1  package requester
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"github.com/fzfile/BaiduPCS-Go/baidupcs/expires"
     8  	"github.com/fzfile/BaiduPCS-Go/baidupcs/expires/cachemap"
     9  	mathrand "math/rand"
    10  	"net"
    11  	"net/http"
    12  	"net/url"
    13  	"strconv"
    14  	"time"
    15  )
    16  
    17  const (
    18  	// MaxDuration 最大的Duration
    19  	MaxDuration = 1<<63 - 1
    20  )
    21  
    22  var (
    23  	localTCPAddrList = []*net.TCPAddr{}
    24  
    25  	// ProxyAddr 代理地址
    26  	ProxyAddr string
    27  
    28  	// ErrProxyAddrEmpty 代理地址为空
    29  	ErrProxyAddrEmpty = errors.New("proxy addr is empty")
    30  
    31  	tcpCache = cachemap.GlobalCacheOpMap.LazyInitCachePoolOp("requester/tcp")
    32  )
    33  
    34  // SetLocalTCPAddrList 设置网卡地址
    35  func SetLocalTCPAddrList(ips ...string) {
    36  	list := make([]*net.TCPAddr, 0, len(ips))
    37  	for k := range ips {
    38  		p := net.ParseIP(ips[k])
    39  		if p == nil {
    40  			continue
    41  		}
    42  
    43  		list = append(list, &net.TCPAddr{
    44  			IP: p,
    45  		})
    46  	}
    47  	localTCPAddrList = list
    48  }
    49  
    50  func proxyFunc(req *http.Request) (*url.URL, error) {
    51  	u, err := checkProxyAddr(ProxyAddr)
    52  	if err != nil {
    53  		return http.ProxyFromEnvironment(req)
    54  	}
    55  
    56  	return u, err
    57  }
    58  
    59  func getLocalTCPAddr() *net.TCPAddr {
    60  	if len(localTCPAddrList) == 0 {
    61  		return nil
    62  	}
    63  	i := mathrand.Intn(len(localTCPAddrList))
    64  	return localTCPAddrList[i]
    65  }
    66  
    67  func getDialer() *net.Dialer {
    68  	return &net.Dialer{
    69  		Timeout:   30 * time.Second,
    70  		KeepAlive: 30 * time.Second,
    71  		LocalAddr: getLocalTCPAddr(),
    72  		DualStack: true,
    73  	}
    74  }
    75  
    76  func checkProxyAddr(proxyAddr string) (u *url.URL, err error) {
    77  	if proxyAddr == "" {
    78  		return nil, ErrProxyAddrEmpty
    79  	}
    80  
    81  	host, port, err := net.SplitHostPort(proxyAddr)
    82  	if err == nil {
    83  		u = &url.URL{
    84  			Host: net.JoinHostPort(host, port),
    85  		}
    86  		return
    87  	}
    88  
    89  	u, err = url.Parse(proxyAddr)
    90  	if err == nil {
    91  		return
    92  	}
    93  
    94  	return
    95  }
    96  
    97  // SetGlobalProxy 设置代理
    98  func SetGlobalProxy(proxyAddr string) {
    99  	ProxyAddr = proxyAddr
   100  }
   101  
   102  // SetTCPHostBind 设置host绑定ip
   103  func SetTCPHostBind(host, ip string) {
   104  	tcpCache.Store(host, expires.NewDataExpires(net.ParseIP(ip), MaxDuration))
   105  	return
   106  }
   107  
   108  func getServerName(address string) string {
   109  	host, _, err := net.SplitHostPort(address)
   110  	if err != nil {
   111  		return address
   112  	}
   113  	return host
   114  }
   115  
   116  // resolveTCPHost
   117  // 解析的tcpaddr没有port!!!
   118  func resolveTCPHost(ctx context.Context, host string) (ip net.IP, err error) {
   119  	addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
   120  	if err != nil {
   121  		return
   122  	}
   123  
   124  	return addrs[0].IP, nil
   125  }
   126  
   127  func dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
   128  	switch network {
   129  	case "tcp", "tcp4", "tcp6":
   130  		host, portStr, err := net.SplitHostPort(address)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  		data, err := cachemap.GlobalCacheOpMap.CacheOperationWithError("requester/tcp", host, func() (expires.DataExpires, error) {
   135  			ip, err := resolveTCPHost(ctx, host)
   136  			if err != nil {
   137  				return nil, err
   138  			}
   139  			return expires.NewDataExpires(ip, 10*time.Minute), nil // 传值
   140  		})
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  
   145  		port, err := strconv.Atoi(portStr)
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  
   150  		return net.DialTCP(network, getLocalTCPAddr(), &net.TCPAddr{
   151  			IP:   data.Data().(net.IP),
   152  			Port: port, // 设置端口
   153  		})
   154  	}
   155  
   156  	// 非 tcp 请求
   157  	conn, err = getDialer().DialContext(ctx, network, address)
   158  	return
   159  }
   160  
   161  func dial(network, address string) (conn net.Conn, err error) {
   162  	return dialContext(context.Background(), network, address)
   163  }
   164  
   165  func (h *HTTPClient) dialTLSFunc() func(network, address string) (tlsConn net.Conn, err error) {
   166  	return func(network, address string) (tlsConn net.Conn, err error) {
   167  		conn, err := dialContext(context.Background(), network, address)
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  
   172  		return tls.Client(conn, &tls.Config{
   173  			ServerName:         getServerName(address),
   174  			InsecureSkipVerify: !h.https,
   175  		}), nil
   176  	}
   177  }