github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/protocol/http/handshake.go (about)

     1  package http
     2  
     3  import (
     4  	std_bufio "bufio"
     5  	"context"
     6  	"encoding/base64"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/atomic"
    13  	"github.com/sagernet/sing/common/auth"
    14  	"github.com/sagernet/sing/common/buf"
    15  	"github.com/sagernet/sing/common/bufio"
    16  	E "github.com/sagernet/sing/common/exceptions"
    17  	F "github.com/sagernet/sing/common/format"
    18  	M "github.com/sagernet/sing/common/metadata"
    19  	N "github.com/sagernet/sing/common/network"
    20  	"github.com/sagernet/sing/common/pipe"
    21  )
    22  
    23  type Handler = N.TCPConnectionHandler
    24  
    25  func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error {
    26  	for {
    27  		request, err := ReadRequest(reader)
    28  		if err != nil {
    29  			return E.Cause(err, "read http request")
    30  		}
    31  
    32  		if authenticator != nil {
    33  			var (
    34  				username string
    35  				password string
    36  				authOk   bool
    37  			)
    38  			authorization := request.Header.Get("Proxy-Authorization")
    39  			if strings.HasPrefix(authorization, "Basic ") {
    40  				userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
    41  				userPswdArr := strings.SplitN(string(userPassword), ":", 2)
    42  				if len(userPswdArr) == 2 {
    43  					username = userPswdArr[0]
    44  					password = userPswdArr[1]
    45  					authOk = authenticator.Verify(username, password)
    46  					if authOk {
    47  						ctx = auth.ContextWithUser(ctx, userPswdArr[0])
    48  					}
    49  				}
    50  			}
    51  			if !authOk {
    52  				// Since no one else is using the library, use a fixed realm until rewritten
    53  				err = responseWith(
    54  					request, http.StatusProxyAuthRequired,
    55  					"Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`,
    56  				).Write(conn)
    57  				if err != nil {
    58  					return err
    59  				}
    60  				if username != "" {
    61  					return E.New("http: authentication failed, username=", username, ", password=", password)
    62  				} else if authorization != "" {
    63  					return E.New("http: authentication failed, Proxy-Authorization=", authorization)
    64  				} else {
    65  					return E.New("http: authentication failed, no Proxy-Authorization header")
    66  				}
    67  			}
    68  		}
    69  
    70  		if sourceAddress := SourceAddress(request); sourceAddress.IsValid() {
    71  			metadata.Source = sourceAddress
    72  		}
    73  
    74  		if request.Method == "CONNECT" {
    75  			portStr := request.URL.Port()
    76  			if portStr == "" {
    77  				portStr = "80"
    78  			}
    79  			destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr)
    80  			_, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n")))
    81  			if err != nil {
    82  				return E.Cause(err, "write http response")
    83  			}
    84  			metadata.Protocol = "http"
    85  			metadata.Destination = destination
    86  
    87  			var requestConn net.Conn
    88  			if reader.Buffered() > 0 {
    89  				buffer := buf.NewSize(reader.Buffered())
    90  				_, err = buffer.ReadFullFrom(reader, reader.Buffered())
    91  				if err != nil {
    92  					return err
    93  				}
    94  				requestConn = bufio.NewCachedConn(conn, buffer)
    95  			} else {
    96  				requestConn = conn
    97  			}
    98  			return handler.NewConnection(ctx, requestConn, metadata)
    99  		}
   100  
   101  		keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive"
   102  		request.RequestURI = ""
   103  
   104  		removeHopByHopHeaders(request.Header)
   105  		removeExtraHTTPHostPort(request)
   106  
   107  		if hostStr := request.Header.Get("Host"); hostStr != "" {
   108  			if hostStr != request.URL.Host {
   109  				request.Host = hostStr
   110  			}
   111  		}
   112  
   113  		if request.URL.Scheme == "" || request.URL.Host == "" {
   114  			return responseWith(request, http.StatusBadRequest).Write(conn)
   115  		}
   116  
   117  		var innerErr atomic.TypedValue[error]
   118  		httpClient := &http.Client{
   119  			Transport: &http.Transport{
   120  				DisableCompression: true,
   121  				DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
   122  					metadata.Destination = M.ParseSocksaddr(address)
   123  					metadata.Protocol = "http"
   124  					input, output := pipe.Pipe()
   125  					go func() {
   126  						hErr := handler.NewConnection(ctx, output, metadata)
   127  						if hErr != nil {
   128  							innerErr.Store(hErr)
   129  							common.Close(input, output)
   130  						}
   131  					}()
   132  					return input, nil
   133  				},
   134  			},
   135  			CheckRedirect: func(req *http.Request, via []*http.Request) error {
   136  				return http.ErrUseLastResponse
   137  			},
   138  		}
   139  		requestCtx, cancel := context.WithCancel(ctx)
   140  		response, err := httpClient.Do(request.WithContext(requestCtx))
   141  		if err != nil {
   142  			cancel()
   143  			return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn))
   144  		}
   145  
   146  		removeHopByHopHeaders(response.Header)
   147  
   148  		if keepAlive {
   149  			response.Header.Set("Proxy-Connection", "keep-alive")
   150  			response.Header.Set("Connection", "keep-alive")
   151  			response.Header.Set("Keep-Alive", "timeout=4")
   152  		}
   153  
   154  		response.Close = !keepAlive
   155  
   156  		err = response.Write(conn)
   157  		if err != nil {
   158  			cancel()
   159  			return E.Errors(innerErr.Load(), err)
   160  		}
   161  
   162  		cancel()
   163  		if !keepAlive {
   164  			return conn.Close()
   165  		}
   166  	}
   167  }
   168  
   169  func removeHopByHopHeaders(header http.Header) {
   170  	// Strip hop-by-hop header based on RFC:
   171  	// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1
   172  	// https://www.mnot.net/blog/2011/07/11/what_proxies_must_do
   173  
   174  	header.Del("Proxy-Connection")
   175  	header.Del("Proxy-Authenticate")
   176  	header.Del("Proxy-Authorization")
   177  	header.Del("TE")
   178  	header.Del("Trailers")
   179  	header.Del("Transfer-Encoding")
   180  	header.Del("Upgrade")
   181  
   182  	connections := header.Get("Connection")
   183  	header.Del("Connection")
   184  	if len(connections) == 0 {
   185  		return
   186  	}
   187  	for _, h := range strings.Split(connections, ",") {
   188  		header.Del(strings.TrimSpace(h))
   189  	}
   190  }
   191  
   192  func removeExtraHTTPHostPort(req *http.Request) {
   193  	host := req.Host
   194  	if host == "" {
   195  		host = req.URL.Host
   196  	}
   197  
   198  	if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" {
   199  		if M.ParseAddr(pHost).Is6() {
   200  			pHost = "[" + pHost + "]"
   201  		}
   202  		host = pHost
   203  	}
   204  
   205  	req.Host = host
   206  	req.URL.Host = host
   207  }
   208  
   209  func responseWith(request *http.Request, statusCode int, headers ...string) *http.Response {
   210  	var header http.Header
   211  	if len(headers) > 0 {
   212  		header = make(http.Header)
   213  		for i := 0; i < len(headers); i += 2 {
   214  			header.Add(headers[i], headers[i+1])
   215  		}
   216  	}
   217  	return &http.Response{
   218  		StatusCode: statusCode,
   219  		Status:     http.StatusText(statusCode),
   220  		Proto:      request.Proto,
   221  		ProtoMajor: request.ProtoMajor,
   222  		ProtoMinor: request.ProtoMinor,
   223  		Header:     header,
   224  	}
   225  }