github.com/moqsien/xraycore@v1.8.5/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 "github.com/moqsien/xraycore/common" 16 "github.com/moqsien/xraycore/common/net" 17 http_proto "github.com/moqsien/xraycore/common/protocol/http" 18 "github.com/moqsien/xraycore/common/session" 19 "github.com/moqsien/xraycore/transport/internet" 20 v2tls "github.com/moqsien/xraycore/transport/internet/tls" 21 ) 22 23 type requestHandler struct { 24 path string 25 ln *Listener 26 } 27 28 var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "") 29 30 var upgrader = &websocket.Upgrader{ 31 ReadBufferSize: 4 * 1024, 32 WriteBufferSize: 4 * 1024, 33 HandshakeTimeout: time.Second * 4, 34 CheckOrigin: func(r *http.Request) bool { 35 return true 36 }, 37 } 38 39 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 40 if request.URL.Path != h.path { 41 writer.WriteHeader(http.StatusNotFound) 42 return 43 } 44 45 var extraReader io.Reader 46 responseHeader := http.Header{} 47 if str := request.Header.Get("Sec-WebSocket-Protocol"); str != "" { 48 if ed, err := base64.RawURLEncoding.DecodeString(replacer.Replace(str)); err == nil && len(ed) > 0 { 49 extraReader = bytes.NewReader(ed) 50 responseHeader.Set("Sec-WebSocket-Protocol", str) 51 } 52 } 53 54 conn, err := upgrader.Upgrade(writer, request, responseHeader) 55 if err != nil { 56 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 57 return 58 } 59 60 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 61 remoteAddr := conn.RemoteAddr() 62 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 63 remoteAddr = &net.TCPAddr{ 64 IP: forwardedAddrs[0].IP(), 65 Port: int(0), 66 } 67 } 68 69 h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) 70 } 71 72 type Listener struct { 73 sync.Mutex 74 server http.Server 75 listener net.Listener 76 config *Config 77 addConn internet.ConnHandler 78 } 79 80 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 81 l := &Listener{ 82 addConn: addConn, 83 } 84 wsSettings := streamSettings.ProtocolSettings.(*Config) 85 l.config = wsSettings 86 if l.config != nil { 87 if streamSettings.SocketSettings == nil { 88 streamSettings.SocketSettings = &internet.SocketConfig{} 89 } 90 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol || streamSettings.SocketSettings.AcceptProxyProtocol 91 } 92 var listener net.Listener 93 var err error 94 if port == net.Port(0) { // unix 95 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 96 Name: address.Domain(), 97 Net: "unix", 98 }, streamSettings.SocketSettings) 99 if err != nil { 100 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 101 } 102 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 103 } else { // tcp 104 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 105 IP: address.IP(), 106 Port: int(port), 107 }, streamSettings.SocketSettings) 108 if err != nil { 109 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 110 } 111 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 112 } 113 114 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 115 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 116 } 117 118 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 119 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 120 listener = tls.NewListener(listener, tlsConfig) 121 } 122 } 123 124 l.listener = listener 125 126 l.server = http.Server{ 127 Handler: &requestHandler{ 128 path: wsSettings.GetNormalizedPath(), 129 ln: l, 130 }, 131 ReadHeaderTimeout: time.Second * 4, 132 MaxHeaderBytes: 4096, 133 } 134 135 go func() { 136 if err := l.server.Serve(l.listener); err != nil { 137 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 138 } 139 }() 140 141 return l, err 142 } 143 144 // Addr implements net.Listener.Addr(). 145 func (ln *Listener) Addr() net.Addr { 146 return ln.listener.Addr() 147 } 148 149 // Close implements net.Listener.Close(). 150 func (ln *Listener) Close() error { 151 return ln.listener.Close() 152 } 153 154 func init() { 155 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 156 }