github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/http/server.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log/slog"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httputil"
    12  	"time"
    13  	_ "unsafe"
    14  
    15  	"github.com/Asutorufa/yuhaiin/pkg/log"
    16  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    17  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    18  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    19  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    20  )
    21  
    22  type Server struct {
    23  	username, password string
    24  	reverseProxy       *httputil.ReverseProxy
    25  
    26  	*netapi.ChannelServer
    27  
    28  	lis net.Listener
    29  }
    30  
    31  func newServer(o *listener.Inbound_Http, lis net.Listener) *Server {
    32  	h := &Server{
    33  		username:      o.Http.Username,
    34  		password:      o.Http.Password,
    35  		lis:           lis,
    36  		ChannelServer: netapi.NewChannelServer(),
    37  	}
    38  
    39  	type remoteKey struct{}
    40  
    41  	tr := &http.Transport{
    42  		MaxIdleConns:          100,
    43  		IdleConnTimeout:       90 * time.Second,
    44  		TLSHandshakeTimeout:   10 * time.Second,
    45  		ExpectContinueTimeout: 1 * time.Second,
    46  		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
    47  			address, err := netapi.ParseAddress(statistic.Type_tcp, addr)
    48  			if err != nil {
    49  				return nil, fmt.Errorf("parse address failed: %w", err)
    50  			}
    51  
    52  			remoteAddr, _ := ctx.Value(remoteKey{}).(string)
    53  
    54  			source, err := netapi.ParseAddress(statistic.Type_tcp, remoteAddr)
    55  			if err != nil {
    56  				source = netapi.ParseAddressPort(statistic.Type_tcp, remoteAddr, netapi.EmptyPort)
    57  			}
    58  
    59  			local, remote := net.Pipe()
    60  
    61  			sm := &netapi.StreamMeta{
    62  				Source:      source,
    63  				Inbound:     h.lis.Addr(),
    64  				Destination: address,
    65  				Src:         local,
    66  				Address:     address,
    67  			}
    68  
    69  			if h.SendStream(sm) != nil {
    70  				_ = local.Close()
    71  				_ = remote.Close()
    72  				return nil, io.EOF
    73  			}
    74  
    75  			return remote, nil
    76  		},
    77  	}
    78  
    79  	h.reverseProxy = &httputil.ReverseProxy{
    80  		Transport:  tr,
    81  		BufferPool: pool.ReverseProxyBuffer{},
    82  		ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
    83  			if err != nil && !errors.Is(err, context.Canceled) {
    84  				log.Error("http: proxy error: ", "err", err)
    85  			}
    86  			w.WriteHeader(http.StatusBadGateway)
    87  		},
    88  		Rewrite: func(pr *httputil.ProxyRequest) {
    89  			pr.Out = pr.Out.WithContext(context.WithValue(pr.Out.Context(), remoteKey{}, pr.In.RemoteAddr))
    90  			pr.Out.RequestURI = ""
    91  		},
    92  	}
    93  
    94  	return h
    95  }
    96  
    97  //go:linkname parseBasicAuth net/http.parseBasicAuth
    98  func parseBasicAuth(auth string) (username, password string, ok bool)
    99  
   100  func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   101  	defer r.Body.Close()
   102  
   103  	if h.password != "" || h.username != "" {
   104  		username, password, isHas := parseBasicAuth(r.Header.Get("Proxy-Authorization"))
   105  		if !isHas {
   106  			w.Header().Set("Proxy-Authenticate", "Basic")
   107  			w.WriteHeader(http.StatusProxyAuthRequired)
   108  			return
   109  		}
   110  
   111  		if username != h.username || password != h.password {
   112  			w.WriteHeader(http.StatusForbidden)
   113  			return
   114  		}
   115  	}
   116  
   117  	switch r.Method {
   118  	case http.MethodConnect:
   119  		if err := h.connect(w, r); err != nil {
   120  			slog.Error("connect failed", "err", err)
   121  		}
   122  	default:
   123  		h.reverseProxy.ServeHTTP(w, r)
   124  	}
   125  }
   126  
   127  func (h *Server) connect(w http.ResponseWriter, req *http.Request) error {
   128  	host := req.URL.Host
   129  	if req.URL.Port() == "" {
   130  		switch req.URL.Scheme {
   131  		case "http":
   132  			host = net.JoinHostPort(host, "80")
   133  		case "https":
   134  			host = net.JoinHostPort(host, "443")
   135  		}
   136  	}
   137  
   138  	dst, err := netapi.ParseAddress(statistic.Type_tcp, host)
   139  	if err != nil {
   140  		w.WriteHeader(http.StatusBadGateway)
   141  		return fmt.Errorf("parse address failed: %w", err)
   142  	}
   143  
   144  	w.WriteHeader(http.StatusOK)
   145  
   146  	client, _, err := http.NewResponseController(w).Hijack()
   147  	if err != nil {
   148  		return fmt.Errorf("hijack failed: %w", err)
   149  	}
   150  
   151  	source, err := netapi.ParseAddress(statistic.Type_tcp, req.RemoteAddr)
   152  	if err != nil {
   153  		source = netapi.ParseAddressPort(statistic.Type_tcp, req.RemoteAddr, netapi.EmptyPort)
   154  	}
   155  
   156  	sm := &netapi.StreamMeta{
   157  		Source:      source,
   158  		Inbound:     h.lis.Addr(),
   159  		Destination: dst,
   160  		Src:         client,
   161  		Address:     dst,
   162  	}
   163  
   164  	return h.SendStream(sm)
   165  }
   166  
   167  func (s *Server) AcceptPacket() (*netapi.Packet, error) {
   168  	return nil, io.EOF
   169  }
   170  
   171  func (s *Server) Close() error {
   172  	s.ChannelServer.Close()
   173  	if s.lis != nil {
   174  		return s.lis.Close()
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func init() {
   181  	listener.RegisterProtocol(NewServer)
   182  }
   183  
   184  func NewServer(o *listener.Inbound_Http) func(netapi.Listener) (netapi.Accepter, error) {
   185  	return func(ii netapi.Listener) (netapi.Accepter, error) {
   186  		lis, err := ii.Stream(context.TODO())
   187  		if err != nil {
   188  			return nil, err
   189  		}
   190  
   191  		s := newServer(o, lis)
   192  
   193  		go func() {
   194  			defer ii.Close()
   195  			if err := http.Serve(lis, s); err != nil {
   196  				log.Error("http serve failed:", err)
   197  			}
   198  		}()
   199  
   200  		return s, nil
   201  	}
   202  }