github.com/cgcardona/r-subnet-evm@v0.1.5/rpc/websocket_test.go (about) 1 // (c) 2019-2020, Ava Labs, Inc. 2 // 3 // This file is a derived work, based on the go-ethereum library whose original 4 // notices appear below. 5 // 6 // It is distributed under a license compatible with the licensing terms of the 7 // original code from which it is derived. 8 // 9 // Much love to the original authors for their work. 10 // ********** 11 // Copyright 2018 The go-ethereum Authors 12 // This file is part of the go-ethereum library. 13 // 14 // The go-ethereum library is free software: you can redistribute it and/or modify 15 // it under the terms of the GNU Lesser General Public License as published by 16 // the Free Software Foundation, either version 3 of the License, or 17 // (at your option) any later version. 18 // 19 // The go-ethereum library is distributed in the hope that it will be useful, 20 // but WITHOUT ANY WARRANTY; without even the implied warranty of 21 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 // GNU Lesser General Public License for more details. 23 // 24 // You should have received a copy of the GNU Lesser General Public License 25 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 26 27 package rpc 28 29 import ( 30 "context" 31 "errors" 32 "io" 33 "net" 34 "net/http" 35 "net/http/httptest" 36 "net/http/httputil" 37 "net/url" 38 "os" 39 "strings" 40 "sync/atomic" 41 "testing" 42 "time" 43 44 "github.com/gorilla/websocket" 45 ) 46 47 func TestWebsocketClientHeaders(t *testing.T) { 48 t.Parallel() 49 50 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 51 if err != nil { 52 t.Fatalf("wsGetConfig failed: %s", err) 53 } 54 if endpoint != "wss://example.com:1234" { 55 t.Fatal("User should have been stripped from the URL") 56 } 57 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 58 t.Fatal("Basic auth header is incorrect") 59 } 60 if header.Get("origin") != "https://example.com" { 61 t.Fatal("Origin not set") 62 } 63 } 64 65 // This test checks that the server rejects connections from disallowed origins. 66 func TestWebsocketOriginCheck(t *testing.T) { 67 t.Parallel() 68 69 var ( 70 srv = newTestServer() 71 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 72 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 73 ) 74 defer srv.Stop() 75 defer httpsrv.Close() 76 77 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 78 if err == nil { 79 client.Close() 80 t.Fatal("no error for wrong origin") 81 } 82 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 83 if !errors.Is(err, wantErr) { 84 t.Fatalf("wrong error for wrong origin: %q", err) 85 } 86 87 // Connections without origin header should work. 88 client, err = DialWebsocket(context.Background(), wsURL, "") 89 if err != nil { 90 t.Fatalf("error for empty origin: %v", err) 91 } 92 client.Close() 93 } 94 95 // This test checks whether calls exceeding the request size limit are rejected. 96 /* 97 func TestWebsocketLargeCall(t *testing.T) { 98 t.Parallel() 99 100 var ( 101 srv = newTestServer() 102 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 103 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 104 ) 105 defer srv.Stop() 106 defer httpsrv.Close() 107 108 client, err := DialWebsocket(context.Background(), wsURL, "") 109 if err != nil { 110 t.Fatalf("can't dial: %v", err) 111 } 112 defer client.Close() 113 114 // This call sends slightly less than the limit and should work. 115 var result echoResult 116 arg := strings.Repeat("x", maxRequestContentLength-200) 117 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 118 t.Fatalf("valid call didn't work: %v", err) 119 } 120 if result.String != arg { 121 t.Fatal("wrong string echoed") 122 } 123 124 // This call sends twice the allowed size and shouldn't work. 125 arg = strings.Repeat("x", maxRequestContentLength*2) 126 err = client.Call(&result, "test_echo", arg) 127 if err == nil { 128 t.Fatal("no error for too large call") 129 } 130 } 131 */ 132 133 func TestWebsocketPeerInfo(t *testing.T) { 134 var ( 135 s = newTestServer() 136 ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) 137 tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") 138 ) 139 defer s.Stop() 140 defer ts.Close() 141 142 ctx := context.Background() 143 c, err := DialWebsocket(ctx, tsurl, "origin.example.com") 144 if err != nil { 145 t.Fatal(err) 146 } 147 148 // Request peer information. 149 var connInfo PeerInfo 150 if err := c.Call(&connInfo, "test_peerInfo"); err != nil { 151 t.Fatal(err) 152 } 153 154 if connInfo.RemoteAddr == "" { 155 t.Error("RemoteAddr not set") 156 } 157 if connInfo.Transport != "ws" { 158 t.Errorf("wrong Transport %q", connInfo.Transport) 159 } 160 if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { 161 t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) 162 } 163 if connInfo.HTTP.Origin != "origin.example.com" { 164 t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) 165 } 166 } 167 168 // This test checks that client handles WebSocket ping frames correctly. 169 func TestClientWebsocketPing(t *testing.T) { 170 t.Parallel() 171 172 var ( 173 sendPing = make(chan struct{}) 174 server = wsPingTestServer(t, sendPing) 175 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 176 ) 177 defer cancel() 178 defer server.Shutdown(ctx) 179 180 client, err := DialContext(ctx, "ws://"+server.Addr) 181 if err != nil { 182 t.Fatalf("client dial error: %v", err) 183 } 184 defer client.Close() 185 186 resultChan := make(chan int) 187 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 188 if err != nil { 189 t.Fatalf("client subscribe error: %v", err) 190 } 191 // Note: Unsubscribe is not called on this subscription because the mockup 192 // server can't handle the request. 193 194 // Wait for the context's deadline to be reached before proceeding. 195 // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 196 <-ctx.Done() 197 close(sendPing) 198 199 // Wait for the subscription result. 200 timeout := time.NewTimer(5 * time.Second) 201 defer timeout.Stop() 202 for { 203 select { 204 case err := <-sub.Err(): 205 t.Error("client subscription error:", err) 206 case result := <-resultChan: 207 t.Log("client got result:", result) 208 return 209 case <-timeout.C: 210 t.Error("didn't get any result within the test timeout") 211 return 212 } 213 } 214 } 215 216 // This checks that the websocket transport can deal with large messages. 217 func TestClientWebsocketLargeMessage(t *testing.T) { 218 var ( 219 srv = NewServer(0) 220 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 221 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 222 ) 223 defer srv.Stop() 224 defer httpsrv.Close() 225 226 respLength := wsMessageSizeLimit - 50 227 srv.RegisterName("test", largeRespService{respLength}) 228 229 c, err := DialWebsocket(context.Background(), wsURL, "") 230 if err != nil { 231 t.Fatal(err) 232 } 233 234 var r string 235 if err := c.Call(&r, "test_largeResp"); err != nil { 236 t.Fatal("call failed:", err) 237 } 238 if len(r) != respLength { 239 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 240 } 241 } 242 243 func TestClientWebsocketSevered(t *testing.T) { 244 if os.Getenv("RUN_FLAKY_TESTS") != "true" { 245 t.Skip("FLAKY") 246 } 247 t.Parallel() 248 249 var ( 250 server = wsPingTestServer(t, nil) 251 ctx = context.Background() 252 ) 253 defer server.Shutdown(ctx) 254 255 u, err := url.Parse("http://" + server.Addr) 256 if err != nil { 257 t.Fatal(err) 258 } 259 rproxy := httputil.NewSingleHostReverseProxy(u) 260 var severable *severableReadWriteCloser 261 rproxy.ModifyResponse = func(response *http.Response) error { 262 severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} 263 response.Body = severable 264 return nil 265 } 266 frontendProxy := httptest.NewServer(rproxy) 267 defer frontendProxy.Close() 268 269 wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") 270 client, err := DialWebsocket(ctx, wsURL, "") 271 if err != nil { 272 t.Fatalf("client dial error: %v", err) 273 } 274 defer client.Close() 275 276 resultChan := make(chan int) 277 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 278 if err != nil { 279 t.Fatalf("client subscribe error: %v", err) 280 } 281 282 // sever the connection 283 severable.Sever() 284 285 // Wait for subscription error. 286 timeout := time.NewTimer(5 * wsPingInterval) 287 defer timeout.Stop() 288 for { 289 select { 290 case err := <-sub.Err(): 291 t.Log("client subscription error:", err) 292 return 293 case result := <-resultChan: 294 t.Error("unexpected result:", result) 295 return 296 case <-timeout.C: 297 t.Error("didn't get any error within the test timeout") 298 return 299 } 300 } 301 } 302 303 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 304 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 305 // pong and finally delivers a single subscription result. 306 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 307 var srv http.Server 308 shutdown := make(chan struct{}) 309 srv.RegisterOnShutdown(func() { 310 close(shutdown) 311 }) 312 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 313 // Upgrade to WebSocket. 314 upgrader := websocket.Upgrader{ 315 CheckOrigin: func(r *http.Request) bool { return true }, 316 } 317 conn, err := upgrader.Upgrade(w, r, nil) 318 if err != nil { 319 t.Errorf("server WS upgrade error: %v", err) 320 return 321 } 322 defer conn.Close() 323 324 // Handle the connection. 325 wsPingTestHandler(t, conn, shutdown, sendPing) 326 }) 327 328 // Start the server. 329 listener, err := net.Listen("tcp", "127.0.0.1:0") 330 if err != nil { 331 t.Fatal("can't listen:", err) 332 } 333 srv.Addr = listener.Addr().String() 334 go srv.Serve(listener) 335 return &srv 336 } 337 338 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 339 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 340 const ( 341 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 342 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 343 ) 344 345 // Handle subscribe request. 346 if _, _, err := conn.ReadMessage(); err != nil { 347 t.Errorf("server read error: %v", err) 348 return 349 } 350 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 351 t.Errorf("server write error: %v", err) 352 return 353 } 354 355 // Read from the connection to process control messages. 356 var pongCh = make(chan string) 357 conn.SetPongHandler(func(d string) error { 358 t.Logf("server got pong: %q", d) 359 pongCh <- d 360 return nil 361 }) 362 go func() { 363 for { 364 typ, msg, err := conn.ReadMessage() 365 if err != nil { 366 return 367 } 368 t.Logf("server got message (%d): %q", typ, msg) 369 } 370 }() 371 372 // Write messages. 373 var ( 374 wantPong string 375 timer = time.NewTimer(0) 376 ) 377 defer timer.Stop() 378 <-timer.C 379 for { 380 select { 381 case _, open := <-sendPing: 382 if !open { 383 sendPing = nil 384 } 385 t.Logf("server sending ping") 386 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 387 wantPong = "ping" 388 case data := <-pongCh: 389 if wantPong == "" { 390 t.Errorf("unexpected pong") 391 } else if data != wantPong { 392 t.Errorf("got pong with wrong data %q", data) 393 } 394 wantPong = "" 395 timer.Reset(200 * time.Millisecond) 396 case <-timer.C: 397 t.Logf("server sending response") 398 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 399 case <-shutdown: 400 conn.Close() 401 return 402 } 403 } 404 } 405 406 // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. 407 type severableReadWriteCloser struct { 408 io.ReadWriteCloser 409 severed int32 // atomic 410 } 411 412 func (s *severableReadWriteCloser) Sever() { 413 atomic.StoreInt32(&s.severed, 1) 414 } 415 416 func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { 417 if atomic.LoadInt32(&s.severed) > 0 { 418 return 0, nil 419 } 420 return s.ReadWriteCloser.Read(p) 421 } 422 423 func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { 424 if atomic.LoadInt32(&s.severed) > 0 { 425 return len(p), nil 426 } 427 return s.ReadWriteCloser.Write(p) 428 } 429 430 func (s *severableReadWriteCloser) Close() error { 431 return s.ReadWriteCloser.Close() 432 }