github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raywebsocket/server.go (about) 1 package v2raywebsocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "net" 7 "net/http" 8 "os" 9 "strings" 10 11 "github.com/inazumav/sing-box/adapter" 12 "github.com/inazumav/sing-box/common/tls" 13 C "github.com/inazumav/sing-box/constant" 14 "github.com/inazumav/sing-box/option" 15 "github.com/inazumav/sing-box/transport/v2rayhttp" 16 "github.com/sagernet/sing/common" 17 "github.com/sagernet/sing/common/buf" 18 "github.com/sagernet/sing/common/bufio" 19 E "github.com/sagernet/sing/common/exceptions" 20 M "github.com/sagernet/sing/common/metadata" 21 N "github.com/sagernet/sing/common/network" 22 aTLS "github.com/sagernet/sing/common/tls" 23 sHttp "github.com/sagernet/sing/protocol/http" 24 "github.com/sagernet/websocket" 25 ) 26 27 var _ adapter.V2RayServerTransport = (*Server)(nil) 28 29 type Server struct { 30 ctx context.Context 31 tlsConfig tls.ServerConfig 32 handler adapter.V2RayServerTransportHandler 33 httpServer *http.Server 34 path string 35 maxEarlyData uint32 36 earlyDataHeaderName string 37 } 38 39 func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { 40 server := &Server{ 41 ctx: ctx, 42 tlsConfig: tlsConfig, 43 handler: handler, 44 path: options.Path, 45 maxEarlyData: options.MaxEarlyData, 46 earlyDataHeaderName: options.EarlyDataHeaderName, 47 } 48 if !strings.HasPrefix(server.path, "/") { 49 server.path = "/" + server.path 50 } 51 server.httpServer = &http.Server{ 52 Handler: server, 53 ReadHeaderTimeout: C.TCPTimeout, 54 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 55 BaseContext: func(net.Listener) context.Context { 56 return ctx 57 }, 58 } 59 return server, nil 60 } 61 62 var upgrader = websocket.Upgrader{ 63 HandshakeTimeout: C.TCPTimeout, 64 CheckOrigin: func(r *http.Request) bool { 65 return true 66 }, 67 } 68 69 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 70 if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" { 71 if request.URL.Path != s.path { 72 s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 73 return 74 } 75 } 76 var ( 77 earlyData []byte 78 err error 79 conn net.Conn 80 ) 81 if s.earlyDataHeaderName == "" { 82 if strings.HasPrefix(request.URL.RequestURI(), s.path) { 83 earlyDataStr := request.URL.RequestURI()[len(s.path):] 84 earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) 85 } else { 86 s.fallbackRequest(request.Context(), writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 87 return 88 } 89 } else { 90 earlyDataStr := request.Header.Get(s.earlyDataHeaderName) 91 if earlyDataStr != "" { 92 earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) 93 } 94 } 95 if err != nil { 96 s.fallbackRequest(request.Context(), writer, request, http.StatusBadRequest, E.Cause(err, "decode early data")) 97 return 98 } 99 wsConn, err := upgrader.Upgrade(writer, request, nil) 100 if err != nil { 101 s.fallbackRequest(request.Context(), writer, request, 0, E.Cause(err, "upgrade websocket connection")) 102 return 103 } 104 var metadata M.Metadata 105 metadata.Source = sHttp.SourceAddress(request) 106 conn = NewServerConn(wsConn, metadata.Source.TCPAddr()) 107 if len(earlyData) > 0 { 108 conn = bufio.NewCachedConn(conn, buf.As(earlyData)) 109 } 110 s.handler.NewConnection(request.Context(), conn, metadata) 111 } 112 113 func (s *Server) fallbackRequest(ctx context.Context, writer http.ResponseWriter, request *http.Request, statusCode int, err error) { 114 conn := v2rayhttp.NewHTTPConn(request.Body, writer) 115 fErr := s.handler.FallbackConnection(ctx, &conn, M.Metadata{}) 116 if fErr == nil { 117 return 118 } else if fErr == os.ErrInvalid { 119 fErr = nil 120 } 121 if statusCode > 0 { 122 writer.WriteHeader(statusCode) 123 } 124 s.handler.NewError(request.Context(), E.Cause(E.Errors(err, E.Cause(fErr, "fallback connection")), "process connection from ", request.RemoteAddr)) 125 } 126 127 func (s *Server) Network() []string { 128 return []string{N.NetworkTCP} 129 } 130 131 func (s *Server) Serve(listener net.Listener) error { 132 if s.tlsConfig != nil { 133 listener = aTLS.NewListener(listener, s.tlsConfig) 134 } 135 return s.httpServer.Serve(listener) 136 } 137 138 func (s *Server) ServePacket(listener net.PacketConn) error { 139 return os.ErrInvalid 140 } 141 142 func (s *Server) Close() error { 143 return common.Close(common.PtrOrNil(s.httpServer)) 144 }