github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raywebsocket/server.go (about) 1 package v2raywebsocket 2 3 import ( 4 "context" 5 "encoding/base64" 6 "net" 7 "net/http" 8 "os" 9 "strings" 10 11 "github.com/sagernet/sing-box/adapter" 12 "github.com/sagernet/sing-box/common/tls" 13 C "github.com/sagernet/sing-box/constant" 14 "github.com/sagernet/sing-box/option" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 M "github.com/sagernet/sing/common/metadata" 20 N "github.com/sagernet/sing/common/network" 21 aTLS "github.com/sagernet/sing/common/tls" 22 sHttp "github.com/sagernet/sing/protocol/http" 23 "github.com/sagernet/ws" 24 ) 25 26 var _ adapter.V2RayServerTransport = (*Server)(nil) 27 28 type Server struct { 29 ctx context.Context 30 tlsConfig tls.ServerConfig 31 handler adapter.V2RayServerTransportHandler 32 httpServer *http.Server 33 path string 34 maxEarlyData uint32 35 earlyDataHeaderName string 36 upgrader ws.HTTPUpgrader 37 } 38 39 func NewServer(ctx context.Context, options option.V2RayWebsocketOptions, tlsConfig tls.ServerConfig, handler adapter.V2RayServerTransportHandler) (*Server, error) { 40 server := &Server{ 41 ctx: ctx, 42 tlsConfig: tlsConfig, 43 handler: handler, 44 path: options.Path, 45 maxEarlyData: options.MaxEarlyData, 46 earlyDataHeaderName: options.EarlyDataHeaderName, 47 upgrader: ws.HTTPUpgrader{ 48 Timeout: C.TCPTimeout, 49 Header: options.Headers.Build(), 50 }, 51 } 52 if !strings.HasPrefix(server.path, "/") { 53 server.path = "/" + server.path 54 } 55 server.httpServer = &http.Server{ 56 Handler: server, 57 ReadHeaderTimeout: C.TCPTimeout, 58 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 59 BaseContext: func(net.Listener) context.Context { 60 return ctx 61 }, 62 } 63 return server, nil 64 } 65 66 func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 67 if s.maxEarlyData == 0 || s.earlyDataHeaderName != "" { 68 if request.URL.Path != s.path { 69 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 70 return 71 } 72 } 73 var ( 74 earlyData []byte 75 err error 76 conn net.Conn 77 ) 78 if s.earlyDataHeaderName == "" { 79 if strings.HasPrefix(request.URL.RequestURI(), s.path) { 80 earlyDataStr := request.URL.RequestURI()[len(s.path):] 81 earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) 82 } else { 83 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 84 return 85 } 86 } else { 87 if request.URL.Path != s.path { 88 s.invalidRequest(writer, request, http.StatusNotFound, E.New("bad path: ", request.URL.Path)) 89 return 90 } 91 earlyDataStr := request.Header.Get(s.earlyDataHeaderName) 92 if earlyDataStr != "" { 93 earlyData, err = base64.RawURLEncoding.DecodeString(earlyDataStr) 94 } 95 } 96 if err != nil { 97 s.invalidRequest(writer, request, http.StatusBadRequest, E.Cause(err, "decode early data")) 98 return 99 } 100 wsConn, _, _, err := ws.UpgradeHTTP(request, writer) 101 if err != nil { 102 s.invalidRequest(writer, request, 0, E.Cause(err, "upgrade websocket connection")) 103 return 104 } 105 var metadata M.Metadata 106 metadata.Source = sHttp.SourceAddress(request) 107 conn = NewConn(wsConn, metadata.Source.TCPAddr(), ws.StateServerSide) 108 if len(earlyData) > 0 { 109 conn = bufio.NewCachedConn(conn, buf.As(earlyData)) 110 } 111 s.handler.NewConnection(request.Context(), conn, metadata) 112 } 113 114 func (s *Server) invalidRequest(writer http.ResponseWriter, request *http.Request, statusCode int, err error) { 115 if statusCode > 0 { 116 writer.WriteHeader(statusCode) 117 } 118 s.handler.NewError(request.Context(), E.Cause(err, "process connection from ", request.RemoteAddr)) 119 } 120 121 func (s *Server) Network() []string { 122 return []string{N.NetworkTCP} 123 } 124 125 func (s *Server) Serve(listener net.Listener) error { 126 if s.tlsConfig != nil { 127 listener = aTLS.NewListener(listener, s.tlsConfig) 128 } 129 return s.httpServer.Serve(listener) 130 } 131 132 func (s *Server) ServePacket(listener net.PacketConn) error { 133 return os.ErrInvalid 134 } 135 136 func (s *Server) Close() error { 137 return common.Close(common.PtrOrNil(s.httpServer)) 138 }