github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/websocket/hub.go (about) 1 package websocket 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "io" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/gorilla/websocket" 15 16 "github.com/v2fly/v2ray-core/v5/common" 17 "github.com/v2fly/v2ray-core/v5/common/net" 18 http_proto "github.com/v2fly/v2ray-core/v5/common/protocol/http" 19 "github.com/v2fly/v2ray-core/v5/common/session" 20 "github.com/v2fly/v2ray-core/v5/transport/internet" 21 v2tls "github.com/v2fly/v2ray-core/v5/transport/internet/tls" 22 ) 23 24 type requestHandler struct { 25 path string 26 ln *Listener 27 earlyDataEnabled bool 28 earlyDataHeaderName string 29 } 30 31 var upgrader = &websocket.Upgrader{ 32 ReadBufferSize: 4 * 1024, 33 WriteBufferSize: 4 * 1024, 34 HandshakeTimeout: time.Second * 4, 35 CheckOrigin: func(r *http.Request) bool { 36 return true 37 }, 38 } 39 40 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 41 responseHeader := http.Header{} 42 43 var earlyData io.Reader 44 if !h.earlyDataEnabled { // nolint: gocritic 45 if request.URL.Path != h.path { 46 writer.WriteHeader(http.StatusNotFound) 47 return 48 } 49 } else if h.earlyDataHeaderName != "" { 50 if request.URL.Path != h.path { 51 writer.WriteHeader(http.StatusNotFound) 52 return 53 } 54 earlyDataStr := request.Header.Get(h.earlyDataHeaderName) 55 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 56 if strings.EqualFold("Sec-WebSocket-Protocol", h.earlyDataHeaderName) { 57 responseHeader.Set(h.earlyDataHeaderName, earlyDataStr) 58 } 59 } else { 60 if strings.HasPrefix(request.URL.RequestURI(), h.path) { 61 earlyDataStr := request.URL.RequestURI()[len(h.path):] 62 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 63 } else { 64 writer.WriteHeader(http.StatusNotFound) 65 return 66 } 67 } 68 69 conn, err := upgrader.Upgrade(writer, request, responseHeader) 70 if err != nil { 71 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 72 return 73 } 74 75 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 76 remoteAddr := conn.RemoteAddr() 77 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 78 remoteAddr = &net.TCPAddr{ 79 IP: forwardedAddrs[0].IP(), 80 Port: int(0), 81 } 82 } 83 if earlyData == nil { 84 h.ln.addConn(newConnection(conn, remoteAddr)) 85 } else { 86 h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData)) 87 } 88 } 89 90 type Listener struct { 91 sync.Mutex 92 server http.Server 93 listener net.Listener 94 config *Config 95 addConn internet.ConnHandler 96 } 97 98 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 99 l := &Listener{ 100 addConn: addConn, 101 } 102 wsSettings := streamSettings.ProtocolSettings.(*Config) 103 l.config = wsSettings 104 if l.config != nil { 105 if streamSettings.SocketSettings == nil { 106 streamSettings.SocketSettings = &internet.SocketConfig{} 107 } 108 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol 109 } 110 var listener net.Listener 111 var err error 112 if port == net.Port(0) { // unix 113 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 114 Name: address.Domain(), 115 Net: "unix", 116 }, streamSettings.SocketSettings) 117 if err != nil { 118 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 119 } 120 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 121 } else { // tcp 122 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 123 IP: address.IP(), 124 Port: int(port), 125 }, streamSettings.SocketSettings) 126 if err != nil { 127 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 128 } 129 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 130 } 131 132 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 133 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 134 } 135 136 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 137 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 138 listener = tls.NewListener(listener, tlsConfig) 139 } 140 } 141 142 l.listener = listener 143 useEarlyData := false 144 earlyDataHeaderName := "" 145 if wsSettings.MaxEarlyData != 0 { 146 useEarlyData = true 147 earlyDataHeaderName = wsSettings.EarlyDataHeaderName 148 } 149 150 l.server = http.Server{ 151 Handler: &requestHandler{ 152 path: wsSettings.GetNormalizedPath(), 153 ln: l, 154 earlyDataEnabled: useEarlyData, 155 earlyDataHeaderName: earlyDataHeaderName, 156 }, 157 ReadHeaderTimeout: time.Second * 4, 158 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 159 } 160 161 go func() { 162 if err := l.server.Serve(l.listener); err != nil { 163 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 164 } 165 }() 166 167 return l, err 168 } 169 170 // Addr implements net.Listener.Addr(). 171 func (ln *Listener) Addr() net.Addr { 172 return ln.listener.Addr() 173 } 174 175 // Close implements net.Listener.Close(). 176 func (ln *Listener) Close() error { 177 return ln.listener.Close() 178 } 179 180 func init() { 181 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 182 }