github.com/ethereum/go-ethereum@v1.14.3/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 "net/http" 24 "net/url" 25 "os" 26 "strings" 27 "sync" 28 "time" 29 30 mapset "github.com/deckarep/golang-set/v2" 31 "github.com/ethereum/go-ethereum/log" 32 "github.com/gorilla/websocket" 33 ) 34 35 const ( 36 wsReadBuffer = 1024 37 wsWriteBuffer = 1024 38 wsPingInterval = 30 * time.Second 39 wsPingWriteTimeout = 5 * time.Second 40 wsPongTimeout = 30 * time.Second 41 wsDefaultReadLimit = 32 * 1024 * 1024 42 ) 43 44 var wsBufferPool = new(sync.Pool) 45 46 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 47 // 48 // allowedOrigins should be a comma-separated list of allowed origin URLs. 49 // To allow connections with any origin, pass "*". 50 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 51 var upgrader = websocket.Upgrader{ 52 ReadBufferSize: wsReadBuffer, 53 WriteBufferSize: wsWriteBuffer, 54 WriteBufferPool: wsBufferPool, 55 CheckOrigin: wsHandshakeValidator(allowedOrigins), 56 } 57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 conn, err := upgrader.Upgrade(w, r, nil) 59 if err != nil { 60 log.Debug("WebSocket upgrade failed", "err", err) 61 return 62 } 63 codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit) 64 s.ServeCodec(codec, 0) 65 }) 66 } 67 68 // wsHandshakeValidator returns a handler that verifies the origin during the 69 // websocket upgrade process. When a '*' is specified as an allowed origins all 70 // connections are accepted. 71 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 72 origins := mapset.NewSet[string]() 73 allowAllOrigins := false 74 75 for _, origin := range allowedOrigins { 76 if origin == "*" { 77 allowAllOrigins = true 78 } 79 if origin != "" { 80 origins.Add(origin) 81 } 82 } 83 // allow localhost if no allowedOrigins are specified. 84 if len(origins.ToSlice()) == 0 { 85 origins.Add("http://localhost") 86 if hostname, err := os.Hostname(); err == nil { 87 origins.Add("http://" + hostname) 88 } 89 } 90 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 91 92 f := func(req *http.Request) bool { 93 // Skip origin verification if no Origin header is present. The origin check 94 // is supposed to protect against browser based attacks. Browsers always set 95 // Origin. Non-browser software can put anything in origin and checking it doesn't 96 // provide additional security. 97 if _, ok := req.Header["Origin"]; !ok { 98 return true 99 } 100 // Verify origin against allow list. 101 origin := strings.ToLower(req.Header.Get("Origin")) 102 if allowAllOrigins || originIsAllowed(origins, origin) { 103 return true 104 } 105 log.Warn("Rejected WebSocket connection", "origin", origin) 106 return false 107 } 108 109 return f 110 } 111 112 type wsHandshakeError struct { 113 err error 114 status string 115 } 116 117 func (e wsHandshakeError) Error() string { 118 s := e.err.Error() 119 if e.status != "" { 120 s += " (HTTP status " + e.status + ")" 121 } 122 return s 123 } 124 125 func (e wsHandshakeError) Unwrap() error { 126 return e.err 127 } 128 129 func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool { 130 it := allowedOrigins.Iterator() 131 for origin := range it.C { 132 if ruleAllowsOrigin(origin, browserOrigin) { 133 return true 134 } 135 } 136 return false 137 } 138 139 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 140 var ( 141 allowedScheme, allowedHostname, allowedPort string 142 browserScheme, browserHostname, browserPort string 143 err error 144 ) 145 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 146 if err != nil { 147 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 148 return false 149 } 150 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 151 if err != nil { 152 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 153 return false 154 } 155 if allowedScheme != "" && allowedScheme != browserScheme { 156 return false 157 } 158 if allowedHostname != "" && allowedHostname != browserHostname { 159 return false 160 } 161 if allowedPort != "" && allowedPort != browserPort { 162 return false 163 } 164 return true 165 } 166 167 func parseOriginURL(origin string) (string, string, string, error) { 168 parsedURL, err := url.Parse(strings.ToLower(origin)) 169 if err != nil { 170 return "", "", "", err 171 } 172 var scheme, hostname, port string 173 if strings.Contains(origin, "://") { 174 scheme = parsedURL.Scheme 175 hostname = parsedURL.Hostname() 176 port = parsedURL.Port() 177 } else { 178 scheme = "" 179 hostname = parsedURL.Scheme 180 port = parsedURL.Opaque 181 if hostname == "" { 182 hostname = origin 183 } 184 } 185 return scheme, hostname, port, nil 186 } 187 188 // DialWebsocketWithDialer creates a new RPC client using WebSocket. 189 // 190 // The context is used for the initial connection establishment. It does not 191 // affect subsequent interactions with the client. 192 // 193 // Deprecated: use DialOptions and the WithWebsocketDialer option. 194 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 195 cfg := new(clientConfig) 196 cfg.wsDialer = &dialer 197 if origin != "" { 198 cfg.setHeader("origin", origin) 199 } 200 connect, err := newClientTransportWS(endpoint, cfg) 201 if err != nil { 202 return nil, err 203 } 204 return newClient(ctx, cfg, connect) 205 } 206 207 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 208 // that is listening on the given endpoint. 209 // 210 // The context is used for the initial connection establishment. It does not 211 // affect subsequent interactions with the client. 212 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 213 cfg := new(clientConfig) 214 if origin != "" { 215 cfg.setHeader("origin", origin) 216 } 217 connect, err := newClientTransportWS(endpoint, cfg) 218 if err != nil { 219 return nil, err 220 } 221 return newClient(ctx, cfg, connect) 222 } 223 224 func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { 225 dialer := cfg.wsDialer 226 if dialer == nil { 227 dialer = &websocket.Dialer{ 228 ReadBufferSize: wsReadBuffer, 229 WriteBufferSize: wsWriteBuffer, 230 WriteBufferPool: wsBufferPool, 231 Proxy: http.ProxyFromEnvironment, 232 } 233 } 234 235 dialURL, header, err := wsClientHeaders(endpoint, "") 236 if err != nil { 237 return nil, err 238 } 239 for key, values := range cfg.httpHeaders { 240 header[key] = values 241 } 242 243 connect := func(ctx context.Context) (ServerCodec, error) { 244 header := header.Clone() 245 if cfg.httpAuth != nil { 246 if err := cfg.httpAuth(header); err != nil { 247 return nil, err 248 } 249 } 250 conn, resp, err := dialer.DialContext(ctx, dialURL, header) 251 if err != nil { 252 hErr := wsHandshakeError{err: err} 253 if resp != nil { 254 hErr.status = resp.Status 255 } 256 return nil, hErr 257 } 258 messageSizeLimit := int64(wsDefaultReadLimit) 259 if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 { 260 messageSizeLimit = *cfg.wsMessageSizeLimit 261 } 262 return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil 263 } 264 return connect, nil 265 } 266 267 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 268 endpointURL, err := url.Parse(endpoint) 269 if err != nil { 270 return endpoint, nil, err 271 } 272 header := make(http.Header) 273 if origin != "" { 274 header.Add("origin", origin) 275 } 276 if endpointURL.User != nil { 277 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 278 header.Add("authorization", "Basic "+b64auth) 279 endpointURL.User = nil 280 } 281 return endpointURL.String(), header, nil 282 } 283 284 type websocketCodec struct { 285 *jsonCodec 286 conn *websocket.Conn 287 info PeerInfo 288 289 wg sync.WaitGroup 290 pingReset chan struct{} 291 pongReceived chan struct{} 292 } 293 294 func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec { 295 conn.SetReadLimit(readLimit) 296 encode := func(v interface{}, isErrorResponse bool) error { 297 return conn.WriteJSON(v) 298 } 299 wc := &websocketCodec{ 300 jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), 301 conn: conn, 302 pingReset: make(chan struct{}, 1), 303 pongReceived: make(chan struct{}), 304 info: PeerInfo{ 305 Transport: "ws", 306 RemoteAddr: conn.RemoteAddr().String(), 307 }, 308 } 309 // Fill in connection details. 310 wc.info.HTTP.Host = host 311 wc.info.HTTP.Origin = req.Get("Origin") 312 wc.info.HTTP.UserAgent = req.Get("User-Agent") 313 // Start pinger. 314 conn.SetPongHandler(func(appData string) error { 315 select { 316 case wc.pongReceived <- struct{}{}: 317 case <-wc.closed(): 318 } 319 return nil 320 }) 321 wc.wg.Add(1) 322 go wc.pingLoop() 323 return wc 324 } 325 326 func (wc *websocketCodec) close() { 327 wc.jsonCodec.close() 328 wc.wg.Wait() 329 } 330 331 func (wc *websocketCodec) peerInfo() PeerInfo { 332 return wc.info 333 } 334 335 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { 336 err := wc.jsonCodec.writeJSON(ctx, v, isError) 337 if err == nil { 338 // Notify pingLoop to delay the next idle ping. 339 select { 340 case wc.pingReset <- struct{}{}: 341 default: 342 } 343 } 344 return err 345 } 346 347 // pingLoop sends periodic ping frames when the connection is idle. 348 func (wc *websocketCodec) pingLoop() { 349 var pingTimer = time.NewTimer(wsPingInterval) 350 defer wc.wg.Done() 351 defer pingTimer.Stop() 352 353 for { 354 select { 355 case <-wc.closed(): 356 return 357 358 case <-wc.pingReset: 359 if !pingTimer.Stop() { 360 <-pingTimer.C 361 } 362 pingTimer.Reset(wsPingInterval) 363 364 case <-pingTimer.C: 365 wc.jsonCodec.encMu.Lock() 366 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 367 wc.conn.WriteMessage(websocket.PingMessage, nil) 368 wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) 369 wc.jsonCodec.encMu.Unlock() 370 pingTimer.Reset(wsPingInterval) 371 372 case <-wc.pongReceived: 373 wc.conn.SetReadDeadline(time.Time{}) 374 } 375 } 376 }