github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raygrpclite/server.go (about)

     1  package v2raygrpclite
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/inazumav/sing-box/adapter"
    12  	"github.com/inazumav/sing-box/common/tls"
    13  	"github.com/inazumav/sing-box/option"
    14  	"github.com/inazumav/sing-box/transport/v2rayhttp"
    15  	"github.com/sagernet/sing/common"
    16  	E "github.com/sagernet/sing/common/exceptions"
    17  	M "github.com/sagernet/sing/common/metadata"
    18  	N "github.com/sagernet/sing/common/network"
    19  	aTLS "github.com/sagernet/sing/common/tls"
    20  	sHttp "github.com/sagernet/sing/protocol/http"
    21  
    22  	"golang.org/x/net/http2"
    23  	"golang.org/x/net/http2/h2c"
    24  )
    25  
    26  var _ adapter.V2RayServerTransport = (*Server)(nil)
    27  
    28  type Server struct {
    29  	tlsConfig    tls.ServerConfig
    30  	handler      adapter.V2RayServerTransportHandler
    31  	errorHandler E.Handler
    32  	httpServer   *http.Server
    33  	h2Server     *http2.Server
    34  	h2cHandler   http.Handler
    35  	path         string
    36  }
    37  
    38  func (s *Server) Network() []string {
    39  	return []string{N.NetworkTCP}
    40  }
    41  
    42  func NewServer(ctx context.Context, options option.V2RayGRPCOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    43  	server := &Server{
    44  		tlsConfig: tlsConfig,
    45  		handler:   handler,
    46  		path:      "/" + options.ServiceName + "/Tun",
    47  		h2Server: &http2.Server{
    48  			IdleTimeout: time.Duration(options.IdleTimeout),
    49  		},
    50  	}
    51  	server.httpServer = &http.Server{
    52  		Handler: server,
    53  		BaseContext: func(net.Listener) context.Context {
    54  			return ctx
    55  		},
    56  	}
    57  	server.h2cHandler = h2c.NewHandler(server, server.h2Server)
    58  	return server, nil
    59  }
    60  
    61  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    62  	if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" {
    63  		s.h2cHandler.ServeHTTP(writer, request)
    64  		return
    65  	}
    66  	if request.URL.Path != s.path {
    67  		s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    68  		return
    69  	}
    70  	if request.Method != http.MethodPost {
    71  		s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad method: ", request.Method))
    72  		return
    73  	}
    74  	if ct := request.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/grpc") {
    75  		s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad content type: ", ct))
    76  		return
    77  	}
    78  	writer.Header().Set("Content-Type", "application/grpc")
    79  	writer.Header().Set("TE", "trailers")
    80  	writer.WriteHeader(http.StatusOK)
    81  	var metadata M.Metadata
    82  	metadata.Source = sHttp.SourceAddress(request)
    83  	conn := v2rayhttp.NewHTTP2Wrapper(newGunConn(request.Body, writer, writer.(http.Flusher)))
    84  	s.handler.NewConnection(request.Context(), conn, metadata)
    85  	conn.CloseWrapper()
    86  }
    87  
    88  func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
    89  	conn := v2rayhttp.NewHTTPConn(request.Body, writer)
    90  	fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{})
    91  	if fErr == nil {
    92  		return
    93  	} else if fErr == os.ErrInvalid {
    94  		fErr = nil
    95  	}
    96  	if statusCode > 0 {
    97  		writer.WriteHeader(statusCode)
    98  	}
    99  	s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr))
   100  }
   101  
   102  func (s *Server) Serve(listener net.Listener) error {
   103  	if s.tlsConfig != nil {
   104  		if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
   105  			s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
   106  		}
   107  		listener = aTLS.NewListener(listener, s.tlsConfig)
   108  	}
   109  	return s.httpServer.Serve(listener)
   110  }
   111  
   112  func (s *Server) ServePacket(listener net.PacketConn) error {
   113  	return os.ErrInvalid
   114  }
   115  
   116  func (s *Server) Close() error {
   117  	return common.Close(common.PtrOrNil(s.httpServer))
   118  }