github.com/xraypb/xray-core@v1.6.6/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/xraypb/websocket-1" 15 "github.com/xraypb/xray-core/common" 16 "github.com/xraypb/xray-core/common/net" 17 http_proto "github.com/xraypb/xray-core/common/protocol/http" 18 "github.com/xraypb/xray-core/common/session" 19 "github.com/xraypb/xray-core/transport/internet" 20 v2tls "github.com/xraypb/xray-core/transport/internet/tls" 21 ) 22 23 type requestHandler struct { 24 path string 25 ln *Listener 26 } 27 28 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 29 30 var upgrader = &websocket.Upgrader{ 31 ReadBufferSize: 4 * 1024, 32 WriteBufferSize: 4 * 1024, 33 HandshakeTimeout: time.Second * 4, 34 CheckOrigin: func(r *http.Request) bool { 35 return true 36 }, 37 } 38 39 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 40 if request.URL.Path != h.path { 41 writer.WriteHeader(http.StatusNotFound) 42 return 43 } 44 45 var extraReader io.Reader 46 responseHeader := http.Header{} 47 if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { 48 if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { 49 extraReader = bytes.NewReader(ed) 50 responseHeader.Set("Sec-WebSocket-Protocol", str) 51 } 52 } 53 54 conn, err := upgrader.Upgrade(writer, request, responseHeader) 55 if err != nil { 56 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 57 return 58 } 59 60 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 61 remoteAddr := conn.RemoteAddr() 62 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 63 remoteAddr = &net.TCPAddr{ 64 IP: forwardedAddrs[0].IP(), 65 Port: int(0), 66 } 67 } 68 69 h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) 70 } 71 72 type Listener struct { 73 sync.Mutex 74 server http.Server 75 listener net.Listener 76 config *Config 77 addConn internet.ConnHandler 78 locker *internet.FileLocker // for unix domain socket 79 } 80 81 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 82 l := &Listener{ 83 addConn: addConn, 84 } 85 wsSettings := streamSettings.ProtocolSettings.(*Config) 86 l.config = wsSettings 87 if l.config != nil { 88 if streamSettings.SocketSettings == nil { 89 streamSettings.SocketSettings = &internet.SocketConfig{} 90 } 91 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 92 } 93 var listener net.Listener 94 var err error 95 if port == net.Port(0) { // unix 96 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 97 Name: address.Domain(), 98 Net: "unix", 99 }, streamSettings.SocketSettings) 100 if err != nil { 101 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 102 } 103 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 104 locker := ctx.Value(address.Domain()) 105 if locker != nil { 106 l.locker = locker.(*internet.FileLocker) 107 } 108 } else { // tcp 109 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 110 IP: address.IP(), 111 Port: int(port), 112 }, streamSettings.SocketSettings) 113 if err != nil { 114 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 115 } 116 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 117 } 118 119 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 120 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 121 } 122 123 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 124 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 125 listener = tls.NewListener(listener, tlsConfig) 126 } 127 } 128 129 l.listener = listener 130 131 l.server = http.Server{ 132 Handler: &requestHandler{ 133 path: wsSettings.GetNormalizedPath(), 134 ln: l, 135 }, 136 ReadHeaderTimeout: time.Second * 4, 137 MaxHeaderBytes: 4096, 138 } 139 140 go func() { 141 if err := l.server.Serve(l.listener); err != nil { 142 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 143 } 144 }() 145 146 return l, err 147 } 148 149 // Addr implements net.Listener.Addr(). 150 func (ln *Listener) Addr() net.Addr { 151 return ln.listener.Addr() 152 } 153 154 // Close implements net.Listener.Close(). 155 func (ln *Listener) Close() error { 156 if ln.locker != nil { 157 ln.locker.Release() 158 } 159 return ln.listener.Close() 160 } 161 162 func init() { 163 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 164 }