github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 "github.com/xtls/xray-core/common" 16 "github.com/xtls/xray-core/common/net" 17 http_proto "github.com/xtls/xray-core/common/protocol/http" 18 "github.com/xtls/xray-core/common/session" 19 "github.com/xtls/xray-core/transport/internet" 20 v2tls "github.com/xtls/xray-core/transport/internet/tls" 21 ) 22 23 type requestHandler struct { 24 host string 25 path string 26 ln *Listener 27 } 28 29 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 30 31 var upgrader = &websocket.Upgrader{ 32 ReadBufferSize: 0, 33 WriteBufferSize: 0, 34 HandshakeTimeout: time.Second * 4, 35 CheckOrigin: func(r *http.Request) bool { 36 return true 37 }, 38 } 39 40 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 41 if len(h.host) > 0 && request.Host != h.host { 42 newError("failed to validate host, request:", request.Host, ", config:", h.host).WriteToLog() 43 writer.WriteHeader(http.StatusNotFound) 44 return 45 } 46 if request.URL.Path != h.path { 47 newError("failed to validate path, request:", request.URL.Path, ", config:", h.path).WriteToLog() 48 writer.WriteHeader(http.StatusNotFound) 49 return 50 } 51 52 var extraReader io.Reader 53 responseHeader := http.Header{} 54 if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { 55 if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { 56 extraReader = bytes.NewReader(ed) 57 responseHeader.Set("Sec-WebSocket-Protocol", str) 58 } 59 } 60 61 conn, err := upgrader.Upgrade(writer, request, responseHeader) 62 if err != nil { 63 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 64 return 65 } 66 67 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 68 remoteAddr := conn.RemoteAddr() 69 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 70 remoteAddr = &net.TCPAddr{ 71 IP: forwardedAddrs[0].IP(), 72 Port: int(0), 73 } 74 } 75 76 h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) 77 } 78 79 type Listener struct { 80 sync.Mutex 81 server http.Server 82 listener net.Listener 83 config *Config 84 addConn internet.ConnHandler 85 } 86 87 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 88 l := &Listener{ 89 addConn: addConn, 90 } 91 wsSettings := streamSettings.ProtocolSettings.(*Config) 92 l.config = wsSettings 93 if l.config != nil { 94 if streamSettings.SocketSettings == nil { 95 streamSettings.SocketSettings = &internet.SocketConfig{} 96 } 97 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 98 } 99 var listener net.Listener 100 var err error 101 if port == net.Port(0) { // unix 102 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 103 Name: address.Domain(), 104 Net: "unix", 105 }, streamSettings.SocketSettings) 106 if err != nil { 107 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 108 } 109 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 110 } else { // tcp 111 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 112 IP: address.IP(), 113 Port: int(port), 114 }, streamSettings.SocketSettings) 115 if err != nil { 116 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 117 } 118 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 119 } 120 121 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 122 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 123 } 124 125 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 126 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 127 listener = tls.NewListener(listener, tlsConfig) 128 } 129 } 130 131 l.listener = listener 132 133 l.server = http.Server{ 134 Handler: &requestHandler{ 135 host: wsSettings.Host, 136 path: wsSettings.GetNormalizedPath(), 137 ln: l, 138 }, 139 ReadHeaderTimeout: time.Second * 4, 140 MaxHeaderBytes: 8192, 141 } 142 143 go func() { 144 if err := l.server.Serve(l.listener); err != nil { 145 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 146 } 147 }() 148 149 return l, err 150 } 151 152 // Addr implements net.Listener.Addr(). 153 func (ln *Listener) Addr() net.Addr { 154 return ln.listener.Addr() 155 } 156 157 // Close implements net.Listener.Close(). 158 func (ln *Listener) Close() error { 159 return ln.listener.Close() 160 } 161 162 func init() { 163 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 164 }