github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2rayhttp/server.go (about)

     1  package v2rayhttp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/inazumav/sing-box/adapter"
    12  	"github.com/inazumav/sing-box/common/tls"
    13  	C "github.com/inazumav/sing-box/constant"
    14  	"github.com/inazumav/sing-box/option"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  	aTLS "github.com/sagernet/sing/common/tls"
    22  	sHttp "github.com/sagernet/sing/protocol/http"
    23  
    24  	"golang.org/x/net/http2"
    25  	"golang.org/x/net/http2/h2c"
    26  )
    27  
    28  var _ adapter.V2RayServerTransport = (*Server)(nil)
    29  
    30  type Server struct {
    31  	ctx        context.Context
    32  	tlsConfig  tls.ServerConfig
    33  	handler    adapter.V2RayServerTransportHandler
    34  	httpServer *http.Server
    35  	h2Server   *http2.Server
    36  	h2cHandler http.Handler
    37  	host       []string
    38  	path       string
    39  	method     string
    40  	headers    http.Header
    41  }
    42  
    43  func (s *Server) Network() []string {
    44  	return []string{N.NetworkTCP}
    45  }
    46  
    47  func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    48  	server := &Server{
    49  		ctx:       ctx,
    50  		tlsConfig: tlsConfig,
    51  		handler:   handler,
    52  		h2Server: &http2.Server{
    53  			IdleTimeout: time.Duration(options.IdleTimeout),
    54  		},
    55  		host:    options.Host,
    56  		path:    options.Path,
    57  		method:  options.Method,
    58  		headers: make(http.Header),
    59  	}
    60  	if server.method == "" {
    61  		server.method = "PUT"
    62  	}
    63  	if !strings.HasPrefix(server.path, "/") {
    64  		server.path = "/" + server.path
    65  	}
    66  	for key, value := range options.Headers {
    67  		server.headers[key] = value
    68  	}
    69  	server.httpServer = &http.Server{
    70  		Handler:           server,
    71  		ReadHeaderTimeout: C.TCPTimeout,
    72  		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
    73  		BaseContext: func(net.Listener) context.Context {
    74  			return ctx
    75  		},
    76  	}
    77  	server.h2cHandler = h2c.NewHandler(server, server.h2Server)
    78  	return server, nil
    79  }
    80  
    81  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    82  	if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
    83  		s.h2cHandler.ServeHTTP(writer, request)
    84  		return
    85  	}
    86  	host := request.Host
    87  	if len(s.host) > 0 && !common.Contains(s.host, host) {
    88  		s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.New("bad host: ", host))
    89  		return
    90  	}
    91  	if !strings.HasPrefix(request.URL.Path, s.path) {
    92  		s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    93  		return
    94  	}
    95  	if request.Method != s.method {
    96  		s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad method: ", request.Method))
    97  		return
    98  	}
    99  
   100  	writer.Header().Set("Cache-Control", "no-store")
   101  
   102  	for key, values := range s.headers {
   103  		for _, value := range values {
   104  			writer.Header().Set(key, value)
   105  		}
   106  	}
   107  
   108  	var metadata M.Metadata
   109  	metadata.Source = sHttp.SourceAddress(request)
   110  	if h, ok := writer.(http.Hijacker); ok {
   111  		var requestBody *buf.Buffer
   112  		if contentLength := int(request.ContentLength); contentLength > 0 {
   113  			requestBody = buf.NewSize(contentLength)
   114  			_, err := requestBody.ReadFullFrom(request.Body, contentLength)
   115  			if err != nil {
   116  				s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "read request"))
   117  				return
   118  			}
   119  		}
   120  		writer.WriteHeader(http.StatusOK)
   121  		writer.(http.Flusher).Flush()
   122  		conn, reader, err := h.Hijack()
   123  		if err != nil {
   124  			s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "hijack conn"))
   125  			return
   126  		}
   127  		if cacheLen := reader.Reader.Buffered(); cacheLen > 0 {
   128  			cache := buf.NewSize(cacheLen)
   129  			_, err = cache.ReadFullFrom(reader.Reader, cacheLen)
   130  			if err != nil {
   131  				s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "read cache"))
   132  				return
   133  			}
   134  			conn = bufio.NewCachedConn(conn, cache)
   135  		}
   136  		if requestBody != nil {
   137  			conn = bufio.NewCachedConn(conn, requestBody)
   138  		}
   139  		s.handler.NewConnection(request.Context(), conn, metadata)
   140  	} else {
   141  		writer.WriteHeader(http.StatusOK)
   142  		conn := NewHTTP2Wrapper(&ServerHTTPConn{
   143  			NewHTTPConn(request.Body, writer),
   144  			writer.(http.Flusher),
   145  		})
   146  		s.handler.NewConnection(request.Context(), conn, metadata)
   147  		conn.CloseWrapper()
   148  	}
   149  }
   150  
   151  func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
   152  	conn := NewHTTPConn(request.Body, writer)
   153  	fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{})
   154  	if fErr == nil {
   155  		return
   156  	} else if fErr == os.ErrInvalid {
   157  		fErr = nil
   158  	}
   159  	if statusCode > 0 {
   160  		writer.WriteHeader(statusCode)
   161  	}
   162  	s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr))
   163  }
   164  
   165  func (s *Server) Serve(listener net.Listener) error {
   166  	if s.tlsConfig != nil {
   167  		if len(s.tlsConfig.NextProtos()) == 0 {
   168  			s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"})
   169  		} else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
   170  			s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
   171  		}
   172  		listener = aTLS.NewListener(listener, s.tlsConfig)
   173  	}
   174  	return s.httpServer.Serve(listener)
   175  }
   176  
   177  func (s *Server) ServePacket(listener net.PacketConn) error {
   178  	return os.ErrInvalid
   179  }
   180  
   181  func (s *Server) Close() error {
   182  	return common.Close(common.PtrOrNil(s.httpServer))
   183  }