github.com/v2fly/v2ray-core/v4@v4.45.2/transport/internet/websocket/hub.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package websocket 5 6 import ( 7 "bytes" 8 "context" 9 "crypto/tls" 10 "encoding/base64" 11 "io" 12 "net/http" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/gorilla/websocket" 18 19 "github.com/v2fly/v2ray-core/v4/common" 20 "github.com/v2fly/v2ray-core/v4/common/net" 21 http_proto "github.com/v2fly/v2ray-core/v4/common/protocol/http" 22 "github.com/v2fly/v2ray-core/v4/common/session" 23 "github.com/v2fly/v2ray-core/v4/transport/internet" 24 v2tls "github.com/v2fly/v2ray-core/v4/transport/internet/tls" 25 ) 26 27 type requestHandler struct { 28 path string 29 ln *Listener 30 earlyDataEnabled bool 31 earlyDataHeaderName string 32 } 33 34 var upgrader = &websocket.Upgrader{ 35 ReadBufferSize: 4 * 1024, 36 WriteBufferSize: 4 * 1024, 37 HandshakeTimeout: time.Second * 4, 38 CheckOrigin: func(r *http.Request) bool { 39 return true 40 }, 41 } 42 43 func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 44 var earlyData io.Reader 45 if !h.earlyDataEnabled { // nolint: gocritic 46 if request.URL.Path != h.path { 47 writer.WriteHeader(http.StatusNotFound) 48 return 49 } 50 } else if h.earlyDataHeaderName != "" { 51 if request.URL.Path != h.path { 52 writer.WriteHeader(http.StatusNotFound) 53 return 54 } 55 earlyDataStr := request.Header.Get(h.earlyDataHeaderName) 56 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 57 } else { 58 if strings.HasPrefix(request.URL.RequestURI(), h.path) { 59 earlyDataStr := request.URL.RequestURI()[len(h.path):] 60 earlyData = base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader([]byte(earlyDataStr))) 61 } else { 62 writer.WriteHeader(http.StatusNotFound) 63 return 64 } 65 } 66 67 conn, err := upgrader.Upgrade(writer, request, nil) 68 if err != nil { 69 newError("failed to convert to WebSocket connection").Base(err).WriteToLog() 70 return 71 } 72 73 forwardedAddrs := http_proto.ParseXForwardedFor(request.Header) 74 remoteAddr := conn.RemoteAddr() 75 if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() { 76 remoteAddr = &net.TCPAddr{ 77 IP: forwardedAddrs[0].IP(), 78 Port: int(0), 79 } 80 } 81 if earlyData == nil { 82 h.ln.addConn(newConnection(conn, remoteAddr)) 83 } else { 84 h.ln.addConn(newConnectionWithEarlyData(conn, remoteAddr, earlyData)) 85 } 86 } 87 88 type Listener struct { 89 sync.Mutex 90 server http.Server 91 listener net.Listener 92 config *Config 93 addConn internet.ConnHandler 94 locker *internet.FileLocker // for unix domain socket 95 } 96 97 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { 98 l := &Listener{ 99 addConn: addConn, 100 } 101 wsSettings := streamSettings.ProtocolSettings.(*Config) 102 l.config = wsSettings 103 if l.config != nil { 104 if streamSettings.SocketSettings == nil { 105 streamSettings.SocketSettings = &internet.SocketConfig{} 106 } 107 streamSettings.SocketSettings.AcceptProxyProtocol = l.config.AcceptProxyProtocol 108 } 109 var listener net.Listener 110 var err error 111 if port == net.Port(0) { // unix 112 listener, err = internet.ListenSystem(ctx, &net.UnixAddr{ 113 Name: address.Domain(), 114 Net: "unix", 115 }, streamSettings.SocketSettings) 116 if err != nil { 117 return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err) 118 } 119 newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx)) 120 locker := ctx.Value(address.Domain()) 121 if locker != nil { 122 l.locker = locker.(*internet.FileLocker) 123 } 124 } else { // tcp 125 listener, err = internet.ListenSystem(ctx, &net.TCPAddr{ 126 IP: address.IP(), 127 Port: int(port), 128 }, streamSettings.SocketSettings) 129 if err != nil { 130 return nil, newError("failed to listen TCP(for WS) on ", address, ":", port).Base(err) 131 } 132 newError("listening TCP(for WS) on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx)) 133 } 134 135 if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol { 136 newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx)) 137 } 138 139 if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil { 140 if tlsConfig := config.GetTLSConfig(); tlsConfig != nil { 141 listener = tls.NewListener(listener, tlsConfig) 142 } 143 } 144 145 l.listener = listener 146 useEarlyData := false 147 earlyDataHeaderName := "" 148 if wsSettings.MaxEarlyData != 0 { 149 useEarlyData = true 150 earlyDataHeaderName = wsSettings.EarlyDataHeaderName 151 } 152 153 l.server = http.Server{ 154 Handler: &requestHandler{ 155 path: wsSettings.GetNormalizedPath(), 156 ln: l, 157 earlyDataEnabled: useEarlyData, 158 earlyDataHeaderName: earlyDataHeaderName, 159 }, 160 ReadHeaderTimeout: time.Second * 4, 161 MaxHeaderBytes: http.DefaultMaxHeaderBytes, 162 } 163 164 go func() { 165 if err := l.server.Serve(l.listener); err != nil { 166 newError("failed to serve http for WebSocket").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 167 } 168 }() 169 170 return l, err 171 } 172 173 // Addr implements net.Listener.Addr(). 174 func (ln *Listener) Addr() net.Addr { 175 return ln.listener.Addr() 176 } 177 178 // Close implements net.Listener.Close(). 179 func (ln *Listener) Close() error { 180 if ln.locker != nil { 181 ln.locker.Release() 182 } 183 return ln.listener.Close() 184 } 185 186 func init() { 187 common.Must(internet.RegisterTransportListener(protocolName, ListenWS)) 188 }