github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/client/http_roundtripper.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	_errors "errors"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  
    13  	utls "github.com/refraction-networking/utls"
    14  	"golang.org/x/net/http2"
    15  	"golang.org/x/net/proxy"
    16  )
    17  
    18  type (
    19  	roundTripper struct {
    20  		sync.Mutex
    21  		// fix typing
    22  		JA3       string
    23  		UserAgent string
    24  
    25  		cachedConnections map[string]net.Conn
    26  		cachedTransports  map[string]http.RoundTripper
    27  		DialTLS           func(network, addr string) (net.Conn, error)
    28  		Dialer            proxy.Dialer
    29  	}
    30  )
    31  
    32  var errProtocolNegotiated = _errors.New("protocol negotiated")
    33  
    34  func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    35  	if rt.Dialer == nil {
    36  		rt.Dialer = proxy.Direct
    37  	}
    38  	/*
    39  		// Fix this later for proper cookie parsing
    40  		for _, properties := range rt.Cookies {
    41  			req.AddCookie(&http.Cookie{Name: properties.Name,
    42  				Value:      properties.Value,
    43  				Path:       properties.Path,
    44  				Domain:     properties.Domain,
    45  				Expires:    properties.JSONExpires.Time, //TODO: scuffed af
    46  				RawExpires: properties.RawExpires,
    47  				MaxAge:     properties.MaxAge,
    48  				HttpOnly:   properties.HTTPOnly,
    49  				Secure:     properties.Secure,
    50  				SameSite:   properties.SameSite,
    51  				Raw:        properties.Raw,
    52  				Unparsed:   properties.Unparsed,
    53  			})
    54  		}*/
    55  	req.Header.Set("User-Agent", rt.UserAgent)
    56  	addr := rt.getDialTLSAddr(req)
    57  
    58  	if rt.cachedTransports == nil {
    59  		rt.cachedTransports = make(map[string]http.RoundTripper)
    60  	}
    61  	if rt.cachedConnections == nil {
    62  		rt.cachedConnections = make(map[string]net.Conn)
    63  	}
    64  
    65  	if _, ok := rt.cachedTransports[addr]; !ok {
    66  		if err := rt.getTransport(req, addr); err != nil {
    67  			return nil, err
    68  		}
    69  	}
    70  	return rt.cachedTransports[addr].RoundTrip(req)
    71  }
    72  
    73  func (rt *roundTripper) getDialTLSAddr(req *http.Request) string {
    74  	host, port, err := net.SplitHostPort(req.URL.Host)
    75  	if err == nil {
    76  		return net.JoinHostPort(host, port)
    77  	}
    78  	return net.JoinHostPort(req.URL.Host, "443") // we can assume port is 443 at this point
    79  }
    80  
    81  func (rt *roundTripper) getTransport(req *http.Request, addr string) error {
    82  	switch strings.ToLower(req.URL.Scheme) {
    83  	case "http":
    84  		rt.cachedTransports[addr] = &http.Transport{Dial: rt.Dialer.Dial, DisableKeepAlives: true}
    85  		return nil
    86  	case "https":
    87  	default:
    88  		return fmt.Errorf("invalid URL scheme: [%v]", req.URL.Scheme)
    89  	}
    90  
    91  	_, err := rt.dialTLS(context.Background(), "tcp", addr)
    92  	switch err {
    93  	case errProtocolNegotiated:
    94  	case nil:
    95  		// Should never happen.
    96  		//panic("dialTLS returned no error when determining cachedTransports")
    97  	default:
    98  		return err
    99  	}
   100  
   101  	return nil
   102  }
   103  
   104  func (rt *roundTripper) dialTLS(ctx context.Context, network, addr string) (net.Conn, error) {
   105  	rt.Lock()
   106  	defer rt.Unlock()
   107  
   108  	// If we have the connection from when we determined the HTTPS
   109  	// cachedTransports to use, return that.
   110  	if conn := rt.cachedConnections[addr]; conn != nil {
   111  		delete(rt.cachedConnections, addr)
   112  		return conn, nil
   113  	}
   114  	conn, err := rt.DialTLS(network, addr)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	//////////
   119  	if rt.cachedTransports[addr] != nil {
   120  		return conn, nil
   121  	}
   122  
   123  	// No http.Transport constructed yet, create one based on the results
   124  	// of ALPN.
   125  	if c, ok := conn.(*utls.UConn); ok {
   126  		switch c.ConnectionState().NegotiatedProtocol {
   127  		case http2.NextProtoTLS:
   128  			// The remote peer is speaking HTTP 2 + TLS.
   129  			rt.cachedTransports[addr] = &http2.Transport{DialTLS: rt.dialTLSHTTP2}
   130  		default:
   131  			// Assume the remote peer is speaking HTTP 1.x + TLS.
   132  			rt.cachedTransports[addr] = &http.Transport{DialTLSContext: rt.dialTLS}
   133  
   134  		}
   135  	}
   136  
   137  	// Stash the connection just established for use servicing the
   138  	// actual request (should be near-immediate).
   139  	rt.cachedConnections[addr] = conn
   140  
   141  	return nil, errProtocolNegotiated
   142  }
   143  
   144  func (rt *roundTripper) dialTLSHTTP2(network, addr string, _ *tls.Config) (net.Conn, error) {
   145  	return rt.dialTLS(context.Background(), network, addr)
   146  }