github.com/yaling888/clash@v1.53.0/listener/mitm/proxy.go (about)

     1  package mitm
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/netip"
    14  	"os"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/yaling888/clash/common/cache"
    19  	N "github.com/yaling888/clash/common/net"
    20  	"github.com/yaling888/clash/component/auth"
    21  	C "github.com/yaling888/clash/constant"
    22  	H "github.com/yaling888/clash/listener/http"
    23  )
    24  
    25  func HandleConn(c net.Conn, opt *C.MitmOption, in chan<- C.ConnContext, cache *cache.LruCache[string, bool], auth auth.Authenticator) {
    26  	var (
    27  		clientIP   = netip.MustParseAddrPort(c.RemoteAddr().String()).Addr()
    28  		sourceAddr net.Addr
    29  		serverConn *N.BufferedConn
    30  		connState  *tls.ConnectionState
    31  	)
    32  
    33  	defer func() {
    34  		if serverConn != nil {
    35  			_ = serverConn.Close()
    36  		}
    37  	}()
    38  
    39  	conn := N.NewBufferedConn(c)
    40  
    41  	trusted := cache == nil // disable authenticate if cache is nil
    42  	if !trusted {
    43  		trusted = clientIP.IsLoopback() || clientIP.IsUnspecified()
    44  	}
    45  
    46  readLoop:
    47  	for {
    48  		// use SetReadDeadline instead of Proxy-Connection keep-alive
    49  		if err := conn.SetReadDeadline(time.Now().Add(65 * time.Second)); err != nil {
    50  			break
    51  		}
    52  
    53  		request, err := H.ReadRequest(conn.Reader())
    54  		if err != nil {
    55  			break
    56  		}
    57  
    58  		var response *http.Response
    59  
    60  		session := C.NewMitmSession(conn, request, response)
    61  
    62  		sourceAddr = parseSourceAddress(session.Request, conn.RemoteAddr(), sourceAddr)
    63  		session.Request.RemoteAddr = sourceAddr.String()
    64  
    65  		if !trusted {
    66  			session.Response = H.Authenticate(session.Request, cache, auth)
    67  
    68  			trusted = session.Response == nil
    69  		}
    70  
    71  		if trusted {
    72  			if session.Request.Method == http.MethodConnect {
    73  				if session.Request.ProtoMajor > 1 {
    74  					session.Request.ProtoMajor = 1
    75  					session.Request.ProtoMinor = 1
    76  				}
    77  
    78  				// Manual writing to support CONNECT for http 1.0 (workaround for uplay client)
    79  				if _, err = fmt.Fprintf(session.Conn, "HTTP/%d.%d %03d %s\r\n\r\n", session.Request.ProtoMajor, session.Request.ProtoMinor, http.StatusOK, "Connection established"); err != nil {
    80  					handleError(opt, session, err)
    81  					break // close connection
    82  				}
    83  
    84  				if strings.HasSuffix(session.Request.URL.Host, ":80") {
    85  					goto readLoop
    86  				}
    87  
    88  				b, err1 := conn.Peek(1)
    89  				if err1 != nil {
    90  					handleError(opt, session, err1)
    91  					break // close connection
    92  				}
    93  
    94  				// TLS handshake.
    95  				if b[0] == 0x16 {
    96  					tlsConn := tls.Server(conn, opt.CertConfig.NewTLSConfigForHost(session.Request.URL.Hostname()))
    97  
    98  					ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    99  					// handshake with the local client
   100  					if err = tlsConn.HandshakeContext(ctx); err != nil {
   101  						cancel()
   102  						session.Response = session.NewErrorResponse(fmt.Errorf("handshake failed: %w", err))
   103  						_ = writeResponse(session, false)
   104  						break // close connection
   105  					}
   106  					cancel()
   107  
   108  					cs := tlsConn.ConnectionState()
   109  					connState = &cs
   110  
   111  					conn = N.NewBufferedConn(tlsConn)
   112  				}
   113  
   114  				if strings.HasSuffix(session.Request.URL.Host, ":443") {
   115  					goto readLoop
   116  				}
   117  
   118  				if conn.SetReadDeadline(time.Now().Add(time.Second)) != nil {
   119  					break
   120  				}
   121  
   122  				buf, err2 := conn.Peek(7)
   123  				if err2 != nil {
   124  					if err2 != bufio.ErrBufferFull && !os.IsTimeout(err2) {
   125  						handleError(opt, session, err2)
   126  						break // close connection
   127  					}
   128  				}
   129  
   130  				// others protocol over tcp
   131  				if !isHTTPTraffic(buf) {
   132  					if connState != nil {
   133  						session.Request.TLS = connState
   134  					}
   135  
   136  					serverConn, err = getServerConn(serverConn, session.Request, sourceAddr, conn.LocalAddr(), in)
   137  					if err != nil {
   138  						break
   139  					}
   140  
   141  					if conn.SetReadDeadline(time.Time{}) != nil {
   142  						break
   143  					}
   144  
   145  					N.Relay(serverConn, conn)
   146  					return // hijack connection
   147  				}
   148  
   149  				goto readLoop
   150  			}
   151  
   152  			prepareRequest(connState, session.Request)
   153  
   154  			// hijack api
   155  			if session.Request.URL.Hostname() == opt.ApiHost {
   156  				if err = handleApiRequest(session, opt); err != nil {
   157  					handleError(opt, session, err)
   158  				}
   159  				break
   160  			}
   161  
   162  			// forward websocket
   163  			if isWebsocketRequest(request) {
   164  				serverConn, err = getServerConn(serverConn, session.Request, sourceAddr, conn.LocalAddr(), in)
   165  				if err != nil {
   166  					break
   167  				}
   168  
   169  				session.Request.RequestURI = ""
   170  				if session.Response = H.HandleUpgrade(conn, serverConn, request, in); session.Response == nil {
   171  					return // hijack connection
   172  				}
   173  			}
   174  
   175  			if session.Response == nil {
   176  				H.RemoveHopByHopHeaders(session.Request.Header)
   177  				H.RemoveExtraHTTPHostPort(session.Request)
   178  
   179  				// hijack custom request and write back custom response if necessary
   180  				newReq, newRes := opt.Handler.HandleRequest(session)
   181  				if newReq != nil {
   182  					session.Request = newReq
   183  				}
   184  				if newRes != nil {
   185  					session.Response = newRes
   186  
   187  					if err = writeResponse(session, false); err != nil {
   188  						handleError(opt, session, err)
   189  						break
   190  					}
   191  					continue
   192  				}
   193  
   194  				session.Request.RequestURI = ""
   195  
   196  				if session.Request.URL.Host == "" {
   197  					session.Response = session.NewErrorResponse(C.ErrInvalidURL)
   198  				} else {
   199  					serverConn, err = getServerConn(serverConn, session.Request, sourceAddr, conn.LocalAddr(), in)
   200  					if err != nil {
   201  						break
   202  					}
   203  
   204  					// send the request to remote server
   205  					err = session.Request.Write(serverConn)
   206  					if err != nil {
   207  						break
   208  					}
   209  
   210  					session.Response, err = http.ReadResponse(serverConn.Reader(), request)
   211  					if err != nil {
   212  						break
   213  					}
   214  				}
   215  			}
   216  		}
   217  
   218  		if err = writeResponseWithHandler(session, opt); err != nil {
   219  			handleError(opt, session, err)
   220  			break // close connection
   221  		}
   222  	}
   223  
   224  	_ = conn.Close()
   225  }
   226  
   227  func writeResponseWithHandler(session *C.MitmSession, opt *C.MitmOption) error {
   228  	res := opt.Handler.HandleResponse(session)
   229  	if res != nil {
   230  		session.Response = res
   231  	}
   232  
   233  	return writeResponse(session, true)
   234  }
   235  
   236  func writeResponse(session *C.MitmSession, keepAlive bool) error {
   237  	H.RemoveHopByHopHeaders(session.Response.Header)
   238  
   239  	if keepAlive {
   240  		session.Response.Header.Set("Connection", "keep-alive")
   241  		session.Response.Header.Set("Keep-Alive", "timeout=60")
   242  	}
   243  
   244  	return session.WriteResponse()
   245  }
   246  
   247  func handleApiRequest(session *C.MitmSession, opt *C.MitmOption) error {
   248  	if opt.CertConfig != nil && strings.ToLower(session.Request.URL.Path) == "/cert.crt" {
   249  		b := pem.EncodeToMemory(&pem.Block{
   250  			Type:  "CERTIFICATE",
   251  			Bytes: opt.CertConfig.GetRootCA().Raw,
   252  		})
   253  
   254  		session.Response = session.NewResponse(http.StatusOK, bytes.NewReader(b))
   255  
   256  		session.Response.Close = true
   257  		session.Response.Header.Set("Content-Type", "application/x-x509-ca-cert")
   258  		session.Response.ContentLength = int64(len(b))
   259  
   260  		return session.WriteResponse()
   261  	}
   262  
   263  	b := `<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">
   264  <html><head>
   265  <title>Clash MITM Proxy Services - 404 Not Found</title>
   266  </head><body>
   267  <h1>Not Found</h1>
   268  <p>The requested URL %s was not found on this server.</p>
   269  </body></html>
   270  `
   271  
   272  	if opt.Handler.HandleApiRequest(session) {
   273  		return nil
   274  	}
   275  
   276  	b = fmt.Sprintf(b, session.Request.URL.Path)
   277  
   278  	session.Response = session.NewResponse(http.StatusNotFound, bytes.NewReader([]byte(b)))
   279  	session.Response.Close = true
   280  	session.Response.Header.Set("Content-Type", "text/html;charset=utf-8")
   281  	session.Response.ContentLength = int64(len(b))
   282  
   283  	return session.WriteResponse()
   284  }
   285  
   286  func handleError(opt *C.MitmOption, session *C.MitmSession, err error) {
   287  	if session.Response != nil {
   288  		defer func() {
   289  			_, _ = io.Copy(io.Discard, session.Response.Body)
   290  			_ = session.Response.Body.Close()
   291  		}()
   292  	}
   293  	opt.Handler.HandleError(session, err)
   294  }
   295  
   296  func prepareRequest(connState *tls.ConnectionState, request *http.Request) {
   297  	host := request.Header.Get("Host")
   298  	if host != "" {
   299  		request.Host = host
   300  	}
   301  
   302  	if request.URL.Host == "" {
   303  		request.URL.Host = request.Host
   304  	}
   305  
   306  	if request.URL.Scheme == "" {
   307  		request.URL.Scheme = "http"
   308  	}
   309  
   310  	if connState != nil {
   311  		request.TLS = connState
   312  		request.URL.Scheme = "https"
   313  	}
   314  
   315  	if request.Header.Get("Accept-Encoding") != "" {
   316  		request.Header.Set("Accept-Encoding", "gzip")
   317  	}
   318  }
   319  
   320  func parseSourceAddress(req *http.Request, connSource, source net.Addr) net.Addr {
   321  	if source != nil {
   322  		return source
   323  	}
   324  
   325  	sourceAddress := req.Header.Get("Origin-Request-Source-Address")
   326  	if sourceAddress == "" {
   327  		return connSource
   328  	}
   329  
   330  	req.Header.Del("Origin-Request-Source-Address")
   331  
   332  	addrPort, err := netip.ParseAddrPort(sourceAddress)
   333  	if err != nil {
   334  		return connSource
   335  	}
   336  
   337  	return net.TCPAddrFromAddrPort(addrPort)
   338  }
   339  
   340  func isWebsocketRequest(req *http.Request) bool {
   341  	return strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
   342  }
   343  
   344  func isHTTPTraffic(buf []byte) bool {
   345  	method, _, _ := strings.Cut(string(buf), " ")
   346  	return validMethod(method)
   347  }