github.com/klaytn/klaytn@v1.12.1/networks/rpc/websocket.go (about) 1 // Modifications Copyright 2018 The klaytn Authors 2 // Copyright 2015 The go-ethereum Authors 3 // This file is part of the go-ethereum library. 4 // 5 // The go-ethereum library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The go-ethereum library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 // This file is derived from rpc/websocket.go (2018/06/04). 19 // Modified and improved for the klaytn development. 20 21 package rpc 22 23 import ( 24 "bufio" 25 "bytes" 26 "context" 27 "encoding/base64" 28 "encoding/json" 29 "fmt" 30 "net/http" 31 "net/url" 32 "os" 33 "strings" 34 "sync" 35 "sync/atomic" 36 "time" 37 38 fastws "github.com/clevergo/websocket" 39 mapset "github.com/deckarep/golang-set" 40 "github.com/gorilla/websocket" 41 "github.com/klaytn/klaytn/common" 42 "github.com/valyala/fasthttp" 43 ) 44 45 const ( 46 wsReadBuffer = 1024 47 wsWriteBuffer = 1024 48 ) 49 50 var wsBufferPool = new(sync.Pool) 51 52 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 53 conn.SetReadLimit(int64(common.MaxRequestContentLength)) 54 if WebsocketReadDeadline != 0 { 55 conn.SetReadDeadline(time.Now().Add(time.Duration(WebsocketReadDeadline) * time.Second)) 56 } 57 if WebsocketWriteDeadline != 0 { 58 conn.SetWriteDeadline(time.Now().Add(time.Duration(WebsocketWriteDeadline) * time.Second)) 59 } 60 return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) 61 } 62 63 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 64 // 65 // allowedOrigins should be a comma-separated list of allowed origin URLs. 66 // To allow connections with any origin, pass "*". 67 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 68 upgrader := websocket.Upgrader{ 69 ReadBufferSize: wsReadBuffer, 70 WriteBufferSize: wsWriteBuffer, 71 WriteBufferPool: wsBufferPool, 72 CheckOrigin: wsHandshakeValidator(allowedOrigins), 73 } 74 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 75 if atomic.LoadInt32(&srv.wsConnCount) >= MaxWebsocketConnections { 76 return 77 } 78 atomic.AddInt32(&srv.wsConnCount, 1) 79 wsConnCounter.Inc(1) 80 defer func() { 81 atomic.AddInt32(&srv.wsConnCount, -1) 82 wsConnCounter.Dec(1) 83 }() 84 conn, err := upgrader.Upgrade(w, r, nil) 85 if err != nil { 86 return 87 } 88 codec := newWebsocketCodec(conn) 89 srv.ServeCodec(codec, 0) 90 }) 91 } 92 93 var upgrader = fastws.Upgrader{ 94 ReadBufferSize: 1024, 95 WriteBufferSize: 1024, 96 } 97 98 func (srv *Server) FastWebsocketHandler(ctx *fasthttp.RequestCtx) { 99 // TODO-Klaytn handle websocket protocol 100 protocol := ctx.Request.Header.Peek("Sec-WebSocket-Protocol") 101 if protocol != nil { 102 ctx.Response.Header.Set("Sec-WebSocket-Protocol", string(protocol)) 103 } 104 105 err := upgrader.Upgrade(ctx, func(conn *fastws.Conn) { 106 if atomic.LoadInt32(&srv.wsConnCount) >= MaxWebsocketConnections { 107 return 108 } 109 atomic.AddInt32(&srv.wsConnCount, 1) 110 wsConnCounter.Inc(1) 111 defer func() { 112 atomic.AddInt32(&srv.wsConnCount, -1) 113 wsConnCounter.Dec(1) 114 }() 115 if WebsocketReadDeadline != 0 { 116 conn.SetReadDeadline(time.Now().Add(time.Duration(WebsocketReadDeadline) * time.Second)) 117 } 118 if WebsocketWriteDeadline != 0 { 119 conn.SetWriteDeadline(time.Now().Add(time.Duration(WebsocketWriteDeadline) * time.Second)) 120 } 121 // Create a custom encode/decode pair to enforce payload size and number encoding 122 encoder := func(v interface{}) error { 123 msg, err := json.Marshal(v) 124 if err != nil { 125 return err 126 } 127 err = conn.WriteMessage(websocket.TextMessage, msg) 128 if err != nil { 129 return err 130 } 131 return err 132 } 133 decoder := func(v interface{}) error { 134 _, data, err := conn.ReadMessage() 135 if err != nil { 136 return err 137 } 138 dec := json.NewDecoder(bytes.NewReader(data)) 139 dec.UseNumber() 140 return dec.Decode(v) 141 } 142 143 reader := bufio.NewReaderSize(bytes.NewReader(ctx.Request.Body()), common.MaxRequestContentLength) 144 srv.ServeCodec(NewFuncCodec(&httpReadWriteNopCloser{reader, ctx.Response.BodyWriter()}, encoder, decoder), 0) 145 }) 146 if err != nil { 147 logger.Error("FastWebsocketHandler fail to upgrade message", "err", err) 148 return 149 } 150 } 151 152 // NewWSServer creates a new websocket RPC server around an API provider. 153 // 154 // Deprecated: use Server.WebsocketHandler 155 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 156 return &http.Server{ 157 Handler: srv.WebsocketHandler(allowedOrigins), 158 } 159 } 160 161 func NewFastWSServer(allowedOrigins []string, srv *Server) *fasthttp.Server { 162 upgrader.CheckOrigin = wsFastHandshakeValidator(allowedOrigins) 163 164 // TODO-Klaytn concurreny default (256 * 1024), goroutine limit (8192) 165 return &fasthttp.Server{ 166 Concurrency: ConcurrencyLimit, 167 MaxRequestBodySize: common.MaxRequestContentLength, 168 Handler: srv.FastWebsocketHandler, 169 } 170 } 171 172 func wsFastHandshakeValidator(allowedOrigins []string) func(ctx *fasthttp.RequestCtx) bool { 173 origins := mapset.NewSet() 174 allowAllOrigins := false 175 176 for _, origin := range allowedOrigins { 177 if origin == "*" { 178 allowAllOrigins = true 179 } 180 if origin != "" { 181 origins.Add(strings.ToLower(origin)) 182 } 183 } 184 185 // allow localhost if no allowedOrigins are specified. 186 if len(origins.ToSlice()) == 0 { 187 origins.Add("http://localhost") 188 if hostname, err := os.Hostname(); err == nil { 189 origins.Add("http://" + strings.ToLower(hostname)) 190 } 191 } 192 193 logger.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) 194 195 f := func(ctx *fasthttp.RequestCtx) bool { 196 // Skip origin verification if no Origin header is present. The origin check 197 // is supposed to protect against browser based attacks. Browsers always set 198 // Origin. Non-browser software can put anything in origin and checking it doesn't 199 // provide additional security. 200 201 origin := strings.ToLower(string(ctx.Request.Header.Peek("Origin"))) 202 if allowAllOrigins || origins.Contains(origin) || origin == "" { 203 return true 204 } 205 logger.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 206 return false 207 } 208 209 return f 210 } 211 212 // wsHandshakeValidator returns a handler that verifies the origin during the 213 // websocket upgrade process. When a '*' is specified as an allowed origins all 214 // connections are accepted. 215 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 216 origins := mapset.NewSet() 217 allowAllOrigins := false 218 219 for _, origin := range allowedOrigins { 220 if origin == "*" { 221 allowAllOrigins = true 222 } 223 if origin != "" { 224 origins.Add(strings.ToLower(origin)) 225 } 226 } 227 228 // allow localhost if no allowedOrigins are specified. 229 if len(origins.ToSlice()) == 0 { 230 origins.Add("http://localhost") 231 if hostname, err := os.Hostname(); err == nil { 232 origins.Add("http://" + strings.ToLower(hostname)) 233 } 234 } 235 f := func(req *http.Request) bool { 236 // Skip origin verification if no Origin header is present. The origin check 237 // is supposed to protect against browser based attacks. Browsers always set 238 // Origin. Non-browser software can put anything in origin and checking it doesn't 239 // provide additional security. 240 if _, ok := req.Header["Origin"]; !ok { 241 return true 242 } 243 // Verify origin against whitelist. 244 origin := strings.ToLower(req.Header.Get("Origin")) 245 if allowAllOrigins || origins.Contains(origin) { 246 return true 247 } 248 249 return false 250 } 251 252 return f 253 } 254 255 type wsHandshakeError struct { 256 err error 257 status string 258 } 259 260 func (e wsHandshakeError) Error() string { 261 s := e.err.Error() 262 if e.status != "" { 263 s += " (HTTP status " + e.status + ")" 264 } 265 return s 266 } 267 268 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 269 // that is listening on the given endpoint. 270 // 271 // The context is used for the initial connection establishment. It does not 272 // affect subsequent interactions with the client. 273 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 274 endpoint, header, err := wsClientHeaders(endpoint, origin) 275 if err != nil { 276 return nil, err 277 } 278 279 dialer := websocket.Dialer{ 280 ReadBufferSize: wsReadBuffer, 281 WriteBufferSize: wsWriteBuffer, 282 WriteBufferPool: wsBufferPool, 283 } 284 285 return NewClient(ctx, func(ctx context.Context) (ServerCodec, error) { 286 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 287 if resp != nil && resp.Body != nil { 288 defer resp.Body.Close() 289 } 290 291 if err != nil { 292 hErr := wsHandshakeError{err: err} 293 if resp != nil { 294 hErr.status = resp.Status 295 } 296 return nil, hErr 297 } 298 return newWebsocketCodec(conn), nil 299 }) 300 } 301 302 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 303 endpointURL, err := url.Parse(endpoint) 304 if err != nil { 305 return endpoint, nil, err 306 } 307 308 header := make(http.Header) 309 310 if origin != "" { 311 header.Add("origin", origin) 312 } 313 314 if endpointURL.User != nil { 315 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 316 header.Add("authorization", "Basic "+b64auth) 317 endpointURL.User = nil 318 } 319 return endpointURL.String(), header, nil 320 }