github.com/MetalBlockchain/subnet-evm@v0.4.9/rpc/websocket.go (about) 1 // (c) 2019-2020, Ava Labs, Inc. 2 // 3 // This file is a derived work, based on the go-ethereum library whose original 4 // notices appear below. 5 // 6 // It is distributed under a license compatible with the licensing terms of the 7 // original code from which it is derived. 8 // 9 // Much love to the original authors for their work. 10 // ********** 11 // Copyright 2015 The go-ethereum Authors 12 // This file is part of the go-ethereum library. 13 // 14 // The go-ethereum library is free software: you can redistribute it and/or modify 15 // it under the terms of the GNU Lesser General Public License as published by 16 // the Free Software Foundation, either version 3 of the License, or 17 // (at your option) any later version. 18 // 19 // The go-ethereum library is distributed in the hope that it will be useful, 20 // but WITHOUT ANY WARRANTY; without even the implied warranty of 21 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 // GNU Lesser General Public License for more details. 23 // 24 // You should have received a copy of the GNU Lesser General Public License 25 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 26 27 package rpc 28 29 import ( 30 "context" 31 "encoding/base64" 32 "fmt" 33 "net/http" 34 "net/url" 35 "os" 36 "strings" 37 "sync" 38 "time" 39 40 mapset "github.com/deckarep/golang-set" 41 "github.com/ethereum/go-ethereum/log" 42 "github.com/gorilla/websocket" 43 ) 44 45 const ( 46 wsReadBuffer = 1024 47 wsWriteBuffer = 1024 48 wsPingInterval = 30 * time.Second 49 wsPingWriteTimeout = 5 * time.Second 50 wsPongTimeout = 30 * time.Second 51 wsMessageSizeLimit = 15 * 1024 * 1024 52 ) 53 54 var wsBufferPool = new(sync.Pool) 55 56 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 57 // 58 // allowedOrigins should be a comma-separated list of allowed origin URLs. 59 // To allow connections with any origin, pass "*". 60 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 61 return s.WebsocketHandlerWithDuration(allowedOrigins, 0, 0, 0) 62 } 63 64 func (s *Server) WebsocketHandlerWithDuration(allowedOrigins []string, apiMaxDuration, refillRate, maxStored time.Duration) http.Handler { 65 var upgrader = websocket.Upgrader{ 66 ReadBufferSize: wsReadBuffer, 67 WriteBufferSize: wsWriteBuffer, 68 WriteBufferPool: wsBufferPool, 69 CheckOrigin: wsHandshakeValidator(allowedOrigins), 70 } 71 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 72 conn, err := upgrader.Upgrade(w, r, nil) 73 if err != nil { 74 log.Debug("WebSocket upgrade failed", "err", err) 75 return 76 } 77 codec := newWebsocketCodec(conn, r.Host, r.Header) 78 s.ServeCodec(codec, 0, apiMaxDuration, refillRate, maxStored) 79 }) 80 } 81 82 // wsHandshakeValidator returns a handler that verifies the origin during the 83 // websocket upgrade process. When a '*' is specified as an allowed origins all 84 // connections are accepted. 85 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 86 origins := mapset.NewSet() 87 allowAllOrigins := false 88 89 for _, origin := range allowedOrigins { 90 if origin == "*" { 91 allowAllOrigins = true 92 } 93 if origin != "" { 94 origins.Add(origin) 95 } 96 } 97 // allow localhost if no allowedOrigins are specified. 98 if len(origins.ToSlice()) == 0 { 99 origins.Add("http://localhost") 100 if hostname, err := os.Hostname(); err == nil { 101 origins.Add("http://" + hostname) 102 } 103 } 104 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 105 106 f := func(req *http.Request) bool { 107 // Skip origin verification if no Origin header is present. The origin check 108 // is supposed to protect against browser based attacks. Browsers always set 109 // Origin. Non-browser software can put anything in origin and checking it doesn't 110 // provide additional security. 111 if _, ok := req.Header["Origin"]; !ok { 112 return true 113 } 114 // Verify origin against allow list. 115 origin := strings.ToLower(req.Header.Get("Origin")) 116 if allowAllOrigins || originIsAllowed(origins, origin) { 117 return true 118 } 119 log.Warn("Rejected WebSocket connection", "origin", origin) 120 return false 121 } 122 123 return f 124 } 125 126 type wsHandshakeError struct { 127 err error 128 status string 129 } 130 131 func (e wsHandshakeError) Error() string { 132 s := e.err.Error() 133 if e.status != "" { 134 s += " (HTTP status " + e.status + ")" 135 } 136 return s 137 } 138 139 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 140 it := allowedOrigins.Iterator() 141 for origin := range it.C { 142 if ruleAllowsOrigin(origin.(string), browserOrigin) { 143 return true 144 } 145 } 146 return false 147 } 148 149 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 150 var ( 151 allowedScheme, allowedHostname, allowedPort string 152 browserScheme, browserHostname, browserPort string 153 err error 154 ) 155 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 156 if err != nil { 157 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 158 return false 159 } 160 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 161 if err != nil { 162 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 163 return false 164 } 165 if allowedScheme != "" && allowedScheme != browserScheme { 166 return false 167 } 168 if allowedHostname != "" && allowedHostname != browserHostname { 169 return false 170 } 171 if allowedPort != "" && allowedPort != browserPort { 172 return false 173 } 174 return true 175 } 176 177 func parseOriginURL(origin string) (string, string, string, error) { 178 parsedURL, err := url.Parse(strings.ToLower(origin)) 179 if err != nil { 180 return "", "", "", err 181 } 182 var scheme, hostname, port string 183 if strings.Contains(origin, "://") { 184 scheme = parsedURL.Scheme 185 hostname = parsedURL.Hostname() 186 port = parsedURL.Port() 187 } else { 188 scheme = "" 189 hostname = parsedURL.Scheme 190 port = parsedURL.Opaque 191 if hostname == "" { 192 hostname = origin 193 } 194 } 195 return scheme, hostname, port, nil 196 } 197 198 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 199 // that is listening on the given endpoint using the provided dialer. 200 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 201 endpoint, header, err := wsClientHeaders(endpoint, origin) 202 if err != nil { 203 return nil, err 204 } 205 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 206 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 207 if err != nil { 208 hErr := wsHandshakeError{err: err} 209 if resp != nil { 210 hErr.status = resp.Status 211 } 212 return nil, hErr 213 } 214 return newWebsocketCodec(conn, endpoint, header), nil 215 }) 216 } 217 218 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 219 // that is listening on the given endpoint. 220 // 221 // The context is used for the initial connection establishment. It does not 222 // affect subsequent interactions with the client. 223 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 224 dialer := websocket.Dialer{ 225 ReadBufferSize: wsReadBuffer, 226 WriteBufferSize: wsWriteBuffer, 227 WriteBufferPool: wsBufferPool, 228 } 229 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 230 } 231 232 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 233 endpointURL, err := url.Parse(endpoint) 234 if err != nil { 235 return endpoint, nil, err 236 } 237 header := make(http.Header) 238 if origin != "" { 239 header.Add("origin", origin) 240 } 241 if endpointURL.User != nil { 242 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 243 header.Add("authorization", "Basic "+b64auth) 244 endpointURL.User = nil 245 } 246 return endpointURL.String(), header, nil 247 } 248 249 type websocketCodec struct { 250 *jsonCodec 251 conn *websocket.Conn 252 info PeerInfo 253 254 wg sync.WaitGroup 255 pingReset chan struct{} 256 } 257 258 func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { 259 conn.SetReadLimit(wsMessageSizeLimit) 260 conn.SetPongHandler(func(appData string) error { 261 conn.SetReadDeadline(time.Time{}) 262 return nil 263 }) 264 wc := &websocketCodec{ 265 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 266 conn: conn, 267 pingReset: make(chan struct{}, 1), 268 info: PeerInfo{ 269 Transport: "ws", 270 RemoteAddr: conn.RemoteAddr().String(), 271 }, 272 } 273 // Fill in connection details. 274 wc.info.HTTP.Host = host 275 wc.info.HTTP.Origin = req.Get("Origin") 276 wc.info.HTTP.UserAgent = req.Get("User-Agent") 277 // Start pinger. 278 wc.wg.Add(1) 279 go wc.pingLoop() 280 return wc 281 } 282 283 func (wc *websocketCodec) close() { 284 wc.jsonCodec.close() 285 wc.wg.Wait() 286 } 287 288 func (wc *websocketCodec) peerInfo() PeerInfo { 289 return wc.info 290 } 291 292 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 293 return wc.writeJSONSkipDeadline(ctx, v, false) 294 } 295 296 func (wc *websocketCodec) writeJSONSkipDeadline(ctx context.Context, v interface{}, skip bool) error { 297 err := wc.jsonCodec.writeJSONSkipDeadline(ctx, v, skip) 298 if err == nil { 299 // Notify pingLoop to delay the next idle ping. 300 select { 301 case wc.pingReset <- struct{}{}: 302 default: 303 } 304 } 305 return err 306 } 307 308 // pingLoop sends periodic ping frames when the connection is idle. 309 func (wc *websocketCodec) pingLoop() { 310 var timer = time.NewTimer(wsPingInterval) 311 defer wc.wg.Done() 312 defer timer.Stop() 313 314 for { 315 select { 316 case <-wc.closed(): 317 return 318 case <-wc.pingReset: 319 if !timer.Stop() { 320 <-timer.C 321 } 322 timer.Reset(wsPingInterval) 323 case <-timer.C: 324 wc.jsonCodec.encMu.Lock() 325 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 326 wc.conn.WriteMessage(websocket.PingMessage, nil) 327 wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) 328 wc.jsonCodec.encMu.Unlock() 329 timer.Reset(wsPingInterval) 330 } 331 } 332 }