github.com/core-coin/go-core/v2@v2.1.9/rpc/websocket.go (about) 1 // Copyright 2015 by the Authors 2 // This file is part of the go-core library. 3 // 4 // The go-core library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-core library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-core library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 "net/http" 24 "net/url" 25 "os" 26 "strings" 27 "sync" 28 "time" 29 30 mapset "github.com/deckarep/golang-set" 31 "github.com/gorilla/websocket" 32 33 "github.com/core-coin/go-core/v2/log" 34 ) 35 36 const ( 37 wsReadBuffer = 1024 38 wsWriteBuffer = 1024 39 wsPingInterval = 60 * time.Second 40 wsPingWriteTimeout = 5 * time.Second 41 ) 42 43 var wsBufferPool = new(sync.Pool) 44 45 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 46 // 47 // allowedOrigins should be a comma-separated list of allowed origin URLs. 48 // To allow connections with any origin, pass "*". 49 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 50 var upgrader = websocket.Upgrader{ 51 ReadBufferSize: wsReadBuffer, 52 WriteBufferSize: wsWriteBuffer, 53 WriteBufferPool: wsBufferPool, 54 CheckOrigin: wsHandshakeValidator(allowedOrigins), 55 } 56 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 57 conn, err := upgrader.Upgrade(w, r, nil) 58 if err != nil { 59 log.Debug("WebSocket upgrade failed", "err", err) 60 return 61 } 62 codec := newWebsocketCodec(conn, r.Host, r.Header) 63 s.ServeCodec(codec, 0) 64 }) 65 } 66 67 // wsHandshakeValidator returns a handler that verifies the origin during the 68 // websocket upgrade process. When a '*' is specified as an allowed origins all 69 // connections are accepted. 70 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 71 origins := mapset.NewSet() 72 allowAllOrigins := false 73 74 for _, origin := range allowedOrigins { 75 if origin == "*" { 76 allowAllOrigins = true 77 } 78 if origin != "" { 79 origins.Add(origin) 80 } 81 } 82 // allow localhost if no allowedOrigins are specified. 83 if len(origins.ToSlice()) == 0 { 84 origins.Add("http://localhost") 85 if hostname, err := os.Hostname(); err == nil { 86 origins.Add("http://" + hostname) 87 } 88 } 89 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 90 91 f := func(req *http.Request) bool { 92 // Skip origin verification if no Origin header is present. The origin check 93 // is supposed to protect against browser based attacks. Browsers always set 94 // Origin. Non-browser software can put anything in origin and checking it doesn't 95 // provide additional security. 96 if _, ok := req.Header["Origin"]; !ok { 97 return true 98 } 99 // Verify origin against whitelist. 100 origin := strings.ToLower(req.Header.Get("Origin")) 101 if allowAllOrigins || originIsAllowed(origins, origin) { 102 return true 103 } 104 log.Warn("Rejected WebSocket connection", "origin", origin) 105 return false 106 } 107 108 return f 109 } 110 111 type wsHandshakeError struct { 112 err error 113 status string 114 } 115 116 func (e wsHandshakeError) Error() string { 117 s := e.err.Error() 118 if e.status != "" { 119 s += " (HTTP status " + e.status + ")" 120 } 121 return s 122 } 123 124 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 125 it := allowedOrigins.Iterator() 126 for origin := range it.C { 127 if ruleAllowsOrigin(origin.(string), browserOrigin) { 128 return true 129 } 130 } 131 return false 132 } 133 134 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 135 var ( 136 allowedScheme, allowedHostname, allowedPort string 137 browserScheme, browserHostname, browserPort string 138 err error 139 ) 140 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 141 if err != nil { 142 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 143 return false 144 } 145 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 146 if err != nil { 147 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 148 return false 149 } 150 if allowedScheme != "" && allowedScheme != browserScheme { 151 return false 152 } 153 if allowedHostname != "" && allowedHostname != browserHostname { 154 return false 155 } 156 if allowedPort != "" && allowedPort != browserPort { 157 return false 158 } 159 return true 160 } 161 162 func parseOriginURL(origin string) (string, string, string, error) { 163 parsedURL, err := url.Parse(strings.ToLower(origin)) 164 if err != nil { 165 return "", "", "", err 166 } 167 var scheme, hostname, port string 168 if strings.Contains(origin, "://") { 169 scheme = parsedURL.Scheme 170 hostname = parsedURL.Hostname() 171 port = parsedURL.Port() 172 } else { 173 scheme = "" 174 hostname = parsedURL.Scheme 175 port = parsedURL.Opaque 176 if hostname == "" { 177 hostname = origin 178 } 179 } 180 return scheme, hostname, port, nil 181 } 182 183 // DialWebsocketWithDialer creates a new RPC client using WebSocket. 184 // 185 // The context is used for the initial connection establishment. It does not 186 // affect subsequent interactions with the client. 187 // 188 // Deprecated: use DialOptions and the WithWebsocketDialer option. 189 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 190 cfg := new(clientConfig) 191 cfg.wsDialer = &dialer 192 if origin != "" { 193 cfg.setHeader("origin", origin) 194 } 195 connect, err := newClientTransportWS(endpoint, cfg) 196 if err != nil { 197 return nil, err 198 } 199 return newClient(ctx, connect) 200 } 201 202 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 203 // that is listening on the given endpoint. 204 // 205 // The context is used for the initial connection establishment. It does not 206 // affect subsequent interactions with the client. 207 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 208 cfg := new(clientConfig) 209 if origin != "" { 210 cfg.setHeader("origin", origin) 211 } 212 connect, err := newClientTransportWS(endpoint, cfg) 213 if err != nil { 214 return nil, err 215 } 216 return newClient(ctx, connect) 217 } 218 219 func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, error) { 220 dialer := cfg.wsDialer 221 if dialer == nil { 222 dialer = &websocket.Dialer{ 223 ReadBufferSize: wsReadBuffer, 224 WriteBufferSize: wsWriteBuffer, 225 WriteBufferPool: wsBufferPool, 226 } 227 } 228 229 dialURL, header, err := wsClientHeaders(endpoint, "") 230 if err != nil { 231 return nil, err 232 } 233 for key, values := range cfg.httpHeaders { 234 header[key] = values 235 } 236 237 connect := func(ctx context.Context) (ServerCodec, error) { 238 header := header.Clone() 239 if cfg.httpAuth != nil { 240 if err := cfg.httpAuth(header); err != nil { 241 return nil, err 242 } 243 } 244 conn, resp, err := dialer.DialContext(ctx, dialURL, header) 245 if err != nil { 246 hErr := wsHandshakeError{err: err} 247 if resp != nil { 248 hErr.status = resp.Status 249 } 250 return nil, hErr 251 } 252 return newWebsocketCodec(conn, dialURL, header), nil 253 } 254 return connect, nil 255 } 256 257 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 258 endpointURL, err := url.Parse(endpoint) 259 if err != nil { 260 return endpoint, nil, err 261 } 262 header := make(http.Header) 263 if origin != "" { 264 header.Add("origin", origin) 265 } 266 if endpointURL.User != nil { 267 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 268 header.Add("authorization", "Basic "+b64auth) 269 endpointURL.User = nil 270 } 271 return endpointURL.String(), header, nil 272 } 273 274 type websocketCodec struct { 275 *jsonCodec 276 conn *websocket.Conn 277 info PeerInfo 278 279 wg sync.WaitGroup 280 pingReset chan struct{} 281 } 282 283 func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { 284 conn.SetReadLimit(maxRequestContentLength) 285 wc := &websocketCodec{ 286 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 287 conn: conn, 288 info: PeerInfo{ 289 Transport: "ws", 290 RemoteAddr: conn.RemoteAddr().String(), 291 }, 292 } 293 // Fill in connection details. 294 wc.info.HTTP.Host = host 295 wc.info.HTTP.Origin = req.Get("Origin") 296 wc.info.HTTP.UserAgent = req.Get("User-Agent") 297 // Start pinger. 298 wc.wg.Add(1) 299 go wc.pingLoop() 300 return wc 301 } 302 303 func (wc *websocketCodec) close() { 304 wc.jsonCodec.close() 305 wc.wg.Wait() 306 } 307 308 func (wc *websocketCodec) peerInfo() PeerInfo { 309 return wc.info 310 } 311 312 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 313 err := wc.jsonCodec.writeJSON(ctx, v) 314 if err == nil { 315 // Notify pingLoop to delay the next idle ping. 316 select { 317 case wc.pingReset <- struct{}{}: 318 default: 319 } 320 } 321 return err 322 } 323 324 // pingLoop sends periodic ping frames when the connection is idle. 325 func (wc *websocketCodec) pingLoop() { 326 var timer = time.NewTimer(wsPingInterval) 327 defer wc.wg.Done() 328 defer timer.Stop() 329 330 for { 331 select { 332 case <-wc.closed(): 333 return 334 case <-wc.pingReset: 335 if !timer.Stop() { 336 <-timer.C 337 } 338 timer.Reset(wsPingInterval) 339 case <-timer.C: 340 wc.jsonCodec.encMu.Lock() 341 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 342 wc.conn.WriteMessage(websocket.PingMessage, nil) 343 wc.jsonCodec.encMu.Unlock() 344 timer.Reset(wsPingInterval) 345 } 346 } 347 }