github.com/juliankolbe/go-ethereum@v1.9.992/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 "net/http" 24 "net/url" 25 "os" 26 "strings" 27 "sync" 28 "time" 29 30 mapset "github.com/deckarep/golang-set" 31 "github.com/juliankolbe/go-ethereum/log" 32 "github.com/gorilla/websocket" 33 ) 34 35 const ( 36 wsReadBuffer = 1024 37 wsWriteBuffer = 1024 38 wsPingInterval = 60 * time.Second 39 wsPingWriteTimeout = 5 * time.Second 40 ) 41 42 var wsBufferPool = new(sync.Pool) 43 44 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 45 // 46 // allowedOrigins should be a comma-separated list of allowed origin URLs. 47 // To allow connections with any origin, pass "*". 48 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 49 var upgrader = websocket.Upgrader{ 50 ReadBufferSize: wsReadBuffer, 51 WriteBufferSize: wsWriteBuffer, 52 WriteBufferPool: wsBufferPool, 53 CheckOrigin: wsHandshakeValidator(allowedOrigins), 54 } 55 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 56 conn, err := upgrader.Upgrade(w, r, nil) 57 if err != nil { 58 log.Debug("WebSocket upgrade failed", "err", err) 59 return 60 } 61 codec := newWebsocketCodec(conn) 62 s.ServeCodec(codec, 0) 63 }) 64 } 65 66 // wsHandshakeValidator returns a handler that verifies the origin during the 67 // websocket upgrade process. When a '*' is specified as an allowed origins all 68 // connections are accepted. 69 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 70 origins := mapset.NewSet() 71 allowAllOrigins := false 72 73 for _, origin := range allowedOrigins { 74 if origin == "*" { 75 allowAllOrigins = true 76 } 77 if origin != "" { 78 origins.Add(origin) 79 } 80 } 81 // allow localhost if no allowedOrigins are specified. 82 if len(origins.ToSlice()) == 0 { 83 origins.Add("http://localhost") 84 if hostname, err := os.Hostname(); err == nil { 85 origins.Add("http://" + hostname) 86 } 87 } 88 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 89 90 f := func(req *http.Request) bool { 91 // Skip origin verification if no Origin header is present. The origin check 92 // is supposed to protect against browser based attacks. Browsers always set 93 // Origin. Non-browser software can put anything in origin and checking it doesn't 94 // provide additional security. 95 if _, ok := req.Header["Origin"]; !ok { 96 return true 97 } 98 // Verify origin against whitelist. 99 origin := strings.ToLower(req.Header.Get("Origin")) 100 if allowAllOrigins || originIsAllowed(origins, origin) { 101 return true 102 } 103 log.Warn("Rejected WebSocket connection", "origin", origin) 104 return false 105 } 106 107 return f 108 } 109 110 type wsHandshakeError struct { 111 err error 112 status string 113 } 114 115 func (e wsHandshakeError) Error() string { 116 s := e.err.Error() 117 if e.status != "" { 118 s += " (HTTP status " + e.status + ")" 119 } 120 return s 121 } 122 123 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 124 it := allowedOrigins.Iterator() 125 for origin := range it.C { 126 if ruleAllowsOrigin(origin.(string), browserOrigin) { 127 return true 128 } 129 } 130 return false 131 } 132 133 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 134 var ( 135 allowedScheme, allowedHostname, allowedPort string 136 browserScheme, browserHostname, browserPort string 137 err error 138 ) 139 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 140 if err != nil { 141 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 142 return false 143 } 144 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 145 if err != nil { 146 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 147 return false 148 } 149 if allowedScheme != "" && allowedScheme != browserScheme { 150 return false 151 } 152 if allowedHostname != "" && allowedHostname != browserHostname { 153 return false 154 } 155 if allowedPort != "" && allowedPort != browserPort { 156 return false 157 } 158 return true 159 } 160 161 func parseOriginURL(origin string) (string, string, string, error) { 162 parsedURL, err := url.Parse(strings.ToLower(origin)) 163 if err != nil { 164 return "", "", "", err 165 } 166 var scheme, hostname, port string 167 if strings.Contains(origin, "://") { 168 scheme = parsedURL.Scheme 169 hostname = parsedURL.Hostname() 170 port = parsedURL.Port() 171 } else { 172 scheme = "" 173 hostname = parsedURL.Scheme 174 port = parsedURL.Opaque 175 if hostname == "" { 176 hostname = origin 177 } 178 } 179 return scheme, hostname, port, nil 180 } 181 182 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 183 // that is listening on the given endpoint using the provided dialer. 184 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 185 endpoint, header, err := wsClientHeaders(endpoint, origin) 186 if err != nil { 187 return nil, err 188 } 189 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 190 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 191 if err != nil { 192 hErr := wsHandshakeError{err: err} 193 if resp != nil { 194 hErr.status = resp.Status 195 } 196 return nil, hErr 197 } 198 return newWebsocketCodec(conn), nil 199 }) 200 } 201 202 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 203 // that is listening on the given endpoint. 204 // 205 // The context is used for the initial connection establishment. It does not 206 // affect subsequent interactions with the client. 207 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 208 dialer := websocket.Dialer{ 209 ReadBufferSize: wsReadBuffer, 210 WriteBufferSize: wsWriteBuffer, 211 WriteBufferPool: wsBufferPool, 212 } 213 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 214 } 215 216 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 217 endpointURL, err := url.Parse(endpoint) 218 if err != nil { 219 return endpoint, nil, err 220 } 221 header := make(http.Header) 222 if origin != "" { 223 header.Add("origin", origin) 224 } 225 if endpointURL.User != nil { 226 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 227 header.Add("authorization", "Basic "+b64auth) 228 endpointURL.User = nil 229 } 230 return endpointURL.String(), header, nil 231 } 232 233 type websocketCodec struct { 234 *jsonCodec 235 conn *websocket.Conn 236 237 wg sync.WaitGroup 238 pingReset chan struct{} 239 } 240 241 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 242 conn.SetReadLimit(maxRequestContentLength) 243 wc := &websocketCodec{ 244 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 245 conn: conn, 246 pingReset: make(chan struct{}, 1), 247 } 248 wc.wg.Add(1) 249 go wc.pingLoop() 250 return wc 251 } 252 253 func (wc *websocketCodec) close() { 254 wc.jsonCodec.close() 255 wc.wg.Wait() 256 } 257 258 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 259 err := wc.jsonCodec.writeJSON(ctx, v) 260 if err == nil { 261 // Notify pingLoop to delay the next idle ping. 262 select { 263 case wc.pingReset <- struct{}{}: 264 default: 265 } 266 } 267 return err 268 } 269 270 // pingLoop sends periodic ping frames when the connection is idle. 271 func (wc *websocketCodec) pingLoop() { 272 var timer = time.NewTimer(wsPingInterval) 273 defer wc.wg.Done() 274 defer timer.Stop() 275 276 for { 277 select { 278 case <-wc.closed(): 279 return 280 case <-wc.pingReset: 281 if !timer.Stop() { 282 <-timer.C 283 } 284 timer.Reset(wsPingInterval) 285 case <-timer.C: 286 wc.jsonCodec.encMu.Lock() 287 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 288 wc.conn.WriteMessage(websocket.PingMessage, nil) 289 wc.jsonCodec.encMu.Unlock() 290 timer.Reset(wsPingInterval) 291 } 292 } 293 }