github.com/eagleql/xray-core@v1.4.4/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 16 "github.com/eagleql/xray-core/common" 17 "github.com/eagleql/xray-core/common/net" 18 http_proto "github.com/eagleql/xray-core/common/protocol/http" 19 "github.com/eagleql/xray-core/common/session" 20 "github.com/eagleql/xray-core/transport/internet" 21 v2tls "github.com/eagleql/xray-core/transport/internet/tls" 22 ) 23 24 type requestHandler struct { 25 path string 26 ln *Listener 27 } 28 29 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 30 31 var upgrader = &websocket.Upgrader{ 32 ReadBufferSize: 4 * 1024, 33 WriteBufferSize: 4 * 1024, 34 HandshakeTimeout: time.Second * 4, 35 CheckOrigin: func(r *http.Request) bool { 36 return true 37 }, 38 } 39 40 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 41 if request.URL.Path != h.path { 42 writer.WriteHeader(http.StatusNotFound) 43 return 44 } 45 46 var extraReader io.Reader 47 var responseHeader = http.Header{} 48 if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { 49 if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { 50 extraReader = bytes.NewReader(ed) 51 responseHeader.Set("Sec-WebSocket-Protocol", str) 52 } 53 } 54 55 conn, err := upgrader.Upgrade(writer, request, responseHeader) 56 if err != nil { 57 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 58 return 59 } 60 61 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 62 remoteAddr := conn.RemoteAddr() 63 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 64 remoteAddr = &net.TCPAddr{ 65 IP: forwardedAddrs[0].IP(), 66 Port: int(0), 67 } 68 } 69 70 h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) 71 } 72 73 type Listener struct { 74 sync.Mutex 75 server http.Server 76 listener net.Listener 77 config *Config 78 addConn internet.ConnHandler 79 locker *internet.FileLocker // for unix domain socket 80 } 81 82 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 83 l := &Listener{ 84 addConn: addConn, 85 } 86 wsSettings := streamSettings.ProtocolSettings.(*Config) 87 l.config = wsSettings 88 if l.config != nil { 89 if streamSettings.SocketSettings == nil { 90 streamSettings.SocketSettings = &internet.SocketConfig{} 91 } 92 streamSettings.SocketSettings.AcceptProxyProtocol = 93 l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 94 } 95 var listener net.Listener 96 var err error 97 if port == net.Port(0) { //unix 98 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 99 Name: address.Domain(), 100 Net: "unix", 101 }, streamSettings.SocketSettings) 102 if err != nil { 103 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 104 } 105 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 106 locker := ctx.Value(address.Domain()) 107 if locker != nil { 108 l.locker = locker.(*internet.FileLocker) 109 } 110 } else { //tcp 111 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 112 IP: address.IP(), 113 Port: int(port), 114 }, streamSettings.SocketSettings) 115 if err != nil { 116 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 117 } 118 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 119 } 120 121 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 122 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 123 } 124 125 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 126 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 127 listener = tls.NewListener(listener, tlsConfig) 128 } 129 } 130 131 l.listener = listener 132 133 l.server = http.Server{ 134 Handler: &requestHandler{ 135 path: wsSettings.GetNormalizedPath(), 136 ln: l, 137 }, 138 ReadHeaderTimeout: time.Second * 4, 139 MaxHeaderBytes: 4096, 140 } 141 142 go func() { 143 if err := l.server.Serve(l.listener); err != nil { 144 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 145 } 146 }() 147 148 return l, err 149 } 150 151 // Addr implements net.Listener.Addr(). 152 func (ln *Listener) Addr() net.Addr { 153 return ln.listener.Addr() 154 } 155 156 // Close implements net.Listener.Close(). 157 func (ln *Listener) Close() error { 158 if ln.locker != nil { 159 ln.locker.Release() 160 } 161 return ln.listener.Close() 162 } 163 164 func init() { 165 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 166 }