github.com/ava-labs/subnet-evm@v0.6.4/rpc/websocket.go (about) 1 // (c) 2019-2020, Ava Labs, Inc. 2 // 3 // This file is a derived work, based on the go-ethereum library whose original 4 // notices appear below. 5 // 6 // It is distributed under a license compatible with the licensing terms of the 7 // original code from which it is derived. 8 // 9 // Much love to the original authors for their work. 10 // ********** 11 // Copyright 2015 The go-ethereum Authors 12 // This file is part of the go-ethereum library. 13 // 14 // The go-ethereum library is free software: you can redistribute it and/or modify 15 // it under the terms of the GNU Lesser General Public License as published by 16 // the Free Software Foundation, either version 3 of the License, or 17 // (at your option) any later version. 18 // 19 // The go-ethereum library is distributed in the hope that it will be useful, 20 // but WITHOUT ANY WARRANTY; without even the implied warranty of 21 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 // GNU Lesser General Public License for more details. 23 // 24 // You should have received a copy of the GNU Lesser General Public License 25 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 26 27 package rpc 28 29 import ( 30 "context" 31 "encoding/base64" 32 "fmt" 33 "net/http" 34 "net/url" 35 "os" 36 "strings" 37 "sync" 38 "time" 39 40 mapset "github.com/deckarep/golang-set/v2" 41 "github.com/ethereum/go-ethereum/log" 42 "github.com/gorilla/websocket" 43 ) 44 45 const ( 46 wsReadBuffer = 1024 47 wsWriteBuffer = 1024 48 wsPingInterval = 30 * time.Second 49 wsPingWriteTimeout = 5 * time.Second 50 wsPongTimeout = 30 * time.Second 51 wsMessageSizeLimit = 32 * 1024 * 1024 52 ) 53 54 var wsBufferPool = new(sync.Pool) 55 56 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 57 // 58 // allowedOrigins should be a comma-separated list of allowed origin URLs. 59 // To allow connections with any origin, pass "*". 60 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 61 return s.WebsocketHandlerWithDuration(allowedOrigins, 0, 0, 0) 62 } 63 64 func (s *Server) WebsocketHandlerWithDuration(allowedOrigins []string, apiMaxDuration, refillRate, maxStored time.Duration) http.Handler { 65 var upgrader = websocket.Upgrader{ 66 ReadBufferSize: wsReadBuffer, 67 WriteBufferSize: wsWriteBuffer, 68 WriteBufferPool: wsBufferPool, 69 CheckOrigin: wsHandshakeValidator(allowedOrigins), 70 } 71 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 72 conn, err := upgrader.Upgrade(w, r, nil) 73 if err != nil { 74 log.Debug("WebSocket upgrade failed", "err", err) 75 return 76 } 77 codec := newWebsocketCodec(conn, r.Host, r.Header) 78 s.ServeCodec(codec, 0, apiMaxDuration, refillRate, maxStored) 79 }) 80 } 81 82 // wsHandshakeValidator returns a handler that verifies the origin during the 83 // websocket upgrade process. When a '*' is specified as an allowed origins all 84 // connections are accepted. 85 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 86 origins := mapset.NewSet[string]() 87 allowAllOrigins := false 88 89 for _, origin := range allowedOrigins { 90 if origin == "*" { 91 allowAllOrigins = true 92 } 93 if origin != "" { 94 origins.Add(origin) 95 } 96 } 97 // allow localhost if no allowedOrigins are specified. 98 if len(origins.ToSlice()) == 0 { 99 origins.Add("http://localhost") 100 if hostname, err := os.Hostname(); err == nil { 101 origins.Add("http://" + hostname) 102 } 103 } 104 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 105 106 f := func(req *http.Request) bool { 107 // Skip origin verification if no Origin header is present. The origin check 108 // is supposed to protect against browser based attacks. Browsers always set 109 // Origin. Non-browser software can put anything in origin and checking it doesn't 110 // provide additional security. 111 if _, ok := req.Header["Origin"]; !ok { 112 return true 113 } 114 // Verify origin against allow list. 115 origin := strings.ToLower(req.Header.Get("Origin")) 116 if allowAllOrigins || originIsAllowed(origins, origin) { 117 return true 118 } 119 log.Warn("Rejected WebSocket connection", "origin", origin) 120 return false 121 } 122 123 return f 124 } 125 126 type wsHandshakeError struct { 127 err error 128 status string 129 } 130 131 func (e wsHandshakeError) Error() string { 132 s := e.err.Error() 133 if e.status != "" { 134 s += " (HTTP status " + e.status + ")" 135 } 136 return s 137 } 138 139 func originIsAllowed(allowedOrigins mapset.Set[string], browserOrigin string) bool { 140 it := allowedOrigins.Iterator() 141 for origin := range it.C { 142 if ruleAllowsOrigin(origin, browserOrigin) { 143 return true 144 } 145 } 146 return false 147 } 148 149 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 150 var ( 151 allowedScheme, allowedHostname, allowedPort string 152 browserScheme, browserHostname, browserPort string 153 err error 154 ) 155 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 156 if err != nil { 157 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 158 return false 159 } 160 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 161 if err != nil { 162 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 163 return false 164 } 165 if allowedScheme != "" && allowedScheme != browserScheme { 166 return false 167 } 168 if allowedHostname != "" && allowedHostname != browserHostname { 169 return false 170 } 171 if allowedPort != "" && allowedPort != browserPort { 172 return false 173 } 174 return true 175 } 176 177 func parseOriginURL(origin string) (string, string, string, error) { 178 parsedURL, err := url.Parse(strings.ToLower(origin)) 179 if err != nil { 180 return "", "", "", err 181 } 182 var scheme, hostname, port string 183 if strings.Contains(origin, "://") { 184 scheme = parsedURL.Scheme 185 hostname = parsedURL.Hostname() 186 port = parsedURL.Port() 187 } else { 188 scheme = "" 189 hostname = parsedURL.Scheme 190 port = parsedURL.Opaque 191 if hostname == "" { 192 hostname = origin 193 } 194 } 195 return scheme, hostname, port, nil 196 } 197 198 // DialWebsocketWithDialer creates a new RPC client using WebSocket. 199 // 200 // The context is used for the initial connection establishment. It does not 201 // affect subsequent interactions with the client. 202 // 203 // Deprecated: use DialOptions and the WithWebsocketDialer option. 204 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 205 cfg := new(clientConfig) 206 cfg.wsDialer = &dialer 207 if origin != "" { 208 cfg.setHeader("origin", origin) 209 } 210 connect, err := newClientTransportWS(endpoint, cfg) 211 if err != nil { 212 return nil, err 213 } 214 return newClient(ctx, cfg, connect) 215 } 216 217 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 218 // that is listening on the given endpoint. 219 // 220 // The context is used for the initial connection establishment. It does not 221 // affect subsequent interactions with the client. 222 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 223 cfg := new(clientConfig) 224 if origin != "" { 225 cfg.setHeader("origin", origin) 226 } 227 connect, err := newClientTransportWS(endpoint, cfg) 228 if err != nil { 229 return nil, err 230 } 231 return newClient(ctx, cfg, connect) 232 } 233 234 func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { 235 dialer := cfg.wsDialer 236 if dialer == nil { 237 dialer = &websocket.Dialer{ 238 ReadBufferSize: wsReadBuffer, 239 WriteBufferSize: wsWriteBuffer, 240 WriteBufferPool: wsBufferPool, 241 Proxy: http.ProxyFromEnvironment, 242 } 243 } 244 245 dialURL, header, err := wsClientHeaders(endpoint, "") 246 if err != nil { 247 return nil, err 248 } 249 for key, values := range cfg.httpHeaders { 250 header[key] = values 251 } 252 253 connect := func(ctx context.Context) (ServerCodec, error) { 254 header := header.Clone() 255 if cfg.httpAuth != nil { 256 if err := cfg.httpAuth(header); err != nil { 257 return nil, err 258 } 259 } 260 conn, resp, err := dialer.DialContext(ctx, dialURL, header) 261 if err != nil { 262 hErr := wsHandshakeError{err: err} 263 if resp != nil { 264 hErr.status = resp.Status 265 } 266 return nil, hErr 267 } 268 return newWebsocketCodec(conn, dialURL, header), nil 269 } 270 return connect, nil 271 } 272 273 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 274 endpointURL, err := url.Parse(endpoint) 275 if err != nil { 276 return endpoint, nil, err 277 } 278 header := make(http.Header) 279 if origin != "" { 280 header.Add("origin", origin) 281 } 282 if endpointURL.User != nil { 283 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 284 header.Add("authorization", "Basic "+b64auth) 285 endpointURL.User = nil 286 } 287 return endpointURL.String(), header, nil 288 } 289 290 type websocketCodec struct { 291 *jsonCodec 292 conn *websocket.Conn 293 info PeerInfo 294 295 wg sync.WaitGroup 296 pingReset chan struct{} 297 pongReceived chan struct{} 298 } 299 300 func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { 301 conn.SetReadLimit(wsMessageSizeLimit) 302 encode := func(v interface{}, isErrorResponse bool) error { 303 return conn.WriteJSON(v) 304 } 305 wc := &websocketCodec{ 306 jsonCodec: NewFuncCodec(conn, encode, conn.ReadJSON).(*jsonCodec), 307 conn: conn, 308 pingReset: make(chan struct{}, 1), 309 pongReceived: make(chan struct{}), 310 info: PeerInfo{ 311 Transport: "ws", 312 RemoteAddr: conn.RemoteAddr().String(), 313 }, 314 } 315 // Fill in connection details. 316 wc.info.HTTP.Host = host 317 wc.info.HTTP.Origin = req.Get("Origin") 318 wc.info.HTTP.UserAgent = req.Get("User-Agent") 319 // Start pinger. 320 conn.SetPongHandler(func(appData string) error { 321 select { 322 case wc.pongReceived <- struct{}{}: 323 case <-wc.closed(): 324 } 325 return nil 326 }) 327 wc.wg.Add(1) 328 go wc.pingLoop() 329 return wc 330 } 331 332 func (wc *websocketCodec) close() { 333 wc.jsonCodec.close() 334 wc.wg.Wait() 335 } 336 337 func (wc *websocketCodec) peerInfo() PeerInfo { 338 return wc.info 339 } 340 341 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}, isError bool) error { 342 return wc.writeJSONSkipDeadline(ctx, v, isError, false) 343 } 344 345 func (wc *websocketCodec) writeJSONSkipDeadline(ctx context.Context, v interface{}, isError bool, skip bool) error { 346 err := wc.jsonCodec.writeJSONSkipDeadline(ctx, v, isError, skip) 347 if err == nil { 348 // Notify pingLoop to delay the next idle ping. 349 select { 350 case wc.pingReset <- struct{}{}: 351 default: 352 } 353 } 354 return err 355 } 356 357 // pingLoop sends periodic ping frames when the connection is idle. 358 func (wc *websocketCodec) pingLoop() { 359 var pingTimer = time.NewTimer(wsPingInterval) 360 defer wc.wg.Done() 361 defer pingTimer.Stop() 362 363 for { 364 select { 365 case <-wc.closed(): 366 return 367 368 case <-wc.pingReset: 369 if !pingTimer.Stop() { 370 <-pingTimer.C 371 } 372 pingTimer.Reset(wsPingInterval) 373 374 case <-pingTimer.C: 375 wc.jsonCodec.encMu.Lock() 376 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 377 wc.conn.WriteMessage(websocket.PingMessage, nil) 378 wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) 379 wc.jsonCodec.encMu.Unlock() 380 pingTimer.Reset(wsPingInterval) 381 382 case <-wc.pongReceived: 383 wc.conn.SetReadDeadline(time.Time{}) 384 } 385 } 386 }