github.com/imannamdari/v2ray-core/v5@v5.0.5/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 16 "github.com/imannamdari/v2ray-core/v5/common" 17 "github.com/imannamdari/v2ray-core/v5/common/net" 18 http_proto "github.com/imannamdari/v2ray-core/v5/common/protocol/http" 19 "github.com/imannamdari/v2ray-core/v5/common/session" 20 "github.com/imannamdari/v2ray-core/v5/transport/internet" 21 v2tls "github.com/imannamdari/v2ray-core/v5/transport/internet/tls" 22 ) 23 24 type requestHandler struct { 25 path string 26 ln *Listener 27 earlyDataEnabled bool 28 earlyDataHeaderName string 29 } 30 31 var upgrader = &websocket.Upgrader{ 32 ReadBufferSize: 4 * 1024, 33 WriteBufferSize: 4 * 1024, 34 HandshakeTimeout: time.Second * 4, 35 CheckOrigin: func(r *http.Request) bool { 36 return true 37 }, 38 } 39 40 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 41 responseHeader := http.Header{} 42 43 var earlyData io.Reader 44 if !h.earlyDataEnabled { // nolint: gocritic 45 if request.URL.Path != h.path { 46 writer.WriteHeader(http.StatusNotFound) 47 return 48 } 49 } else if h.earlyDataHeaderName != "" { 50 if request.URL.Path != h.path { 51 writer.WriteHeader(http.StatusNotFound) 52 return 53 } 54 earlyDataStr := request.Header.Get(h.earlyDataHeaderName) 55 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 56 if strings.EqualFold("Sec-WebSocket-Protocol", h.earlyDataHeaderName) { 57 responseHeader.Set(h.earlyDataHeaderName, earlyDataStr) 58 } 59 } else { 60 if strings.HasPrefix(request.URL.RequestURI(), h.path) { 61 earlyDataStr := request.URL.RequestURI()[len(h.path):] 62 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 63 } else { 64 writer.WriteHeader(http.StatusNotFound) 65 return 66 } 67 } 68 69 conn, err := upgrader.Upgrade(writer, request, responseHeader) 70 if err != nil { 71 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 72 return 73 } 74 75 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 76 remoteAddr := conn.RemoteAddr() 77 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 78 remoteAddr = &net.TCPAddr{ 79 IP: forwardedAddrs[0].IP(), 80 Port: int(0), 81 } 82 } 83 if earlyData == nil { 84 h.ln.addConn(newConnection(conn, remoteAddr)) 85 } else { 86 h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData)) 87 } 88 } 89 90 type Listener struct { 91 sync.Mutex 92 server http.Server 93 listener net.Listener 94 config *Config 95 addConn internet.ConnHandler 96 locker *internet.FileLocker // for unix domain socket 97 } 98 99 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 100 l := &Listener{ 101 addConn: addConn, 102 } 103 wsSettings := streamSettings.ProtocolSettings.(*Config) 104 l.config = wsSettings 105 if l.config != nil { 106 if streamSettings.SocketSettings == nil { 107 streamSettings.SocketSettings = &internet.SocketConfig{} 108 } 109 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol 110 } 111 var listener net.Listener 112 var err error 113 if port == net.Port(0) { // unix 114 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 115 Name: address.Domain(), 116 Net: "unix", 117 }, streamSettings.SocketSettings) 118 if err != nil { 119 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 120 } 121 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 122 locker := ctx.Value(address.Domain()) 123 if locker != nil { 124 l.locker = locker.(*internet.FileLocker) 125 } 126 } else { // tcp 127 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 128 IP: address.IP(), 129 Port: int(port), 130 }, streamSettings.SocketSettings) 131 if err != nil { 132 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 133 } 134 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 135 } 136 137 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 138 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 139 } 140 141 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 142 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 143 listener = tls.NewListener(listener, tlsConfig) 144 } 145 } 146 147 l.listener = listener 148 useEarlyData := false 149 earlyDataHeaderName := "" 150 if wsSettings.MaxEarlyData != 0 { 151 useEarlyData = true 152 earlyDataHeaderName = wsSettings.EarlyDataHeaderName 153 } 154 155 l.server = http.Server{ 156 Handler: &requestHandler{ 157 path: wsSettings.GetNormalizedPath(), 158 ln: l, 159 earlyDataEnabled: useEarlyData, 160 earlyDataHeaderName: earlyDataHeaderName, 161 }, 162 ReadHeaderTimeout: time.Second * 4, 163 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 164 } 165 166 go func() { 167 if err := l.server.Serve(l.listener); err != nil { 168 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 169 } 170 }() 171 172 return l, err 173 } 174 175 // Addr implements net.Listener.Addr(). 176 func (ln *Listener) Addr() net.Addr { 177 return ln.listener.Addr() 178 } 179 180 // Close implements net.Listener.Close(). 181 func (ln *Listener) Close() error { 182 if ln.locker != nil { 183 ln.locker.Release() 184 } 185 return ln.listener.Close() 186 } 187 188 func init() { 189 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 190 }