github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "crypto/tls" 22 "encoding/base64" 23 "fmt" 24 "net/http" 25 "net/url" 26 "os" 27 "strings" 28 "sync" 29 "time" 30 31 mapset "github.com/deckarep/golang-set" 32 "github.com/kisexp/xdchain/log" 33 "github.com/gorilla/websocket" 34 ) 35 36 const ( 37 wsReadBuffer = 1024 38 wsWriteBuffer = 1024 39 wsPingInterval = 60 * time.Second 40 wsPingWriteTimeout = 5 * time.Second 41 ) 42 43 var wsBufferPool = new(sync.Pool) 44 45 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 46 // 47 // allowedOrigins should be a comma-separated list of allowed origin URLs. 48 // To allow connections with any origin, pass "*". 49 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 50 var upgrader = websocket.Upgrader{ 51 ReadBufferSize: wsReadBuffer, 52 WriteBufferSize: wsWriteBuffer, 53 WriteBufferPool: wsBufferPool, 54 CheckOrigin: wsHandshakeValidator(allowedOrigins), 55 } 56 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 conn, err := upgrader.Upgrade(w, r, nil) 58 if err != nil { 59 log.Debug("WebSocket upgrade failed", "err", err) 60 return 61 } 62 codec := newWebsocketCodec(conn) 63 s.authenticateHttpRequest(r, codec) 64 s.ServeCodec(codec, 0) 65 }) 66 } 67 68 // wsHandshakeValidator returns a handler that verifies the origin during the 69 // websocket upgrade process. When a '*' is specified as an allowed origins all 70 // connections are accepted. 71 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 72 origins := mapset.NewSet() 73 allowAllOrigins := false 74 75 for _, origin := range allowedOrigins { 76 if origin == "*" { 77 allowAllOrigins = true 78 } 79 if origin != "" { 80 origins.Add(origin) 81 } 82 } 83 // allow localhost if no allowedOrigins are specified. 84 if len(origins.ToSlice()) == 0 { 85 origins.Add("http://localhost") 86 origins.Add("https://localhost") 87 if hostname, err := os.Hostname(); err == nil { 88 origins.Add("http://" + hostname) 89 origins.Add("https://" + hostname) 90 } 91 } 92 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 93 94 f := func(req *http.Request) bool { 95 // Skip origin verification if no Origin header is present. The origin check 96 // is supposed to protect against browser based attacks. Browsers always set 97 // Origin. Non-browser software can put anything in origin and checking it doesn't 98 // provide additional security. 99 if _, ok := req.Header["Origin"]; !ok { 100 return true 101 } 102 // Verify origin against whitelist. 103 origin := strings.ToLower(req.Header.Get("Origin")) 104 if allowAllOrigins || originIsAllowed(origins, origin) { 105 return true 106 } 107 log.Warn("Rejected WebSocket connection", "origin", origin) 108 return false 109 } 110 111 return f 112 } 113 114 type wsHandshakeError struct { 115 err error 116 status string 117 } 118 119 func (e wsHandshakeError) Error() string { 120 s := e.err.Error() 121 if e.status != "" { 122 s += " (HTTP status " + e.status + ")" 123 } 124 return s 125 } 126 127 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 128 it := allowedOrigins.Iterator() 129 for origin := range it.C { 130 if ruleAllowsOrigin(origin.(string), browserOrigin) { 131 return true 132 } 133 } 134 return false 135 } 136 137 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 138 var ( 139 allowedScheme, allowedHostname, allowedPort string 140 browserScheme, browserHostname, browserPort string 141 err error 142 ) 143 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 144 if err != nil { 145 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 146 return false 147 } 148 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 149 if err != nil { 150 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 151 return false 152 } 153 if allowedScheme != "" && allowedScheme != browserScheme { 154 return false 155 } 156 if allowedHostname != "" && allowedHostname != browserHostname { 157 return false 158 } 159 if allowedPort != "" && allowedPort != browserPort { 160 return false 161 } 162 return true 163 } 164 165 func parseOriginURL(origin string) (string, string, string, error) { 166 parsedURL, err := url.Parse(strings.ToLower(origin)) 167 if err != nil { 168 return "", "", "", err 169 } 170 var scheme, hostname, port string 171 if strings.Contains(origin, "://") { 172 scheme = parsedURL.Scheme 173 hostname = parsedURL.Hostname() 174 port = parsedURL.Port() 175 } else { 176 scheme = "" 177 hostname = parsedURL.Scheme 178 port = parsedURL.Opaque 179 if hostname == "" { 180 hostname = origin 181 } 182 } 183 return scheme, hostname, port, nil 184 } 185 186 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 187 // that is listening on the given endpoint using the provided dialer. 188 // 189 // The context is used for the initial connection establishment. It does not 190 // affect subsequent interactions with the client. 191 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 192 return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil) 193 } 194 195 // Quorum 196 // 197 // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server 198 // that is listening on the given endpoint. 199 // At the same time, allowing to customize TLSClientConfig of the dialer 200 // 201 // The context is used for the initial connection establishment. It does not 202 // affect subsequent interactions with the client. 203 func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) { 204 dialer := websocket.Dialer{ 205 ReadBufferSize: wsReadBuffer, 206 WriteBufferSize: wsWriteBuffer, 207 WriteBufferPool: wsBufferPool, 208 } 209 210 endpoint, header, err := wsClientHeaders(endpoint, origin) 211 if err != nil { 212 return nil, err 213 } 214 if tlsConfig != nil { 215 dialer.TLSClientConfig = tlsConfig 216 } 217 ctx = resolvePSIProvider(ctx, endpoint) 218 219 credProviderFunc := CredentialsProviderFromContext(ctx) 220 psiProviderFunc := PSIProviderFromContext(ctx) 221 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 222 if credProviderFunc != nil { 223 token, err := credProviderFunc(ctx) 224 if err != nil { 225 log.Warn("unable to obtain credentials from provider", "err", err) 226 } else { 227 header.Set(HttpAuthorizationHeader, token) 228 } 229 } 230 if psiProviderFunc != nil { 231 psi, err := psiProviderFunc(ctx) 232 if err != nil { 233 log.Warn("unable to obtain PSI from provider", "err", err) 234 } else { 235 header.Set(HttpPrivateStateIdentifierHeader, psi.String()) 236 } 237 } 238 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 239 if err != nil { 240 hErr := wsHandshakeError{err: err} 241 if resp != nil { 242 hErr.status = resp.Status 243 } 244 return nil, hErr 245 } 246 return newWebsocketCodec(conn), nil 247 }) 248 } 249 250 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 251 // that is listening on the given endpoint. 252 // 253 // The context is used for the initial connection establishment. It does not 254 // affect subsequent interactions with the client. 255 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 256 dialer := websocket.Dialer{ 257 ReadBufferSize: wsReadBuffer, 258 WriteBufferSize: wsWriteBuffer, 259 WriteBufferPool: wsBufferPool, 260 } 261 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 262 } 263 264 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 265 endpointURL, err := url.Parse(endpoint) 266 if err != nil { 267 return endpoint, nil, err 268 } 269 header := make(http.Header) 270 if origin != "" { 271 header.Add("origin", origin) 272 } 273 if endpointURL.User != nil { 274 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 275 header.Add(HttpAuthorizationHeader, "Basic "+b64auth) 276 endpointURL.User = nil 277 } 278 return endpointURL.String(), header, nil 279 } 280 281 type websocketCodec struct { 282 *jsonCodec 283 conn *websocket.Conn 284 285 wg sync.WaitGroup 286 pingReset chan struct{} 287 } 288 289 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 290 conn.SetReadLimit(maxRequestContentLength) 291 wc := &websocketCodec{ 292 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 293 conn: conn, 294 pingReset: make(chan struct{}, 1), 295 } 296 wc.wg.Add(1) 297 go wc.pingLoop() 298 return wc 299 } 300 301 func (wc *websocketCodec) close() { 302 wc.jsonCodec.close() 303 wc.wg.Wait() 304 } 305 306 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 307 err := wc.jsonCodec.writeJSON(ctx, v) 308 if err == nil { 309 // Notify pingLoop to delay the next idle ping. 310 select { 311 case wc.pingReset <- struct{}{}: 312 default: 313 } 314 } 315 return err 316 } 317 318 // pingLoop sends periodic ping frames when the connection is idle. 319 func (wc *websocketCodec) pingLoop() { 320 var timer = time.NewTimer(wsPingInterval) 321 defer wc.wg.Done() 322 defer timer.Stop() 323 324 for { 325 select { 326 case <-wc.closed(): 327 return 328 case <-wc.pingReset: 329 if !timer.Stop() { 330 <-timer.C 331 } 332 timer.Reset(wsPingInterval) 333 case <-timer.C: 334 wc.jsonCodec.encMu.Lock() 335 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 336 wc.conn.WriteMessage(websocket.PingMessage, nil) 337 wc.jsonCodec.encMu.Unlock() 338 timer.Reset(wsPingInterval) 339 } 340 } 341 }