github.com/sagernet/sing@v0.2.6/protocol/http/handshake.go (about)

     1  package http
     2  
     3  import (
     4  	std_bufio "bufio"
     5  	"context"
     6  	"encoding/base64"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/auth"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	F "github.com/sagernet/sing/common/format"
    17  	M "github.com/sagernet/sing/common/metadata"
    18  	N "github.com/sagernet/sing/common/network"
    19  )
    20  
    21  type Handler = N.TCPConnectionHandler
    22  
    23  func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator auth.Authenticator, handler Handler, metadata M.Metadata) error {
    24  	var httpClient *http.Client
    25  	for {
    26  		request, err := ReadRequest(reader)
    27  		if err != nil {
    28  			return E.Cause(err, "read http request")
    29  		}
    30  
    31  		if authenticator != nil {
    32  			var authOk bool
    33  			authorization := request.Header.Get("Proxy-Authorization")
    34  			if strings.HasPrefix(authorization, "Basic ") {
    35  				userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
    36  				userPswdArr := strings.SplitN(string(userPassword), ":", 2)
    37  				authOk = authenticator.Verify(userPswdArr[0], userPswdArr[1])
    38  				if authOk {
    39  					ctx = auth.ContextWithUser(ctx, userPswdArr[0])
    40  				}
    41  			}
    42  			if !authOk {
    43  				err = responseWith(request, http.StatusProxyAuthRequired).Write(conn)
    44  				if err != nil {
    45  					return err
    46  				}
    47  			}
    48  		}
    49  
    50  		if sourceAddress := SourceAddress(request); sourceAddress.IsValid() {
    51  			metadata.Source = sourceAddress
    52  		}
    53  
    54  		if request.Method == "CONNECT" {
    55  			portStr := request.URL.Port()
    56  			if portStr == "" {
    57  				portStr = "80"
    58  			}
    59  			destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
    60  			_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
    61  			if err != nil {
    62  				return E.Cause(err, "write http response")
    63  			}
    64  			metadata.Protocol = "http"
    65  			metadata.Destination = destination
    66  
    67  			var requestConn net.Conn
    68  			if reader.Buffered() > 0 {
    69  				buffer := buf.NewSize(reader.Buffered())
    70  				_, err = buffer.ReadFullFrom(reader, reader.Buffered())
    71  				if err != nil {
    72  					return err
    73  				}
    74  				requestConn = bufio.NewCachedConn(conn, buffer)
    75  			} else {
    76  				requestConn = conn
    77  			}
    78  			return handler.NewConnection(ctx, requestConn, metadata)
    79  		}
    80  
    81  		keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
    82  		request.RequestURI = ""
    83  
    84  		removeHopByHopHeaders(request.Header)
    85  		removeExtraHTTPHostPort(request)
    86  
    87  		if request.URL.Scheme == "" || request.URL.Host == "" {
    88  			return responseWith(request, http.StatusBadRequest).Write(conn)
    89  		}
    90  
    91  		var innerErr error
    92  		if httpClient == nil {
    93  			httpClient = &http.Client{
    94  				Transport: &http.Transport{
    95  					DisableCompression: true,
    96  					DialContext: func(context context.Context, network, address string) (net.Conn, error) {
    97  						metadata.Destination = M.ParseSocksaddr(address)
    98  						metadata.Protocol = "http"
    99  						input, output := net.Pipe()
   100  						go func() {
   101  							hErr := handler.NewConnection(ctx, output, metadata)
   102  							if hErr != nil {
   103  								innerErr = hErr
   104  								common.Close(input, output)
   105  							}
   106  						}()
   107  						return input, nil
   108  					},
   109  				},
   110  				CheckRedirect: func(req *http.Request, via []*http.Request) error {
   111  					return http.ErrUseLastResponse
   112  				},
   113  			}
   114  		}
   115  
   116  		response, err := httpClient.Do(request)
   117  		if err != nil {
   118  			return E.Errors(innerErr, err, responseWith(request, http.StatusBadGateway).Write(conn))
   119  		}
   120  
   121  		removeHopByHopHeaders(response.Header)
   122  
   123  		if keepAlive {
   124  			response.Header.Set("Proxy-Connection", "keep-alive")
   125  			response.Header.Set("Connection", "keep-alive")
   126  			response.Header.Set("Keep-Alive", "timeout=4")
   127  		}
   128  
   129  		response.Close = !keepAlive
   130  
   131  		err = response.Write(conn)
   132  		if err != nil {
   133  			return E.Errors(innerErr, err)
   134  		}
   135  
   136  		if !keepAlive {
   137  			return conn.Close()
   138  		}
   139  	}
   140  }
   141  
   142  func removeHopByHopHeaders(header http.Header) {
   143  	// Strip hop-by-hop header based on RFC:
   144  	// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
   145  	// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
   146  
   147  	header.Del("Proxy-Connection")
   148  	header.Del("Proxy-Authenticate")
   149  	header.Del("Proxy-Authorization")
   150  	header.Del("TE")
   151  	header.Del("Trailers")
   152  	header.Del("Transfer-Encoding")
   153  	header.Del("Upgrade")
   154  
   155  	connections := header.Get("Connection")
   156  	header.Del("Connection")
   157  	if len(connections) == 0 {
   158  		return
   159  	}
   160  	for _, h := range strings.Split(connections, ",") {
   161  		header.Del(strings.TrimSpace(h))
   162  	}
   163  }
   164  
   165  func removeExtraHTTPHostPort(req *http.Request) {
   166  	host := req.Host
   167  	if host == "" {
   168  		host = req.URL.Host
   169  	}
   170  
   171  	if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
   172  		host = pHost
   173  	}
   174  
   175  	req.Host = host
   176  	req.URL.Host = host
   177  }
   178  
   179  func responseWith(request *http.Request, statusCode int) *http.Response {
   180  	return &http.Response{
   181  		StatusCode: statusCode,
   182  		Status:     http.StatusText(statusCode),
   183  		Proto:      request.Proto,
   184  		ProtoMajor: request.ProtoMajor,
   185  		ProtoMinor: request.ProtoMinor,
   186  		Header:     http.Header{},
   187  	}
   188  }