gitlab.com/flarenetwork/coreth@v0.1.1/rpc/websocket.go (about) 1 // (c) 2019-2020, Ava Labs, Inc. 2 // 3 // This file is a derived work, based on the go-ethereum library whose original 4 // notices appear below. 5 // 6 // It is distributed under a license compatible with the licensing terms of the 7 // original code from which it is derived. 8 // 9 // Much love to the original authors for their work. 10 // ********** 11 // Copyright 2015 The go-ethereum Authors 12 // This file is part of the go-ethereum library. 13 // 14 // The go-ethereum library is free software: you can redistribute it and/or modify 15 // it under the terms of the GNU Lesser General Public License as published by 16 // the Free Software Foundation, either version 3 of the License, or 17 // (at your option) any later version. 18 // 19 // The go-ethereum library is distributed in the hope that it will be useful, 20 // but WITHOUT ANY WARRANTY; without even the implied warranty of 21 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 // GNU Lesser General Public License for more details. 23 // 24 // You should have received a copy of the GNU Lesser General Public License 25 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 26 27 package rpc 28 29 import ( 30 "context" 31 "encoding/base64" 32 "fmt" 33 "net/http" 34 "net/url" 35 "os" 36 "strings" 37 "sync" 38 "time" 39 40 mapset "github.com/deckarep/golang-set" 41 "github.com/ethereum/go-ethereum/log" 42 "github.com/gorilla/websocket" 43 ) 44 45 const ( 46 wsReadBuffer = 1024 47 wsWriteBuffer = 1024 48 wsPingInterval = 30 * time.Second 49 wsPingWriteTimeout = 5 * time.Second 50 wsMessageSizeLimit = 15 * 1024 * 1024 51 ) 52 53 var wsBufferPool = new(sync.Pool) 54 55 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 56 // 57 // allowedOrigins should be a comma-separated list of allowed origin URLs. 58 // To allow connections with any origin, pass "*". 59 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 60 return s.WebsocketHandlerWithDuration(allowedOrigins, 0) 61 } 62 63 func (s *Server) WebsocketHandlerWithDuration(allowedOrigins []string, apiMaxDuration time.Duration) http.Handler { 64 var upgrader = websocket.Upgrader{ 65 ReadBufferSize: wsReadBuffer, 66 WriteBufferSize: wsWriteBuffer, 67 WriteBufferPool: wsBufferPool, 68 CheckOrigin: wsHandshakeValidator(allowedOrigins), 69 } 70 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 71 conn, err := upgrader.Upgrade(w, r, nil) 72 if err != nil { 73 log.Debug("WebSocket upgrade failed", "err", err) 74 return 75 } 76 codec := newWebsocketCodec(conn) 77 s.ServeCodec(codec, 0, apiMaxDuration) 78 }) 79 } 80 81 // wsHandshakeValidator returns a handler that verifies the origin during the 82 // websocket upgrade process. When a '*' is specified as an allowed origins all 83 // connections are accepted. 84 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 85 origins := mapset.NewSet() 86 allowAllOrigins := false 87 88 for _, origin := range allowedOrigins { 89 if origin == "*" { 90 allowAllOrigins = true 91 } 92 if origin != "" { 93 origins.Add(origin) 94 } 95 } 96 // allow localhost if no allowedOrigins are specified. 97 if len(origins.ToSlice()) == 0 { 98 origins.Add("http://localhost") 99 if hostname, err := os.Hostname(); err == nil { 100 origins.Add("http://" + hostname) 101 } 102 } 103 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 104 105 f := func(req *http.Request) bool { 106 // Skip origin verification if no Origin header is present. The origin check 107 // is supposed to protect against browser based attacks. Browsers always set 108 // Origin. Non-browser software can put anything in origin and checking it doesn't 109 // provide additional security. 110 if _, ok := req.Header["Origin"]; !ok { 111 return true 112 } 113 // Verify origin against allow list. 114 origin := strings.ToLower(req.Header.Get("Origin")) 115 if allowAllOrigins || originIsAllowed(origins, origin) { 116 return true 117 } 118 log.Warn("Rejected WebSocket connection", "origin", origin) 119 return false 120 } 121 122 return f 123 } 124 125 type wsHandshakeError struct { 126 err error 127 status string 128 } 129 130 func (e wsHandshakeError) Error() string { 131 s := e.err.Error() 132 if e.status != "" { 133 s += " (HTTP status " + e.status + ")" 134 } 135 return s 136 } 137 138 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 139 it := allowedOrigins.Iterator() 140 for origin := range it.C { 141 if ruleAllowsOrigin(origin.(string), browserOrigin) { 142 return true 143 } 144 } 145 return false 146 } 147 148 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 149 var ( 150 allowedScheme, allowedHostname, allowedPort string 151 browserScheme, browserHostname, browserPort string 152 err error 153 ) 154 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 155 if err != nil { 156 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 157 return false 158 } 159 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 160 if err != nil { 161 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 162 return false 163 } 164 if allowedScheme != "" && allowedScheme != browserScheme { 165 return false 166 } 167 if allowedHostname != "" && allowedHostname != browserHostname { 168 return false 169 } 170 if allowedPort != "" && allowedPort != browserPort { 171 return false 172 } 173 return true 174 } 175 176 func parseOriginURL(origin string) (string, string, string, error) { 177 parsedURL, err := url.Parse(strings.ToLower(origin)) 178 if err != nil { 179 return "", "", "", err 180 } 181 var scheme, hostname, port string 182 if strings.Contains(origin, "://") { 183 scheme = parsedURL.Scheme 184 hostname = parsedURL.Hostname() 185 port = parsedURL.Port() 186 } else { 187 scheme = "" 188 hostname = parsedURL.Scheme 189 port = parsedURL.Opaque 190 if hostname == "" { 191 hostname = origin 192 } 193 } 194 return scheme, hostname, port, nil 195 } 196 197 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 198 // that is listening on the given endpoint using the provided dialer. 199 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 200 endpoint, header, err := wsClientHeaders(endpoint, origin) 201 if err != nil { 202 return nil, err 203 } 204 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 205 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 206 if err != nil { 207 hErr := wsHandshakeError{err: err} 208 if resp != nil { 209 hErr.status = resp.Status 210 } 211 return nil, hErr 212 } 213 return newWebsocketCodec(conn), nil 214 }) 215 } 216 217 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 218 // that is listening on the given endpoint. 219 // 220 // The context is used for the initial connection establishment. It does not 221 // affect subsequent interactions with the client. 222 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 223 dialer := websocket.Dialer{ 224 ReadBufferSize: wsReadBuffer, 225 WriteBufferSize: wsWriteBuffer, 226 WriteBufferPool: wsBufferPool, 227 } 228 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 229 } 230 231 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 232 endpointURL, err := url.Parse(endpoint) 233 if err != nil { 234 return endpoint, nil, err 235 } 236 header := make(http.Header) 237 if origin != "" { 238 header.Add("origin", origin) 239 } 240 if endpointURL.User != nil { 241 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 242 header.Add("authorization", "Basic "+b64auth) 243 endpointURL.User = nil 244 } 245 return endpointURL.String(), header, nil 246 } 247 248 type websocketCodec struct { 249 *jsonCodec 250 conn *websocket.Conn 251 252 wg sync.WaitGroup 253 pingReset chan struct{} 254 } 255 256 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 257 conn.SetReadLimit(wsMessageSizeLimit) 258 conn.SetPongHandler(func(appData string) error { 259 conn.SetReadDeadline(time.Time{}) 260 return nil 261 }) 262 wc := &websocketCodec{ 263 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 264 conn: conn, 265 pingReset: make(chan struct{}, 1), 266 } 267 wc.wg.Add(1) 268 go wc.pingLoop() 269 return wc 270 } 271 272 func (wc *websocketCodec) close() { 273 wc.jsonCodec.close() 274 wc.wg.Wait() 275 } 276 277 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 278 return wc.writeJSONSkipDeadline(ctx, v, false) 279 } 280 281 func (wc *websocketCodec) writeJSONSkipDeadline(ctx context.Context, v interface{}, skip bool) error { 282 err := wc.jsonCodec.writeJSONSkipDeadline(ctx, v, skip) 283 if err == nil { 284 // Notify pingLoop to delay the next idle ping. 285 select { 286 case wc.pingReset <- struct{}{}: 287 default: 288 } 289 } 290 return err 291 } 292 293 // pingLoop sends periodic ping frames when the connection is idle. 294 func (wc *websocketCodec) pingLoop() { 295 var timer = time.NewTimer(wsPingInterval) 296 defer wc.wg.Done() 297 defer timer.Stop() 298 299 for { 300 select { 301 case <-wc.closed(): 302 return 303 case <-wc.pingReset: 304 if !timer.Stop() { 305 <-timer.C 306 } 307 timer.Reset(wsPingInterval) 308 case <-timer.C: 309 wc.jsonCodec.encMu.Lock() 310 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 311 wc.conn.WriteMessage(websocket.PingMessage, nil) 312 wc.jsonCodec.encMu.Unlock() 313 timer.Reset(wsPingInterval) 314 } 315 } 316 }