github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2rayhttpupgrade/server.go (about)

     1  package v2rayhttpupgrade
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  
    10  	"github.com/sagernet/sing-box/adapter"
    11  	"github.com/sagernet/sing-box/common/tls"
    12  	C "github.com/sagernet/sing-box/constant"
    13  	"github.com/sagernet/sing-box/option"
    14  	"github.com/sagernet/sing/common"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	aTLS "github.com/sagernet/sing/common/tls"
    19  	sHttp "github.com/sagernet/sing/protocol/http"
    20  )
    21  
    22  var _ adapter.V2RayServerTransport = (*Server)(nil)
    23  
    24  type Server struct {
    25  	ctx        context.Context
    26  	tlsConfig  tls.ServerConfig
    27  	handler    adapter.V2RayServerTransportHandler
    28  	httpServer *http.Server
    29  	host       string
    30  	path       string
    31  	headers    http.Header
    32  }
    33  
    34  func NewServer(ctx context.Context, options option.V2RayHTTPUpgradeOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) {
    35  	server := &Server{
    36  		ctx:       ctx,
    37  		tlsConfig: tlsConfig,
    38  		handler:   handler,
    39  		host:      options.Host,
    40  		path:      options.Path,
    41  		headers:   options.Headers.Build(),
    42  	}
    43  	if !strings.HasPrefix(server.path, "/") {
    44  		server.path = "/" + server.path
    45  	}
    46  	server.httpServer = &http.Server{
    47  		Handler:           server,
    48  		ReadHeaderTimeout: C.TCPTimeout,
    49  		MaxHeaderBytes:    http.DefaultMaxHeaderBytes,
    50  		BaseContext: func(net.Listener) context.Context {
    51  			return ctx
    52  		},
    53  		TLSNextProto: make(map[string]func(*http.Server, *tls.STDConn, http.Handler)),
    54  	}
    55  	return server, nil
    56  }
    57  
    58  type httpFlusher interface {
    59  	FlushError() error
    60  }
    61  
    62  func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    63  	host := request.Host
    64  	if len(s.host) > 0 && host != s.host {
    65  		s.invalidRequest(writer, request, http.StatusBadRequest, E.New("bad host: ", host))
    66  		return
    67  	}
    68  	if request.URL.Path != s.path {
    69  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path))
    70  		return
    71  	}
    72  	if request.Method != http.MethodGet {
    73  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad method: ", request.Method))
    74  		return
    75  	}
    76  	if !strings.EqualFold(request.Header.Get("Connection"), "upgrade") {
    77  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("not a upgrade request"))
    78  		return
    79  	}
    80  	if !strings.EqualFold(request.Header.Get("Upgrade"), "websocket") {
    81  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("not a websocket request"))
    82  		return
    83  	}
    84  	if request.Header.Get("Sec-WebSocket-Key") != "" {
    85  		s.invalidRequest(writer, request, http.StatusNotFound, E.New("real websocket request received"))
    86  		return
    87  	}
    88  	writer.Header().Set("Connection", "upgrade")
    89  	writer.Header().Set("Upgrade", "websocket")
    90  	writer.WriteHeader(http.StatusSwitchingProtocols)
    91  	if flusher, isFlusher := writer.(httpFlusher); isFlusher {
    92  		err := flusher.FlushError()
    93  		if err != nil {
    94  			s.invalidRequest(writer, request, http.StatusInternalServerError, E.New("flush response"))
    95  		}
    96  	}
    97  	hijacker, canHijack := writer.(http.Hijacker)
    98  	if !canHijack {
    99  		s.invalidRequest(writer, request, http.StatusInternalServerError, E.New("invalid connection, maybe HTTP/2"))
   100  		return
   101  	}
   102  	conn, _, err := hijacker.Hijack()
   103  	if err != nil {
   104  		s.invalidRequest(writer, request, http.StatusInternalServerError, E.Cause(err, "hijack failed"))
   105  		return
   106  	}
   107  	var metadata M.Metadata
   108  	metadata.Source = sHttp.SourceAddress(request)
   109  	s.handler.NewConnection(request.Context(), conn, metadata)
   110  }
   111  
   112  func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) {
   113  	if statusCode > 0 {
   114  		writer.WriteHeader(statusCode)
   115  	}
   116  	s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr))
   117  }
   118  
   119  func (s *Server) Network() []string {
   120  	return []string{N.NetworkTCP}
   121  }
   122  
   123  func (s *Server) Serve(listener net.Listener) error {
   124  	if s.tlsConfig != nil {
   125  		if len(s.tlsConfig.NextProtos()) == 0 {
   126  			s.tlsConfig.SetNextProtos([]string{"http/1.1"})
   127  		}
   128  		listener = aTLS.NewListener(listener, s.tlsConfig)
   129  	}
   130  	return s.httpServer.Serve(listener)
   131  }
   132  
   133  func (s *Server) ServePacket(listener net.PacketConn) error {
   134  	return os.ErrInvalid
   135  }
   136  
   137  func (s *Server) Close() error {
   138  	return common.Close(common.PtrOrNil(s.httpServer))
   139  }