github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/transport/http_proxy.go (about)

     1  package transport
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httputil"
    14  	"net/url"
    15  	"strconv"
    16  	"sync"
    17  
    18  	"golang.org/x/net/http2"
    19  	"golang.org/x/net/proxy"
    20  )
    21  
    22  const (
    23  	proxyAuthHeader = "Proxy-Authorization"
    24  )
    25  
    26  // connectDialer allows to configure one-time use HTTP CONNECT client
    27  type (
    28  	pbuffer struct {
    29  		net.Conn
    30  		r io.Reader
    31  	}
    32  
    33  	connectDialer struct {
    34  		ProxyURL      url.URL
    35  		DefaultHeader http.Header
    36  
    37  		Dialer net.Dialer // overridden dialer allow to control establishment of TCP connection
    38  
    39  		// overridden DialTLS allows user to control establishment of TLS connection
    40  		// MUST return connection with completed Handshake, and NegotiatedProtocol
    41  		DialTLS func(network string, address string) (net.Conn, string, error)
    42  
    43  		EnableH2ConnReuse  bool
    44  		cacheH2Mu          sync.Mutex
    45  		cachedH2ClientConn *http2.ClientConn
    46  		cachedH2RawConn    net.Conn
    47  	}
    48  
    49  	// ContextKeyHeader Users of context.WithValue should define their own types for keys
    50  	ContextKeyHeader struct{}
    51  
    52  	http2Conn struct {
    53  		net.Conn
    54  		in  *io.PipeWriter
    55  		out io.ReadCloser
    56  	}
    57  )
    58  
    59  // newConnectDialer creates a dialer to issue CONNECT requests and tunnel traffic via HTTP/S proxy.
    60  // proxyUrlStr must provide Scheme and Host, may provide credentials and port.
    61  // Example: https://username:password@golang.org:443
    62  func NewProxyDialer(proxyURLStr string, UserAgent string) (proxy.Dialer, error) {
    63  	proxyURL, err := url.Parse(proxyURLStr)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	if proxyURL.Host == "" || proxyURL.Host == "undefined" {
    69  		return nil, errors.New("invalid url `" + proxyURLStr +
    70  			"`, make sure to specify full url like https://username:password@hostname.com:443/")
    71  	}
    72  
    73  	switch proxyURL.Scheme {
    74  	case "http":
    75  		if proxyURL.Port() == "" {
    76  			proxyURL.Host = net.JoinHostPort(proxyURL.Host, "80")
    77  		}
    78  	case "https":
    79  		if proxyURL.Port() == "" {
    80  			proxyURL.Host = net.JoinHostPort(proxyURL.Host, "443")
    81  		}
    82  	case "socks5":
    83  		{
    84  		}
    85  	case "":
    86  		return nil, errors.New("specify scheme explicitly (https://)")
    87  	default:
    88  		return nil, errors.New("scheme " + proxyURL.Scheme + " is not supported")
    89  	}
    90  
    91  	client := &connectDialer{
    92  		ProxyURL:          *proxyURL,
    93  		DefaultHeader:     make(http.Header),
    94  		EnableH2ConnReuse: true,
    95  	}
    96  
    97  	if proxyURL.User != nil {
    98  		if proxyURL.User.Username() != "" {
    99  			// password, _ := proxyUrl.User.Password()
   100  			// client.DefaultHeader.Set("Proxy-Authorization", "Basic "+
   101  			// 	base64.StdEncoding.EncodeToString([]byte(proxyUrl.User.Username()+":"+password)))
   102  
   103  			username := proxyURL.User.Username()
   104  			password, _ := proxyURL.User.Password()
   105  
   106  			// client.DefaultHeader.SetBasicAuth(username, password)
   107  			auth := username + ":" + password
   108  			basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(auth))
   109  			client.DefaultHeader.Add("Proxy-Authorization", basicAuth)
   110  		}
   111  	}
   112  	client.DefaultHeader.Set("User-Agent", UserAgent)
   113  	return client, nil
   114  }
   115  
   116  func (c *connectDialer) Dial(network, address string) (net.Conn, error) {
   117  	return c.DialContext(context.Background(), network, address)
   118  }
   119  
   120  // ctx.Value will be inspected for optional ContextKeyHeader{} key, with `http.Header` value,
   121  // which will be added to outgoing request headers, overriding any colliding c.DefaultHeader
   122  func (c *connectDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   123  	req := (&http.Request{
   124  		Method: "CONNECT",
   125  		URL:    &url.URL{Host: address},
   126  		Header: make(http.Header),
   127  		Host:   address,
   128  	}).WithContext(ctx)
   129  	for k, v := range c.DefaultHeader {
   130  		req.Header[k] = v
   131  	}
   132  	if ctxHeader, ctxHasHeader := ctx.Value(ContextKeyHeader{}).(http.Header); ctxHasHeader {
   133  		for k, v := range ctxHeader {
   134  			req.Header[k] = v
   135  		}
   136  	}
   137  	connectHTTP2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) {
   138  		req.Proto = "HTTP/2.0"
   139  		req.ProtoMajor = 2
   140  		req.ProtoMinor = 0
   141  		pr, pw := io.Pipe()
   142  		req.Body = pr
   143  
   144  		resp, err := h2clientConn.RoundTrip(req)
   145  		if err != nil {
   146  			_ = rawConn.Close()
   147  			return nil, err
   148  		}
   149  
   150  		if resp.StatusCode != http.StatusOK {
   151  			_ = rawConn.Close()
   152  			return nil, errors.New("Proxy responded with non 200 code: " + resp.Status + "StatusCode:" + strconv.Itoa(resp.StatusCode))
   153  		}
   154  		return newHTTP2Conn(rawConn, pw, resp.Body), nil
   155  	}
   156  
   157  	connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) {
   158  		req.Proto = "HTTP/1.1"
   159  		req.ProtoMajor = 1
   160  		req.ProtoMinor = 1
   161  
   162  		err := req.Write(rawConn)
   163  		if err != nil {
   164  			_ = rawConn.Close()
   165  			return nil, err
   166  		}
   167  
   168  		resp, err := http.ReadResponse(bufio.NewReader(rawConn), req)
   169  		if err != nil {
   170  			_ = rawConn.Close()
   171  			return nil, err
   172  		}
   173  
   174  		if resp.StatusCode != http.StatusOK {
   175  			_ = rawConn.Close()
   176  			return nil, errors.New("Proxy responded with non 200 code: " + resp.Status + " StatusCode:" + strconv.Itoa(resp.StatusCode))
   177  		}
   178  		return rawConn, nil
   179  	}
   180  
   181  	if c.EnableH2ConnReuse {
   182  		c.cacheH2Mu.Lock()
   183  		unlocked := false
   184  		if c.cachedH2ClientConn != nil && c.cachedH2RawConn != nil {
   185  			if c.cachedH2ClientConn.CanTakeNewRequest() {
   186  				rc := c.cachedH2RawConn
   187  				cc := c.cachedH2ClientConn
   188  				c.cacheH2Mu.Unlock()
   189  				unlocked = true
   190  				proxyConn, err := connectHTTP2(rc, cc)
   191  				if err == nil {
   192  					return proxyConn, err
   193  				}
   194  				// else: carry on and try again
   195  			}
   196  		}
   197  		if !unlocked {
   198  			c.cacheH2Mu.Unlock()
   199  		}
   200  	}
   201  
   202  	var err error
   203  	var rawConn net.Conn
   204  	negotiatedProtocol := ""
   205  	switch c.ProxyURL.Scheme {
   206  	case "http":
   207  		rawConn, err = c.Dialer.DialContext(ctx, network, c.ProxyURL.Host)
   208  		if err != nil {
   209  			return nil, err
   210  		}
   211  	case "https":
   212  		if c.DialTLS != nil {
   213  			rawConn, negotiatedProtocol, err = c.DialTLS(network, c.ProxyURL.Host)
   214  			if err != nil {
   215  				return nil, err
   216  			}
   217  		} else {
   218  			tlsConf := tls.Config{
   219  				NextProtos: []string{"h2", "http/1.1"},
   220  				ServerName: c.ProxyURL.Hostname(),
   221  			}
   222  			tlsConn, err := tls.Dial(network, c.ProxyURL.Host, &tlsConf)
   223  			if err != nil {
   224  				return nil, err
   225  			}
   226  			err = tlsConn.Handshake()
   227  			if err != nil {
   228  				return nil, err
   229  			}
   230  			negotiatedProtocol = tlsConn.ConnectionState().NegotiatedProtocol
   231  			rawConn = tlsConn
   232  		}
   233  	default:
   234  		return nil, errors.New("scheme " + c.ProxyURL.Scheme + " is not supported")
   235  	}
   236  
   237  	switch negotiatedProtocol {
   238  	case "":
   239  		fallthrough
   240  	case "http/1.1":
   241  		return connectHTTP1(rawConn)
   242  	case "h2":
   243  		t := http2.Transport{}
   244  		h2clientConn, err := t.NewClientConn(rawConn)
   245  		if err != nil {
   246  			_ = rawConn.Close()
   247  			return nil, err
   248  		}
   249  
   250  		proxyConn, err := connectHTTP2(rawConn, h2clientConn)
   251  		if err != nil {
   252  			_ = rawConn.Close()
   253  			return nil, err
   254  		}
   255  		if c.EnableH2ConnReuse {
   256  			c.cacheH2Mu.Lock()
   257  			c.cachedH2ClientConn = h2clientConn
   258  			c.cachedH2RawConn = rawConn
   259  			c.cacheH2Mu.Unlock()
   260  		}
   261  		return proxyConn, err
   262  	default:
   263  		_ = rawConn.Close()
   264  		return nil, errors.New("negotiated unsupported application layer protocol: " +
   265  			negotiatedProtocol)
   266  	}
   267  }
   268  
   269  func newHTTP2Conn(c net.Conn, pipedReqBody *io.PipeWriter, respBody io.ReadCloser) net.Conn {
   270  	return &http2Conn{Conn: c, in: pipedReqBody, out: respBody}
   271  }
   272  
   273  func (h *http2Conn) Read(p []byte) (n int, err error) {
   274  	return h.out.Read(p)
   275  }
   276  
   277  func (h *http2Conn) Write(p []byte) (n int, err error) {
   278  	return h.in.Write(p)
   279  }
   280  
   281  func (h *http2Conn) Close() error {
   282  	var retErr error = nil
   283  	if err := h.in.Close(); err != nil {
   284  		retErr = err
   285  	}
   286  	if err := h.out.Close(); err != nil {
   287  		retErr = err
   288  	}
   289  	return retErr
   290  }
   291  
   292  func (h *http2Conn) CloseConn() error {
   293  	return h.Conn.Close()
   294  }
   295  
   296  func (h *http2Conn) CloseWrite() error {
   297  	return h.in.Close()
   298  }
   299  
   300  func (h *http2Conn) CloseRead() error {
   301  	return h.out.Close()
   302  }
   303  
   304  func getURL(addr string) (*url.URL, error) {
   305  	r := &http.Request{
   306  		URL: &url.URL{
   307  			Scheme: "https",
   308  			Host:   addr,
   309  		},
   310  	}
   311  	return http.ProxyFromEnvironment(r)
   312  }
   313  
   314  func (p *pbuffer) Read(b []byte) (int, error) {
   315  	return p.r.Read(b)
   316  }
   317  
   318  func proxyDial(conn net.Conn, addr string, proxyURL *url.URL) (_ net.Conn, err error) {
   319  	defer func() {
   320  		if err != nil {
   321  			conn.Close()
   322  		}
   323  	}()
   324  
   325  	r := &http.Request{
   326  		Method: http.MethodConnect,
   327  		URL:    &url.URL{Host: addr},
   328  		Header: map[string][]string{"User-Agent": {"volts/latest"}},
   329  	}
   330  
   331  	if user := proxyURL.User; user != nil {
   332  		u := user.Username()
   333  		p, _ := user.Password()
   334  		auth := []byte(u + ":" + p)
   335  		basicAuth := base64.StdEncoding.EncodeToString(auth)
   336  		r.Header.Add(proxyAuthHeader, "Basic "+basicAuth)
   337  	}
   338  
   339  	if err := r.Write(conn); err != nil {
   340  		return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
   341  	}
   342  
   343  	br := bufio.NewReader(conn)
   344  	rsp, err := http.ReadResponse(br, r)
   345  	if err != nil {
   346  		return nil, fmt.Errorf("reading server HTTP response: %v", err)
   347  	}
   348  	defer rsp.Body.Close()
   349  	if rsp.StatusCode != http.StatusOK {
   350  		dump, err := httputil.DumpResponse(rsp, true)
   351  		if err != nil {
   352  			return nil, fmt.Errorf("failed to do connect handshake, status code: %s", rsp.Status)
   353  		}
   354  		return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
   355  	}
   356  
   357  	return &pbuffer{Conn: conn, r: br}, nil
   358  }
   359  
   360  // Creates a new connection
   361  func newConn(dial func(string) (net.Conn, error)) func(string) (net.Conn, error) {
   362  	return func(addr string) (net.Conn, error) {
   363  		// get the proxy url
   364  		proxyURL, err := getURL(addr)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  
   369  		// set to addr
   370  		callAddr := addr
   371  
   372  		// got proxy
   373  		if proxyURL != nil {
   374  			callAddr = proxyURL.Host
   375  		}
   376  
   377  		// dial the addr
   378  		c, err := dial(callAddr)
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  
   383  		// do proxy connect if we have proxy url
   384  		if proxyURL != nil {
   385  			c, err = proxyDial(c, addr, proxyURL)
   386  		}
   387  
   388  		return c, err
   389  	}
   390  }