github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2rayhttp/server.go (about)

     1  package v2rayhttp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/sagernet/sing-box/adapter"
    12  	"github.com/sagernet/sing-box/common/tls"
    13  	C "github.com/sagernet/sing-box/constant"
    14  	"github.com/sagernet/sing-box/option"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  	aTLS "github.com/sagernet/sing/common/tls"
    22  	sHttp "github.com/sagernet/sing/protocol/http"
    23  
    24  	"golang.org/x/net/http2"
    25  	"golang.org/x/net/http2/h2c"
    26  )
    27  
    28  var _ adapter.V2RayServerTransport = (*Server)(nil)
    29  
    30  type Server struct {
    31  	ctx        context.Context
    32  	tlsConfig  tls.ServerConfig
    33  	handler    adapter.V2RayServerTransportHandler
    34  	httpServer *http.Server
    35  	h2Server   *http2.Server
    36  	h2cHandler http.Handler
    37  	host       []string
    38  	path       string
    39  	method     string
    40  	headers    http.Header
    41  }
    42  
    43  func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    44  	server := &Server{
    45  		ctx:       ctx,
    46  		tlsConfig: tlsConfig,
    47  		handler:   handler,
    48  		h2Server: &http2.Server{
    49  			IdleTimeout: time.Duration(options.IdleTimeout),
    50  		},
    51  		host:    options.Host,
    52  		path:    options.Path,
    53  		method:  options.Method,
    54  		headers: options.Headers.Build(),
    55  	}
    56  	if !strings.HasPrefix(server.path, "/") {
    57  		server.path = "/" + server.path
    58  	}
    59  	server.httpServer = &http.Server{
    60  		Handler:           server,
    61  		ReadHeaderTimeout: C.TCPTimeout,
    62  		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
    63  		BaseContext: func(net.Listener) context.Context {
    64  			return ctx
    65  		},
    66  	}
    67  	server.h2cHandler = h2c.NewHandler(server, server.h2Server)
    68  	return server, nil
    69  }
    70  
    71  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    72  	if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
    73  		s.h2cHandler.ServeHTTP(writer, request)
    74  		return
    75  	}
    76  	host := request.Host
    77  	if len(s.host) > 0 && !common.Contains(s.host, host) {
    78  		s.invalidRequest(writer, request, http.StatusBadRequest, E.New("bad host: ", host))
    79  		return
    80  	}
    81  	if !strings.HasPrefix(request.URL.Path, s.path) {
    82  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    83  		return
    84  	}
    85  	if s.method != "" && request.Method != s.method {
    86  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad method: ", request.Method))
    87  		return
    88  	}
    89  
    90  	writer.Header().Set("Cache-Control", "no-store")
    91  
    92  	for key, values := range s.headers {
    93  		for _, value := range values {
    94  			writer.Header().Set(key, value)
    95  		}
    96  	}
    97  
    98  	var metadata M.Metadata
    99  	metadata.Source = sHttp.SourceAddress(request)
   100  	if h, ok := writer.(http.Hijacker); ok {
   101  		var requestBody *buf.Buffer
   102  		if contentLength := int(request.ContentLength); contentLength > 0 {
   103  			requestBody = buf.NewSize(contentLength)
   104  			_, err := requestBody.ReadFullFrom(request.Body, contentLength)
   105  			if err != nil {
   106  				s.invalidRequest(writer, request, 0, E.Cause(err, "read request"))
   107  				return
   108  			}
   109  		}
   110  		writer.WriteHeader(http.StatusOK)
   111  		writer.(http.Flusher).Flush()
   112  		conn, reader, err := h.Hijack()
   113  		if err != nil {
   114  			s.invalidRequest(writer, request, 0, E.Cause(err, "hijack conn"))
   115  			return
   116  		}
   117  		if cacheLen := reader.Reader.Buffered(); cacheLen > 0 {
   118  			cache := buf.NewSize(cacheLen)
   119  			_, err = cache.ReadFullFrom(reader.Reader, cacheLen)
   120  			if err != nil {
   121  				conn.Close()
   122  				s.invalidRequest(writer, request, 0, E.Cause(err, "read cache"))
   123  				return
   124  			}
   125  			conn = bufio.NewCachedConn(conn, cache)
   126  		}
   127  		if requestBody != nil {
   128  			conn = bufio.NewCachedConn(conn, requestBody)
   129  		}
   130  		s.handler.NewConnection(request.Context(), conn, metadata)
   131  	} else {
   132  		writer.WriteHeader(http.StatusOK)
   133  		conn := NewHTTP2Wrapper(&ServerHTTPConn{
   134  			NewHTTPConn(request.Body, writer),
   135  			writer.(http.Flusher),
   136  		})
   137  		s.handler.NewConnection(request.Context(), conn, metadata)
   138  		conn.CloseWrapper()
   139  	}
   140  }
   141  
   142  func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
   143  	if statusCode > 0 {
   144  		writer.WriteHeader(statusCode)
   145  	}
   146  	s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr))
   147  }
   148  
   149  func (s *Server) Network() []string {
   150  	return []string{N.NetworkTCP}
   151  }
   152  
   153  func (s *Server) Serve(listener net.Listener) error {
   154  	if s.tlsConfig != nil {
   155  		if len(s.tlsConfig.NextProtos()) == 0 {
   156  			s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
   157  		} else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
   158  			s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
   159  		}
   160  		listener = aTLS.NewListener(listener, s.tlsConfig)
   161  	}
   162  	return s.httpServer.Serve(listener)
   163  }
   164  
   165  func (s *Server) ServePacket(listener net.PacketConn) error {
   166  	return os.ErrInvalid
   167  }
   168  
   169  func (s *Server) Close() error {
   170  	return common.Close(common.PtrOrNil(s.httpServer))
   171  }