github.com/tickstep/library-go@v0.1.1/requester/dial.go (about) 1 package requester 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "github.com/tickstep/library-go/expires" 8 "github.com/tickstep/library-go/expires/cachemap" 9 mathrand "math/rand" 10 "net" 11 "net/http" 12 "net/url" 13 "strconv" 14 "strings" 15 "time" 16 ) 17 18 type IPType string 19 20 const ( 21 // MaxDuration 最大的Duration 22 MaxDuration = 1<<63 - 1 23 24 // IPAny 任意IP,默认取第一个域名解析的结果 25 IPAny IPType = "any" 26 // IPv4 优先使用Ipv4的域名解析地址 27 IPv4 IPType = "ipv4" 28 // IPv6 优先使用Ipv6的域名解析地址 29 IPv6 IPType = "ipv6" 30 ) 31 32 var ( 33 localTCPAddrList = []*net.TCPAddr{} 34 35 // ProxyAddr 代理地址 36 ProxyAddr string 37 38 // ErrProxyAddrEmpty 代理地址为空 39 ErrProxyAddrEmpty = errors.New("proxy addr is empty") 40 41 tcpCache = cachemap.GlobalCacheOpMap.LazyInitCachePoolOp("requester/tcp") 42 43 // ipPref 域名解析策略 44 ipPref = IPAny 45 ) 46 47 // SetLocalTCPAddrList 设置网卡地址 48 func SetLocalTCPAddrList(ips ...string) { 49 list := make([]*net.TCPAddr, 0, len(ips)) 50 for k := range ips { 51 p := net.ParseIP(ips[k]) 52 if p == nil { 53 continue 54 } 55 56 list = append(list, &net.TCPAddr{ 57 IP: p, 58 }) 59 } 60 localTCPAddrList = list 61 } 62 63 // SetPreferIPType 设置优先的IP类型 64 func SetPreferIPType(ipType IPType) { 65 ipPref = ipType 66 } 67 68 func proxyFunc(req *http.Request) (*url.URL, error) { 69 u, err := checkProxyAddr(ProxyAddr) 70 if err != nil { 71 return http.ProxyFromEnvironment(req) 72 } 73 74 return u, err 75 } 76 77 func getLocalTCPAddr() *net.TCPAddr { 78 if len(localTCPAddrList) == 0 { 79 return nil 80 } 81 i := mathrand.Intn(len(localTCPAddrList)) 82 return localTCPAddrList[i] 83 } 84 85 func getDialer() *net.Dialer { 86 return &net.Dialer{ 87 Timeout: 30 * time.Second, 88 KeepAlive: 30 * time.Second, 89 LocalAddr: getLocalTCPAddr(), 90 DualStack: true, 91 } 92 } 93 94 func checkProxyAddr(proxyAddr string) (u *url.URL, err error) { 95 if proxyAddr == "" { 96 return nil, ErrProxyAddrEmpty 97 } 98 99 host, port, err := net.SplitHostPort(proxyAddr) 100 if err == nil { 101 u = &url.URL{ 102 Host: net.JoinHostPort(host, port), 103 } 104 return 105 } 106 107 u, err = url.Parse(proxyAddr) 108 if err == nil { 109 return 110 } 111 112 return 113 } 114 115 // SetGlobalProxy 设置代理 116 func SetGlobalProxy(proxyAddr string) { 117 ProxyAddr = proxyAddr 118 } 119 120 // SetTCPHostBind 设置host绑定ip 121 func SetTCPHostBind(host, ip string) { 122 tcpCache.Store(host, expires.NewDataExpires(net.ParseIP(ip), MaxDuration)) 123 return 124 } 125 126 func getServerName(address string) string { 127 host, _, err := net.SplitHostPort(address) 128 if err != nil { 129 return address 130 } 131 return host 132 } 133 134 // resolveTCPHost 135 // 解析的tcpaddr没有port!!! 136 func resolveTCPHost(ctx context.Context, host string) (ip net.IP, err error) { 137 addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) 138 if err != nil { 139 return 140 } 141 142 // 执行域名解析策略 143 for _, ipaddr := range addrs { 144 if ipPref == IPv4 { // 优先IPv4 145 if isIPv4(ipaddr.IP.String()) { 146 return ipaddr.IP, nil 147 } 148 } else if ipPref == IPv6 { // 优先IPv6 149 if isIPv6(ipaddr.IP.String()) { 150 return ipaddr.IP, nil 151 } 152 } 153 } 154 155 // 默认使用第一个解析结果 156 return addrs[0].IP, nil 157 } 158 159 func isIPv4(ip string) bool { 160 return strings.Contains(ip, ".") 161 } 162 163 func isIPv6(ip string) bool { 164 return strings.Contains(ip, ":") 165 } 166 167 func dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) { 168 switch network { 169 case "tcp", "tcp4", "tcp6": 170 host, portStr, err := net.SplitHostPort(address) 171 if err != nil { 172 return nil, err 173 } 174 data, err := cachemap.GlobalCacheOpMap.CacheOperationWithError("requester/tcp", host, func() (expires.DataExpires, error) { 175 ip, err := resolveTCPHost(ctx, host) 176 if err != nil { 177 return nil, err 178 } 179 return expires.NewDataExpires(ip, 10*time.Minute), nil // 传值 180 }) 181 if err != nil { 182 return nil, err 183 } 184 185 port, err := strconv.Atoi(portStr) 186 if err != nil { 187 return nil, err 188 } 189 190 return net.DialTCP(network, getLocalTCPAddr(), &net.TCPAddr{ 191 IP: data.Data().(net.IP), 192 Port: port, // 设置端口 193 }) 194 } 195 196 // 非 tcp 请求 197 conn, err = getDialer().DialContext(ctx, network, address) 198 return 199 } 200 201 func dial(network, address string) (conn net.Conn, err error) { 202 return dialContext(context.Background(), network, address) 203 } 204 205 func (h *HTTPClient) dialTLSFunc() func(network, address string) (tlsConn net.Conn, err error) { 206 return func(network, address string) (tlsConn net.Conn, err error) { 207 conn, err := dialContext(context.Background(), network, address) 208 if err != nil { 209 return nil, err 210 } 211 212 return tls.Client(conn, &tls.Config{ 213 ServerName: getServerName(address), 214 InsecureSkipVerify: !h.https, 215 }), nil 216 } 217 }