github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/client/http_roundtripper.go (about) 1 package client 2 3 import ( 4 "context" 5 "crypto/tls" 6 _errors "errors" 7 "fmt" 8 "net" 9 "net/http" 10 "strings" 11 "sync" 12 13 utls "github.com/refraction-networking/utls" 14 "golang.org/x/net/http2" 15 "golang.org/x/net/proxy" 16 ) 17 18 type ( 19 roundTripper struct { 20 sync.Mutex 21 // fix typing 22 JA3 string 23 UserAgent string 24 25 cachedConnections map[string]net.Conn 26 cachedTransports map[string]http.RoundTripper 27 DialTLS func(network, addr string) (net.Conn, error) 28 Dialer proxy.Dialer 29 } 30 ) 31 32 var errProtocolNegotiated = _errors.New("protocol negotiated") 33 34 func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 35 if rt.Dialer == nil { 36 rt.Dialer = proxy.Direct 37 } 38 /* 39 // Fix this later for proper cookie parsing 40 for _, properties := range rt.Cookies { 41 req.AddCookie(&http.Cookie{Name: properties.Name, 42 Value: properties.Value, 43 Path: properties.Path, 44 Domain: properties.Domain, 45 Expires: properties.JSONExpires.Time, //TODO: scuffed af 46 RawExpires: properties.RawExpires, 47 MaxAge: properties.MaxAge, 48 HttpOnly: properties.HTTPOnly, 49 Secure: properties.Secure, 50 SameSite: properties.SameSite, 51 Raw: properties.Raw, 52 Unparsed: properties.Unparsed, 53 }) 54 }*/ 55 req.Header.Set("User-Agent", rt.UserAgent) 56 addr := rt.getDialTLSAddr(req) 57 58 if rt.cachedTransports == nil { 59 rt.cachedTransports = make(map[string]http.RoundTripper) 60 } 61 if rt.cachedConnections == nil { 62 rt.cachedConnections = make(map[string]net.Conn) 63 } 64 65 if _, ok := rt.cachedTransports[addr]; !ok { 66 if err := rt.getTransport(req, addr); err != nil { 67 return nil, err 68 } 69 } 70 return rt.cachedTransports[addr].RoundTrip(req) 71 } 72 73 func (rt *roundTripper) getDialTLSAddr(req *http.Request) string { 74 host, port, err := net.SplitHostPort(req.URL.Host) 75 if err == nil { 76 return net.JoinHostPort(host, port) 77 } 78 return net.JoinHostPort(req.URL.Host, "443") // we can assume port is 443 at this point 79 } 80 81 func (rt *roundTripper) getTransport(req *http.Request, addr string) error { 82 switch strings.ToLower(req.URL.Scheme) { 83 case "http": 84 rt.cachedTransports[addr] = &http.Transport{Dial: rt.Dialer.Dial, DisableKeepAlives: true} 85 return nil 86 case "https": 87 default: 88 return fmt.Errorf("invalid URL scheme: [%v]", req.URL.Scheme) 89 } 90 91 _, err := rt.dialTLS(context.Background(), "tcp", addr) 92 switch err { 93 case errProtocolNegotiated: 94 case nil: 95 // Should never happen. 96 //panic("dialTLS returned no error when determining cachedTransports") 97 default: 98 return err 99 } 100 101 return nil 102 } 103 104 func (rt *roundTripper) dialTLS(ctx context.Context, network, addr string) (net.Conn, error) { 105 rt.Lock() 106 defer rt.Unlock() 107 108 // If we have the connection from when we determined the HTTPS 109 // cachedTransports to use, return that. 110 if conn := rt.cachedConnections[addr]; conn != nil { 111 delete(rt.cachedConnections, addr) 112 return conn, nil 113 } 114 conn, err := rt.DialTLS(network, addr) 115 if err != nil { 116 return nil, err 117 } 118 ////////// 119 if rt.cachedTransports[addr] != nil { 120 return conn, nil 121 } 122 123 // No http.Transport constructed yet, create one based on the results 124 // of ALPN. 125 if c, ok := conn.(*utls.UConn); ok { 126 switch c.ConnectionState().NegotiatedProtocol { 127 case http2.NextProtoTLS: 128 // The remote peer is speaking HTTP 2 + TLS. 129 rt.cachedTransports[addr] = &http2.Transport{DialTLS: rt.dialTLSHTTP2} 130 default: 131 // Assume the remote peer is speaking HTTP 1.x + TLS. 132 rt.cachedTransports[addr] = &http.Transport{DialTLSContext: rt.dialTLS} 133 134 } 135 } 136 137 // Stash the connection just established for use servicing the 138 // actual request (should be near-immediate). 139 rt.cachedConnections[addr] = conn 140 141 return nil, errProtocolNegotiated 142 } 143 144 func (rt *roundTripper) dialTLSHTTP2(network, addr string, _ *tls.Config) (net.Conn, error) { 145 return rt.dialTLS(context.Background(), network, addr) 146 }