github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/protocol/http/handshake.go (about) 1 package http 2 3 import ( 4 std_bufio "bufio" 5 "context" 6 "encoding/base64" 7 "net" 8 "net/http" 9 "strings" 10 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/atomic" 13 "github.com/sagernet/sing/common/auth" 14 "github.com/sagernet/sing/common/buf" 15 "github.com/sagernet/sing/common/bufio" 16 E "github.com/sagernet/sing/common/exceptions" 17 F "github.com/sagernet/sing/common/format" 18 M "github.com/sagernet/sing/common/metadata" 19 N "github.com/sagernet/sing/common/network" 20 "github.com/sagernet/sing/common/pipe" 21 ) 22 23 type Handler = N.TCPConnectionHandler 24 25 func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { 26 for { 27 request, err := ReadRequest(reader) 28 if err != nil { 29 return E.Cause(err, "read http request") 30 } 31 32 if authenticator != nil { 33 var ( 34 username string 35 password string 36 authOk bool 37 ) 38 authorization := request.Header.Get("Proxy-Authorization") 39 if strings.HasPrefix(authorization, "Basic ") { 40 userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) 41 userPswdArr := strings.SplitN(string(userPassword), ":", 2) 42 if len(userPswdArr) == 2 { 43 username = userPswdArr[0] 44 password = userPswdArr[1] 45 authOk = authenticator.Verify(username, password) 46 if authOk { 47 ctx = auth.ContextWithUser(ctx, userPswdArr[0]) 48 } 49 } 50 } 51 if !authOk { 52 // Since no one else is using the library, use a fixed realm until rewritten 53 err = responseWith( 54 request, http.StatusProxyAuthRequired, 55 "Proxy-Authenticate", `Basic realm="sing-box" charset="UTF-8"`, 56 ).Write(conn) 57 if err != nil { 58 return err 59 } 60 if username != "" { 61 return E.New("http: authentication failed, username=", username, ", password=", password) 62 } else if authorization != "" { 63 return E.New("http: authentication failed, Proxy-Authorization=", authorization) 64 } else { 65 return E.New("http: authentication failed, no Proxy-Authorization header") 66 } 67 } 68 } 69 70 if sourceAddress := SourceAddress(request); sourceAddress.IsValid() { 71 metadata.Source = sourceAddress 72 } 73 74 if request.Method == "CONNECT" { 75 portStr := request.URL.Port() 76 if portStr == "" { 77 portStr = "80" 78 } 79 destination := M.ParseSocksaddrHostPortStr(request.URL.Hostname(), portStr) 80 _, err = conn.Write([]byte(F.ToString("HTTP/", request.ProtoMajor, ".", request.ProtoMinor, " 200 Connection established\r\n\r\n"))) 81 if err != nil { 82 return E.Cause(err, "write http response") 83 } 84 metadata.Protocol = "http" 85 metadata.Destination = destination 86 87 var requestConn net.Conn 88 if reader.Buffered() > 0 { 89 buffer := buf.NewSize(reader.Buffered()) 90 _, err = buffer.ReadFullFrom(reader, reader.Buffered()) 91 if err != nil { 92 return err 93 } 94 requestConn = bufio.NewCachedConn(conn, buffer) 95 } else { 96 requestConn = conn 97 } 98 return handler.NewConnection(ctx, requestConn, metadata) 99 } 100 101 keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" 102 request.RequestURI = "" 103 104 removeHopByHopHeaders(request.Header) 105 removeExtraHTTPHostPort(request) 106 107 if hostStr := request.Header.Get("Host"); hostStr != "" { 108 if hostStr != request.URL.Host { 109 request.Host = hostStr 110 } 111 } 112 113 if request.URL.Scheme == "" || request.URL.Host == "" { 114 return responseWith(request, http.StatusBadRequest).Write(conn) 115 } 116 117 var innerErr atomic.TypedValue[error] 118 httpClient := &http.Client{ 119 Transport: &http.Transport{ 120 DisableCompression: true, 121 DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { 122 metadata.Destination = M.ParseSocksaddr(address) 123 metadata.Protocol = "http" 124 input, output := pipe.Pipe() 125 go func() { 126 hErr := handler.NewConnection(ctx, output, metadata) 127 if hErr != nil { 128 innerErr.Store(hErr) 129 common.Close(input, output) 130 } 131 }() 132 return input, nil 133 }, 134 }, 135 CheckRedirect: func(req *http.Request, via []*http.Request) error { 136 return http.ErrUseLastResponse 137 }, 138 } 139 requestCtx, cancel := context.WithCancel(ctx) 140 response, err := httpClient.Do(request.WithContext(requestCtx)) 141 if err != nil { 142 cancel() 143 return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) 144 } 145 146 removeHopByHopHeaders(response.Header) 147 148 if keepAlive { 149 response.Header.Set("Proxy-Connection", "keep-alive") 150 response.Header.Set("Connection", "keep-alive") 151 response.Header.Set("Keep-Alive", "timeout=4") 152 } 153 154 response.Close = !keepAlive 155 156 err = response.Write(conn) 157 if err != nil { 158 cancel() 159 return E.Errors(innerErr.Load(), err) 160 } 161 162 cancel() 163 if !keepAlive { 164 return conn.Close() 165 } 166 } 167 } 168 169 func removeHopByHopHeaders(header http.Header) { 170 // Strip hop-by-hop header based on RFC: 171 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 172 // https://www.mnot.net/blog/2011/07/11/what_proxies_must_do 173 174 header.Del("Proxy-Connection") 175 header.Del("Proxy-Authenticate") 176 header.Del("Proxy-Authorization") 177 header.Del("TE") 178 header.Del("Trailers") 179 header.Del("Transfer-Encoding") 180 header.Del("Upgrade") 181 182 connections := header.Get("Connection") 183 header.Del("Connection") 184 if len(connections) == 0 { 185 return 186 } 187 for _, h := range strings.Split(connections, ",") { 188 header.Del(strings.TrimSpace(h)) 189 } 190 } 191 192 func removeExtraHTTPHostPort(req *http.Request) { 193 host := req.Host 194 if host == "" { 195 host = req.URL.Host 196 } 197 198 if pHost, port, err := net.SplitHostPort(host); err == nil && port == "80" { 199 if M.ParseAddr(pHost).Is6() { 200 pHost = "[" + pHost + "]" 201 } 202 host = pHost 203 } 204 205 req.Host = host 206 req.URL.Host = host 207 } 208 209 func responseWith(request *http.Request, statusCode int, headers ...string) *http.Response { 210 var header http.Header 211 if len(headers) > 0 { 212 header = make(http.Header) 213 for i := 0; i < len(headers); i += 2 { 214 header.Add(headers[i], headers[i+1]) 215 } 216 } 217 return &http.Response{ 218 StatusCode: statusCode, 219 Status: http.StatusText(statusCode), 220 Proto: request.Proto, 221 ProtoMajor: request.ProtoMajor, 222 ProtoMinor: request.ProtoMinor, 223 Header: header, 224 } 225 }