github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2rayhttpupgrade/server.go (about) 1 package v2rayhttpupgrade 2 3 import ( 4 "context" 5 "net" 6 "net/http" 7 "os" 8 "strings" 9 10 "github.com/sagernet/sing-box/adapter" 11 "github.com/sagernet/sing-box/common/tls" 12 C "github.com/sagernet/sing-box/constant" 13 "github.com/sagernet/sing-box/option" 14 "github.com/sagernet/sing/common" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 aTLS "github.com/sagernet/sing/common/tls" 19 sHttp "github.com/sagernet/sing/protocol/http" 20 ) 21 22 var _ adapter.V2RayServerTransport = (*Server)(nil) 23 24 type Server struct { 25 ctx context.Context 26 tlsConfig tls.ServerConfig 27 handler adapter.V2RayServerTransportHandler 28 httpServer *http.Server 29 host string 30 path string 31 headers http.Header 32 } 33 34 func NewServer(ctx context.Context, options option.V2RayHTTPUpgradeOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { 35 server := &Server{ 36 ctx: ctx, 37 tlsConfig: tlsConfig, 38 handler: handler, 39 host: options.Host, 40 path: options.Path, 41 headers: options.Headers.Build(), 42 } 43 if !strings.HasPrefix(server.path, "/") { 44 server.path = "/" + server.path 45 } 46 server.httpServer = &http.Server{ 47 Handler: server, 48 ReadHeaderTimeout: C.TCPTimeout, 49 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 50 BaseContext: func(net.Listener) context.Context { 51 return ctx 52 }, 53 TLSNextProto: make(map[string]func(*http.Server, *tls.STDConn, http.Handler)), 54 } 55 return server, nil 56 } 57 58 type httpFlusher interface { 59 FlushError() error 60 } 61 62 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 63 host := request.Host 64 if len(s.host) > 0 && host != s.host { 65 s.invalidRequest(writer, request, http.StatusBadRequest, E.New("bad host: ", host)) 66 return 67 } 68 if request.URL.Path != s.path { 69 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 70 return 71 } 72 if request.Method != http.MethodGet { 73 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad method: ", request.Method)) 74 return 75 } 76 if !strings.EqualFold(request.Header.Get("Connection"), "upgrade") { 77 s.invalidRequest(writer, request, http.StatusNotFound, E.New("not a upgrade request")) 78 return 79 } 80 if !strings.EqualFold(request.Header.Get("Upgrade"), "websocket") { 81 s.invalidRequest(writer, request, http.StatusNotFound, E.New("not a websocket request")) 82 return 83 } 84 if request.Header.Get("Sec-WebSocket-Key") != "" { 85 s.invalidRequest(writer, request, http.StatusNotFound, E.New("real websocket request received")) 86 return 87 } 88 writer.Header().Set("Connection", "upgrade") 89 writer.Header().Set("Upgrade", "websocket") 90 writer.WriteHeader(http.StatusSwitchingProtocols) 91 if flusher, isFlusher := writer.(httpFlusher); isFlusher { 92 err := flusher.FlushError() 93 if err != nil { 94 s.invalidRequest(writer, request, http.StatusInternalServerError, E.New("flush response")) 95 } 96 } 97 hijacker, canHijack := writer.(http.Hijacker) 98 if !canHijack { 99 s.invalidRequest(writer, request, http.StatusInternalServerError, E.New("invalid connection, maybe HTTP/2")) 100 return 101 } 102 conn, _, err := hijacker.Hijack() 103 if err != nil { 104 s.invalidRequest(writer, request, http.StatusInternalServerError, E.Cause(err, "hijack failed")) 105 return 106 } 107 var metadata M.Metadata 108 metadata.Source = sHttp.SourceAddress(request) 109 s.handler.NewConnection(request.Context(), conn, metadata) 110 } 111 112 func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) { 113 if statusCode > 0 { 114 writer.WriteHeader(statusCode) 115 } 116 s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) 117 } 118 119 func (s *Server) Network() []string { 120 return []string{N.NetworkTCP} 121 } 122 123 func (s *Server) Serve(listener net.Listener) error { 124 if s.tlsConfig != nil { 125 if len(s.tlsConfig.NextProtos()) == 0 { 126 s.tlsConfig.SetNextProtos([]string{"http/1.1"}) 127 } 128 listener = aTLS.NewListener(listener, s.tlsConfig) 129 } 130 return s.httpServer.Serve(listener) 131 } 132 133 func (s *Server) ServePacket(listener net.PacketConn) error { 134 return os.ErrInvalid 135 } 136 137 func (s *Server) Close() error { 138 return common.Close(common.PtrOrNil(s.httpServer)) 139 }