github.com/Unheilbar/quorum@v1.0.0/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "crypto/tls" 22 "encoding/base64" 23 "fmt" 24 "net/http" 25 "net/url" 26 "os" 27 "strings" 28 "sync" 29 "time" 30 31 mapset "github.com/deckarep/golang-set" 32 "github.com/ethereum/go-ethereum/log" 33 "github.com/gorilla/websocket" 34 ) 35 36 const ( 37 wsReadBuffer = 1024 38 wsWriteBuffer = 1024 39 wsPingInterval = 60 * time.Second 40 wsPingWriteTimeout = 5 * time.Second 41 wsMessageSizeLimit = 15 * 1024 * 1024 42 ) 43 44 var wsBufferPool = new(sync.Pool) 45 46 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 47 // 48 // allowedOrigins should be a comma-separated list of allowed origin URLs. 49 // To allow connections with any origin, pass "*". 50 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 51 var upgrader = websocket.Upgrader{ 52 ReadBufferSize: wsReadBuffer, 53 WriteBufferSize: wsWriteBuffer, 54 WriteBufferPool: wsBufferPool, 55 CheckOrigin: wsHandshakeValidator(allowedOrigins), 56 } 57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 conn, err := upgrader.Upgrade(w, r, nil) 59 if err != nil { 60 log.Debug("WebSocket upgrade failed", "err", err) 61 return 62 } 63 codec := newWebsocketCodec(conn) 64 s.authenticateHttpRequest(r, codec) 65 s.ServeCodec(codec, 0) 66 }) 67 } 68 69 // wsHandshakeValidator returns a handler that verifies the origin during the 70 // websocket upgrade process. When a '*' is specified as an allowed origins all 71 // connections are accepted. 72 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 73 origins := mapset.NewSet() 74 allowAllOrigins := false 75 76 for _, origin := range allowedOrigins { 77 if origin == "*" { 78 allowAllOrigins = true 79 } 80 if origin != "" { 81 origins.Add(origin) 82 } 83 } 84 // allow localhost if no allowedOrigins are specified. 85 if len(origins.ToSlice()) == 0 { 86 origins.Add("http://localhost") 87 origins.Add("https://localhost") 88 if hostname, err := os.Hostname(); err == nil { 89 origins.Add("http://" + hostname) 90 origins.Add("https://" + hostname) 91 } 92 } 93 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 94 95 f := func(req *http.Request) bool { 96 // Skip origin verification if no Origin header is present. The origin check 97 // is supposed to protect against browser based attacks. Browsers always set 98 // Origin. Non-browser software can put anything in origin and checking it doesn't 99 // provide additional security. 100 if _, ok := req.Header["Origin"]; !ok { 101 return true 102 } 103 // Verify origin against whitelist. 104 origin := strings.ToLower(req.Header.Get("Origin")) 105 if allowAllOrigins || originIsAllowed(origins, origin) { 106 return true 107 } 108 log.Warn("Rejected WebSocket connection", "origin", origin) 109 return false 110 } 111 112 return f 113 } 114 115 type wsHandshakeError struct { 116 err error 117 status string 118 } 119 120 func (e wsHandshakeError) Error() string { 121 s := e.err.Error() 122 if e.status != "" { 123 s += " (HTTP status " + e.status + ")" 124 } 125 return s 126 } 127 128 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 129 it := allowedOrigins.Iterator() 130 for origin := range it.C { 131 if ruleAllowsOrigin(origin.(string), browserOrigin) { 132 return true 133 } 134 } 135 return false 136 } 137 138 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 139 var ( 140 allowedScheme, allowedHostname, allowedPort string 141 browserScheme, browserHostname, browserPort string 142 err error 143 ) 144 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 145 if err != nil { 146 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 147 return false 148 } 149 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 150 if err != nil { 151 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 152 return false 153 } 154 if allowedScheme != "" && allowedScheme != browserScheme { 155 return false 156 } 157 if allowedHostname != "" && allowedHostname != browserHostname { 158 return false 159 } 160 if allowedPort != "" && allowedPort != browserPort { 161 return false 162 } 163 return true 164 } 165 166 func parseOriginURL(origin string) (string, string, string, error) { 167 parsedURL, err := url.Parse(strings.ToLower(origin)) 168 if err != nil { 169 return "", "", "", err 170 } 171 var scheme, hostname, port string 172 if strings.Contains(origin, "://") { 173 scheme = parsedURL.Scheme 174 hostname = parsedURL.Hostname() 175 port = parsedURL.Port() 176 } else { 177 scheme = "" 178 hostname = parsedURL.Scheme 179 port = parsedURL.Opaque 180 if hostname == "" { 181 hostname = origin 182 } 183 } 184 return scheme, hostname, port, nil 185 } 186 187 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 188 // that is listening on the given endpoint using the provided dialer. 189 // 190 // The context is used for the initial connection establishment. It does not 191 // affect subsequent interactions with the client. 192 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 193 return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil) 194 } 195 196 // Quorum 197 // 198 // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server 199 // that is listening on the given endpoint. 200 // At the same time, allowing to customize TLSClientConfig of the dialer 201 // 202 // The context is used for the initial connection establishment. It does not 203 // affect subsequent interactions with the client. 204 func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) { 205 dialer := websocket.Dialer{ 206 ReadBufferSize: wsReadBuffer, 207 WriteBufferSize: wsWriteBuffer, 208 WriteBufferPool: wsBufferPool, 209 } 210 211 endpoint, header, err := wsClientHeaders(endpoint, origin) 212 if err != nil { 213 return nil, err 214 } 215 if tlsConfig != nil { 216 dialer.TLSClientConfig = tlsConfig 217 } 218 ctx = resolvePSIProvider(ctx, endpoint) 219 220 credProviderFunc := CredentialsProviderFromContext(ctx) 221 psiProviderFunc := PSIProviderFromContext(ctx) 222 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 223 if credProviderFunc != nil { 224 token, err := credProviderFunc(ctx) 225 if err != nil { 226 log.Warn("unable to obtain credentials from provider", "err", err) 227 } else { 228 header.Set(HttpAuthorizationHeader, token) 229 } 230 } 231 if psiProviderFunc != nil { 232 psi, err := psiProviderFunc(ctx) 233 if err != nil { 234 log.Warn("unable to obtain PSI from provider", "err", err) 235 } else { 236 header.Set(HttpPrivateStateIdentifierHeader, psi.String()) 237 } 238 } 239 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 240 if err != nil { 241 hErr := wsHandshakeError{err: err} 242 if resp != nil { 243 hErr.status = resp.Status 244 } 245 return nil, hErr 246 } 247 return newWebsocketCodec(conn), nil 248 }) 249 } 250 251 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 252 // that is listening on the given endpoint. 253 // 254 // The context is used for the initial connection establishment. It does not 255 // affect subsequent interactions with the client. 256 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 257 dialer := websocket.Dialer{ 258 ReadBufferSize: wsReadBuffer, 259 WriteBufferSize: wsWriteBuffer, 260 WriteBufferPool: wsBufferPool, 261 } 262 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 263 } 264 265 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 266 endpointURL, err := url.Parse(endpoint) 267 if err != nil { 268 return endpoint, nil, err 269 } 270 header := make(http.Header) 271 if origin != "" { 272 header.Add("origin", origin) 273 } 274 if endpointURL.User != nil { 275 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 276 header.Add(HttpAuthorizationHeader, "Basic "+b64auth) 277 endpointURL.User = nil 278 } 279 return endpointURL.String(), header, nil 280 } 281 282 type websocketCodec struct { 283 *jsonCodec 284 conn *websocket.Conn 285 286 wg sync.WaitGroup 287 pingReset chan struct{} 288 } 289 290 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 291 conn.SetReadLimit(wsMessageSizeLimit) 292 wc := &websocketCodec{ 293 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 294 conn: conn, 295 pingReset: make(chan struct{}, 1), 296 } 297 wc.wg.Add(1) 298 go wc.pingLoop() 299 return wc 300 } 301 302 func (wc *websocketCodec) close() { 303 wc.jsonCodec.close() 304 wc.wg.Wait() 305 } 306 307 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 308 err := wc.jsonCodec.writeJSON(ctx, v) 309 if err == nil { 310 // Notify pingLoop to delay the next idle ping. 311 select { 312 case wc.pingReset <- struct{}{}: 313 default: 314 } 315 } 316 return err 317 } 318 319 // pingLoop sends periodic ping frames when the connection is idle. 320 func (wc *websocketCodec) pingLoop() { 321 var timer = time.NewTimer(wsPingInterval) 322 defer wc.wg.Done() 323 defer timer.Stop() 324 325 for { 326 select { 327 case <-wc.closed(): 328 return 329 case <-wc.pingReset: 330 if !timer.Stop() { 331 <-timer.C 332 } 333 timer.Reset(wsPingInterval) 334 case <-timer.C: 335 wc.jsonCodec.encMu.Lock() 336 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 337 wc.conn.WriteMessage(websocket.PingMessage, nil) 338 wc.jsonCodec.encMu.Unlock() 339 timer.Reset(wsPingInterval) 340 } 341 } 342 }