github.com/amazechain/amc@v0.1.3/modules/rpc/jsonrpc/websocket.go (about) 1 // Copyright 2022 The AmazeChain Authors 2 // This file is part of the AmazeChain library. 3 // 4 // The AmazeChain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The AmazeChain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package jsonrpc 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 "github.com/amazechain/amc/log" 24 "net/http" 25 "net/url" 26 "os" 27 "strings" 28 "sync" 29 "time" 30 31 mapset "github.com/deckarep/golang-set" 32 "github.com/gorilla/websocket" 33 ) 34 35 const ( 36 wsReadBuffer = 1024 37 wsWriteBuffer = 1024 38 wsPingInterval = 60 * time.Second 39 wsPingWriteTimeout = 5 * time.Second 40 wsPongTimeout = 30 * time.Second 41 wsMessageSizeLimit = 15 * 1024 * 1024 42 ) 43 44 var wsBufferPool = new(sync.Pool) 45 46 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 47 // 48 // allowedOrigins should be a comma-separated list of allowed origin URLs. 49 // To allow connections with any origin, pass "*". 50 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 51 var upgrader = websocket.Upgrader{ 52 ReadBufferSize: wsReadBuffer, 53 WriteBufferSize: wsWriteBuffer, 54 WriteBufferPool: wsBufferPool, 55 CheckOrigin: wsHandshakeValidator(allowedOrigins), 56 } 57 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 conn, err := upgrader.Upgrade(w, r, nil) 59 if err != nil { 60 log.Debug("WebSocket upgrade failed", "err", err) 61 return 62 } 63 codec := newWebsocketCodec(conn, r.Host, r.Header) 64 s.ServeCodec(codec, 0) 65 }) 66 } 67 68 // wsHandshakeValidator returns a handler that verifies the origin during the 69 // websocket upgrade process. When a '*' is specified as an allowed origins all 70 // connections are accepted. 71 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 72 origins := mapset.NewSet() 73 allowAllOrigins := false 74 75 for _, origin := range allowedOrigins { 76 if origin == "*" { 77 allowAllOrigins = true 78 } 79 if origin != "" { 80 origins.Add(origin) 81 } 82 } 83 // allow localhost if no allowedOrigins are specified. 84 if len(origins.ToSlice()) == 0 { 85 origins.Add("http://localhost") 86 if hostname, err := os.Hostname(); err == nil { 87 origins.Add("http://" + hostname) 88 } 89 } 90 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 91 92 f := func(req *http.Request) bool { 93 // Skip origin verification if no Origin header is present. The origin check 94 // is supposed to protect against browser based attacks. Browsers always set 95 // Origin. Non-browser software can put anything in origin and checking it doesn't 96 // provide additional security. 97 if _, ok := req.Header["Origin"]; !ok { 98 return true 99 } 100 // Verify origin against allow list. 101 origin := strings.ToLower(req.Header.Get("Origin")) 102 if allowAllOrigins || originIsAllowed(origins, origin) { 103 return true 104 } 105 log.Warn("Rejected WebSocket connection", "origin", origin) 106 return false 107 } 108 109 return f 110 } 111 112 type wsHandshakeError struct { 113 err error 114 status string 115 } 116 117 func (e wsHandshakeError) Error() string { 118 s := e.err.Error() 119 if e.status != "" { 120 s += " (HTTP status " + e.status + ")" 121 } 122 return s 123 } 124 125 func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { 126 it := allowedOrigins.Iterator() 127 for origin := range it.C { 128 if ruleAllowsOrigin(origin.(string), browserOrigin) { 129 return true 130 } 131 } 132 return false 133 } 134 135 func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { 136 var ( 137 allowedScheme, allowedHostname, allowedPort string 138 browserScheme, browserHostname, browserPort string 139 err error 140 ) 141 allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) 142 if err != nil { 143 log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) 144 return false 145 } 146 browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) 147 if err != nil { 148 log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) 149 return false 150 } 151 if allowedScheme != "" && allowedScheme != browserScheme { 152 return false 153 } 154 if allowedHostname != "" && allowedHostname != browserHostname { 155 return false 156 } 157 if allowedPort != "" && allowedPort != browserPort { 158 return false 159 } 160 return true 161 } 162 163 func parseOriginURL(origin string) (string, string, string, error) { 164 parsedURL, err := url.Parse(strings.ToLower(origin)) 165 if err != nil { 166 return "", "", "", err 167 } 168 var scheme, hostname, port string 169 if strings.Contains(origin, "://") { 170 scheme = parsedURL.Scheme 171 hostname = parsedURL.Hostname() 172 port = parsedURL.Port() 173 } else { 174 scheme = "" 175 hostname = parsedURL.Scheme 176 port = parsedURL.Opaque 177 if hostname == "" { 178 hostname = origin 179 } 180 } 181 return scheme, hostname, port, nil 182 } 183 184 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 185 // that is listening on the given endpoint using the provided dialer. 186 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 187 endpoint, header, err := wsClientHeaders(endpoint, origin) 188 if err != nil { 189 return nil, err 190 } 191 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 192 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 193 if err != nil { 194 hErr := wsHandshakeError{err: err} 195 if resp != nil { 196 hErr.status = resp.Status 197 } 198 return nil, hErr 199 } 200 return newWebsocketCodec(conn, endpoint, header), nil 201 }) 202 } 203 204 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 205 // that is listening on the given endpoint. 206 // 207 // The context is used for the initial connection establishment. It does not 208 // affect subsequent interactions with the client. 209 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 210 dialer := websocket.Dialer{ 211 ReadBufferSize: wsReadBuffer, 212 WriteBufferSize: wsWriteBuffer, 213 WriteBufferPool: wsBufferPool, 214 } 215 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 216 } 217 218 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 219 endpointURL, err := url.Parse(endpoint) 220 if err != nil { 221 return endpoint, nil, err 222 } 223 header := make(http.Header) 224 if origin != "" { 225 header.Add("origin", origin) 226 } 227 if endpointURL.User != nil { 228 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 229 header.Add("authorization", "Basic "+b64auth) 230 endpointURL.User = nil 231 } 232 return endpointURL.String(), header, nil 233 } 234 235 type websocketCodec struct { 236 *jsonCodec 237 conn *websocket.Conn 238 //info PeerInfo 239 240 wg sync.WaitGroup 241 pingReset chan struct{} 242 } 243 244 func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { 245 conn.SetReadLimit(wsMessageSizeLimit) 246 conn.SetPongHandler(func(appData string) error { 247 conn.SetReadDeadline(time.Time{}) 248 return nil 249 }) 250 wc := &websocketCodec{ 251 jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), 252 conn: conn, 253 pingReset: make(chan struct{}, 1), 254 //info: PeerInfo{ 255 // Transport: "ws", 256 // RemoteAddr: conn.RemoteAddr().String(), 257 //}, 258 } 259 // Fill in connection details. 260 //wc.info.HTTP.Host = host 261 //wc.info.HTTP.Origin = req.Get("Origin") 262 //wc.info.HTTP.UserAgent = req.Get("User-Agent") 263 // Start pinger. 264 wc.wg.Add(1) 265 go wc.pingLoop() 266 return wc 267 } 268 269 func (wc *websocketCodec) close() { 270 wc.jsonCodec.close() 271 wc.wg.Wait() 272 } 273 274 //func (wc *websocketCodec) peerInfo() PeerInfo { 275 // return wc.info 276 //} 277 278 func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { 279 err := wc.jsonCodec.writeJSON(ctx, v) 280 if err == nil { 281 // Notify pingLoop to delay the next idle ping. 282 select { 283 case wc.pingReset <- struct{}{}: 284 default: 285 } 286 } 287 return err 288 } 289 290 // pingLoop sends periodic ping frames when the connection is idle. 291 func (wc *websocketCodec) pingLoop() { 292 var timer = time.NewTimer(wsPingInterval) 293 defer wc.wg.Done() 294 defer timer.Stop() 295 296 for { 297 select { 298 case <-wc.closed(): 299 return 300 case <-wc.pingReset: 301 if !timer.Stop() { 302 <-timer.C 303 } 304 timer.Reset(wsPingInterval) 305 case <-timer.C: 306 wc.jsonCodec.encMu.Lock() 307 wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout)) 308 wc.conn.WriteMessage(websocket.PingMessage, nil) 309 wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout)) 310 wc.jsonCodec.encMu.Unlock() 311 timer.Reset(wsPingInterval) 312 } 313 } 314 }