github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2rayhttp/server.go (about) 1 package v2rayhttp 2 3 import ( 4 "context" 5 "net" 6 "net/http" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/inazumav/sing-box/adapter" 12 "github.com/inazumav/sing-box/common/tls" 13 C "github.com/inazumav/sing-box/constant" 14 "github.com/inazumav/sing-box/option" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 M "github.com/sagernet/sing/common/metadata" 20 N "github.com/sagernet/sing/common/network" 21 aTLS "github.com/sagernet/sing/common/tls" 22 sHttp "github.com/sagernet/sing/protocol/http" 23 24 "golang.org/x/net/http2" 25 "golang.org/x/net/http2/h2c" 26 ) 27 28 var _ adapter.V2RayServerTransport = (*Server)(nil) 29 30 type Server struct { 31 ctx context.Context 32 tlsConfig tls.ServerConfig 33 handler adapter.V2RayServerTransportHandler 34 httpServer *http.Server 35 h2Server *http2.Server 36 h2cHandler http.Handler 37 host []string 38 path string 39 method string 40 headers http.Header 41 } 42 43 func (s *Server) Network() []string { 44 return []string{N.NetworkTCP} 45 } 46 47 func NewServer(ctx context.Context, options option.V2RayHTTPOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { 48 server := &Server{ 49 ctx: ctx, 50 tlsConfig: tlsConfig, 51 handler: handler, 52 h2Server: &http2.Server{ 53 IdleTimeout: time.Duration(options.IdleTimeout), 54 }, 55 host: options.Host, 56 path: options.Path, 57 method: options.Method, 58 headers: make(http.Header), 59 } 60 if server.method == "" { 61 server.method = "PUT" 62 } 63 if !strings.HasPrefix(server.path, "/") { 64 server.path = "/" + server.path 65 } 66 for key, value := range options.Headers { 67 server.headers[key] = value 68 } 69 server.httpServer = &http.Server{ 70 Handler: server, 71 ReadHeaderTimeout: C.TCPTimeout, 72 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 73 BaseContext: func(net.Listener) context.Context { 74 return ctx 75 }, 76 } 77 server.h2cHandler = h2c.NewHandler(server, server.h2Server) 78 return server, nil 79 } 80 81 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 82 if request.Method == "PRI" && len(request.Header) == 0 && request.URL.Path == "*" && request.Proto == "HTTP/2.0" { 83 s.h2cHandler.ServeHTTP(writer, request) 84 return 85 } 86 host := request.Host 87 if len(s.host) > 0 && !common.Contains(s.host, host) { 88 s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.New("bad host: ", host)) 89 return 90 } 91 if !strings.HasPrefix(request.URL.Path, s.path) { 92 s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 93 return 94 } 95 if request.Method != s.method { 96 s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad method: ", request.Method)) 97 return 98 } 99 100 writer.Header().Set("Cache-Control", "no-store") 101 102 for key, values := range s.headers { 103 for _, value := range values { 104 writer.Header().Set(key, value) 105 } 106 } 107 108 var metadata M.Metadata 109 metadata.Source = sHttp.SourceAddress(request) 110 if h, ok := writer.(http.Hijacker); ok { 111 var requestBody *buf.Buffer 112 if contentLength := int(request.ContentLength); contentLength > 0 { 113 requestBody = buf.NewSize(contentLength) 114 _, err := requestBody.ReadFullFrom(request.Body, contentLength) 115 if err != nil { 116 s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "read request")) 117 return 118 } 119 } 120 writer.WriteHeader(http.StatusOK) 121 writer.(http.Flusher).Flush() 122 conn, reader, err := h.Hijack() 123 if err != nil { 124 s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "hijack conn")) 125 return 126 } 127 if cacheLen := reader.Reader.Buffered(); cacheLen > 0 { 128 cache := buf.NewSize(cacheLen) 129 _, err = cache.ReadFullFrom(reader.Reader, cacheLen) 130 if err != nil { 131 s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "read cache")) 132 return 133 } 134 conn = bufio.NewCachedConn(conn, cache) 135 } 136 if requestBody != nil { 137 conn = bufio.NewCachedConn(conn, requestBody) 138 } 139 s.handler.NewConnection(request.Context(), conn, metadata) 140 } else { 141 writer.WriteHeader(http.StatusOK) 142 conn := NewHTTP2Wrapper(&ServerHTTPConn{ 143 NewHTTPConn(request.Body, writer), 144 writer.(http.Flusher), 145 }) 146 s.handler.NewConnection(request.Context(), conn, metadata) 147 conn.CloseWrapper() 148 } 149 } 150 151 func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) { 152 conn := NewHTTPConn(request.Body, writer) 153 fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{}) 154 if fErr == nil { 155 return 156 } else if fErr == os.ErrInvalid { 157 fErr = nil 158 } 159 if statusCode > 0 { 160 writer.WriteHeader(statusCode) 161 } 162 s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr)) 163 } 164 165 func (s *Server) Serve(listener net.Listener) error { 166 if s.tlsConfig != nil { 167 if len(s.tlsConfig.NextProtos()) == 0 { 168 s.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS, "http/1.1"}) 169 } else if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) { 170 s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...)) 171 } 172 listener = aTLS.NewListener(listener, s.tlsConfig) 173 } 174 return s.httpServer.Serve(listener) 175 } 176 177 func (s *Server) ServePacket(listener net.PacketConn) error { 178 return os.ErrInvalid 179 } 180 181 func (s *Server) Close() error { 182 return common.Close(common.PtrOrNil(s.httpServer)) 183 }