github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2rayhttp/server.go (about) 1 package v2rayhttp 2 3 import ( 4 "context" 5 "net" 6 "net/http" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/sagernet/sing-box/adapter" 12 "github.com/sagernet/sing-box/common/tls" 13 C "github.com/sagernet/sing-box/constant" 14 "github.com/sagernet/sing-box/option" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 M "github.com/sagernet/sing/common/metadata" 20 N "github.com/sagernet/sing/common/network" 21 aTLS "github.com/sagernet/sing/common/tls" 22 sHttp "github.com/sagernet/sing/protocol/http" 23 24 "golang.org/x/net/http2" 25 "golang.org/x/net/http2/h2c" 26 ) 27 28 var _ adapter.V2RayServerTransport = (*Server)(nil) 29 30 type Server struct { 31 ctx context.Context 32 tlsConfig tls.ServerConfig 33 handler adapter.V2RayServerTransportHandler 34 httpServer *http.Server 35 h2Server *http2.Server 36 h2cHandler http.Handler 37 host []string 38 path string 39 method string 40 headers http.Header 41 } 42 43 func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { 44 server := &Server{ 45 ctx: ctx, 46 tlsConfig: tlsConfig, 47 handler: handler, 48 h2Server: &http2.Server{ 49 IdleTimeout: time.Duration(options.IdleTimeout), 50 }, 51 host: options.Host, 52 path: options.Path, 53 method: options.Method, 54 headers: options.Headers.Build(), 55 } 56 if !strings.HasPrefix(server.path, "/") { 57 server.path = "/" + server.path 58 } 59 server.httpServer = &http.Server{ 60 Handler: server, 61 ReadHeaderTimeout: C.TCPTimeout, 62 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 63 BaseContext: func(net.Listener) context.Context { 64 return ctx 65 }, 66 } 67 server.h2cHandler = h2c.NewHandler(server, server.h2Server) 68 return server, nil 69 } 70 71 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 72 if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" { 73 s.h2cHandler.ServeHTTP(writer, request) 74 return 75 } 76 host := request.Host 77 if len(s.host) > 0 && !common.Contains(s.host, host) { 78 s.invalidRequest(writer, request, http.StatusBadRequest, E.New("bad host: ", host)) 79 return 80 } 81 if !strings.HasPrefix(request.URL.Path, s.path) { 82 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 83 return 84 } 85 if s.method != "" && request.Method != s.method { 86 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad method: ", request.Method)) 87 return 88 } 89 90 writer.Header().Set("Cache-Control", "no-store") 91 92 for key, values := range s.headers { 93 for _, value := range values { 94 writer.Header().Set(key, value) 95 } 96 } 97 98 var metadata M.Metadata 99 metadata.Source = sHttp.SourceAddress(request) 100 if h, ok := writer.(http.Hijacker); ok { 101 var requestBody *buf.Buffer 102 if contentLength := int(request.ContentLength); contentLength > 0 { 103 requestBody = buf.NewSize(contentLength) 104 _, err := requestBody.ReadFullFrom(request.Body, contentLength) 105 if err != nil { 106 s.invalidRequest(writer, request, 0, E.Cause(err, "read request")) 107 return 108 } 109 } 110 writer.WriteHeader(http.StatusOK) 111 writer.(http.Flusher).Flush() 112 conn, reader, err := h.Hijack() 113 if err != nil { 114 s.invalidRequest(writer, request, 0, E.Cause(err, "hijack conn")) 115 return 116 } 117 if cacheLen := reader.Reader.Buffered(); cacheLen > 0 { 118 cache := buf.NewSize(cacheLen) 119 _, err = cache.ReadFullFrom(reader.Reader, cacheLen) 120 if err != nil { 121 conn.Close() 122 s.invalidRequest(writer, request, 0, E.Cause(err, "read cache")) 123 return 124 } 125 conn = bufio.NewCachedConn(conn, cache) 126 } 127 if requestBody != nil { 128 conn = bufio.NewCachedConn(conn, requestBody) 129 } 130 s.handler.NewConnection(request.Context(), conn, metadata) 131 } else { 132 writer.WriteHeader(http.StatusOK) 133 conn := NewHTTP2Wrapper(&ServerHTTPConn{ 134 NewHTTPConn(request.Body, writer), 135 writer.(http.Flusher), 136 }) 137 s.handler.NewConnection(request.Context(), conn, metadata) 138 conn.CloseWrapper() 139 } 140 } 141 142 func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) { 143 if statusCode > 0 { 144 writer.WriteHeader(statusCode) 145 } 146 s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) 147 } 148 149 func (s *Server) Network() []string { 150 return []string{N.NetworkTCP} 151 } 152 153 func (s *Server) Serve(listener net.Listener) error { 154 if s.tlsConfig != nil { 155 if len(s.tlsConfig.NextProtos()) == 0 { 156 s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"}) 157 } else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) { 158 s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...)) 159 } 160 listener = aTLS.NewListener(listener, s.tlsConfig) 161 } 162 return s.httpServer.Serve(listener) 163 } 164 165 func (s *Server) ServePacket(listener net.PacketConn) error { 166 return os.ErrInvalid 167 } 168 169 func (s *Server) Close() error { 170 return common.Close(common.PtrOrNil(s.httpServer)) 171 }