github.com/MetalBlockchain/subnet-evm@v0.4.9/rpc/websocket_test.go (about) 1 // (c) 2019-2020, Ava Labs, Inc. 2 // 3 // This file is a derived work, based on the go-ethereum library whose original 4 // notices appear below. 5 // 6 // It is distributed under a license compatible with the licensing terms of the 7 // original code from which it is derived. 8 // 9 // Much love to the original authors for their work. 10 // ********** 11 // Copyright 2018 The go-ethereum Authors 12 // This file is part of the go-ethereum library. 13 // 14 // The go-ethereum library is free software: you can redistribute it and/or modify 15 // it under the terms of the GNU Lesser General Public License as published by 16 // the Free Software Foundation, either version 3 of the License, or 17 // (at your option) any later version. 18 // 19 // The go-ethereum library is distributed in the hope that it will be useful, 20 // but WITHOUT ANY WARRANTY; without even the implied warranty of 21 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 22 // GNU Lesser General Public License for more details. 23 // 24 // You should have received a copy of the GNU Lesser General Public License 25 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 26 27 package rpc 28 29 import ( 30 "context" 31 "errors" 32 "io" 33 "net" 34 "net/http" 35 "net/http/httptest" 36 "net/http/httputil" 37 "net/url" 38 "strings" 39 "sync/atomic" 40 "testing" 41 "time" 42 43 "github.com/gorilla/websocket" 44 ) 45 46 func TestWebsocketClientHeaders(t *testing.T) { 47 t.Parallel() 48 49 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 50 if err != nil { 51 t.Fatalf("wsGetConfig failed: %s", err) 52 } 53 if endpoint != "wss://example.com:1234" { 54 t.Fatal("User should have been stripped from the URL") 55 } 56 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 57 t.Fatal("Basic auth header is incorrect") 58 } 59 if header.Get("origin") != "https://example.com" { 60 t.Fatal("Origin not set") 61 } 62 } 63 64 // This test checks that the server rejects connections from disallowed origins. 65 func TestWebsocketOriginCheck(t *testing.T) { 66 t.Parallel() 67 68 var ( 69 srv = newTestServer() 70 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 71 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 72 ) 73 defer srv.Stop() 74 defer httpsrv.Close() 75 76 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 77 if err == nil { 78 client.Close() 79 t.Fatal("no error for wrong origin") 80 } 81 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 82 if !errors.Is(err, wantErr) { 83 t.Fatalf("wrong error for wrong origin: %q", err) 84 } 85 86 // Connections without origin header should work. 87 client, err = DialWebsocket(context.Background(), wsURL, "") 88 if err != nil { 89 t.Fatalf("error for empty origin: %v", err) 90 } 91 client.Close() 92 } 93 94 // This test checks whether calls exceeding the request size limit are rejected. 95 /* 96 func TestWebsocketLargeCall(t *testing.T) { 97 t.Parallel() 98 99 var ( 100 srv = newTestServer() 101 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 102 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 103 ) 104 defer srv.Stop() 105 defer httpsrv.Close() 106 107 client, err := DialWebsocket(context.Background(), wsURL, "") 108 if err != nil { 109 t.Fatalf("can't dial: %v", err) 110 } 111 defer client.Close() 112 113 // This call sends slightly less than the limit and should work. 114 var result echoResult 115 arg := strings.Repeat("x", maxRequestContentLength-200) 116 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 117 t.Fatalf("valid call didn't work: %v", err) 118 } 119 if result.String != arg { 120 t.Fatal("wrong string echoed") 121 } 122 123 // This call sends twice the allowed size and shouldn't work. 124 arg = strings.Repeat("x", maxRequestContentLength*2) 125 err = client.Call(&result, "test_echo", arg) 126 if err == nil { 127 t.Fatal("no error for too large call") 128 } 129 } 130 */ 131 132 func TestWebsocketPeerInfo(t *testing.T) { 133 var ( 134 s = newTestServer() 135 ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) 136 tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") 137 ) 138 defer s.Stop() 139 defer ts.Close() 140 141 ctx := context.Background() 142 c, err := DialWebsocket(ctx, tsurl, "origin.example.com") 143 if err != nil { 144 t.Fatal(err) 145 } 146 147 // Request peer information. 148 var connInfo PeerInfo 149 if err := c.Call(&connInfo, "test_peerInfo"); err != nil { 150 t.Fatal(err) 151 } 152 153 if connInfo.RemoteAddr == "" { 154 t.Error("RemoteAddr not set") 155 } 156 if connInfo.Transport != "ws" { 157 t.Errorf("wrong Transport %q", connInfo.Transport) 158 } 159 if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { 160 t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) 161 } 162 if connInfo.HTTP.Origin != "origin.example.com" { 163 t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) 164 } 165 } 166 167 // This test checks that client handles WebSocket ping frames correctly. 168 func TestClientWebsocketPing(t *testing.T) { 169 t.Parallel() 170 171 var ( 172 sendPing = make(chan struct{}) 173 server = wsPingTestServer(t, sendPing) 174 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 175 ) 176 defer cancel() 177 defer server.Shutdown(ctx) 178 179 client, err := DialContext(ctx, "ws://"+server.Addr) 180 if err != nil { 181 t.Fatalf("client dial error: %v", err) 182 } 183 defer client.Close() 184 185 resultChan := make(chan int) 186 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 187 if err != nil { 188 t.Fatalf("client subscribe error: %v", err) 189 } 190 // Note: Unsubscribe is not called on this subscription because the mockup 191 // server can't handle the request. 192 193 // Wait for the context's deadline to be reached before proceeding. 194 // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 195 <-ctx.Done() 196 close(sendPing) 197 198 // Wait for the subscription result. 199 timeout := time.NewTimer(5 * time.Second) 200 defer timeout.Stop() 201 for { 202 select { 203 case err := <-sub.Err(): 204 t.Error("client subscription error:", err) 205 case result := <-resultChan: 206 t.Log("client got result:", result) 207 return 208 case <-timeout.C: 209 t.Error("didn't get any result within the test timeout") 210 return 211 } 212 } 213 } 214 215 // This checks that the websocket transport can deal with large messages. 216 func TestClientWebsocketLargeMessage(t *testing.T) { 217 var ( 218 srv = NewServer(0) 219 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 220 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 221 ) 222 defer srv.Stop() 223 defer httpsrv.Close() 224 225 respLength := wsMessageSizeLimit - 50 226 srv.RegisterName("test", largeRespService{respLength}) 227 228 c, err := DialWebsocket(context.Background(), wsURL, "") 229 if err != nil { 230 t.Fatal(err) 231 } 232 233 var r string 234 if err := c.Call(&r, "test_largeResp"); err != nil { 235 t.Fatal("call failed:", err) 236 } 237 if len(r) != respLength { 238 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 239 } 240 } 241 242 func TestClientWebsocketSevered(t *testing.T) { 243 t.Skip("FLAKY") 244 t.Parallel() 245 246 var ( 247 server = wsPingTestServer(t, nil) 248 ctx = context.Background() 249 ) 250 defer server.Shutdown(ctx) 251 252 u, err := url.Parse("http://" + server.Addr) 253 if err != nil { 254 t.Fatal(err) 255 } 256 rproxy := httputil.NewSingleHostReverseProxy(u) 257 var severable *severableReadWriteCloser 258 rproxy.ModifyResponse = func(response *http.Response) error { 259 severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} 260 response.Body = severable 261 return nil 262 } 263 frontendProxy := httptest.NewServer(rproxy) 264 defer frontendProxy.Close() 265 266 wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") 267 client, err := DialWebsocket(ctx, wsURL, "") 268 if err != nil { 269 t.Fatalf("client dial error: %v", err) 270 } 271 defer client.Close() 272 273 resultChan := make(chan int) 274 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 275 if err != nil { 276 t.Fatalf("client subscribe error: %v", err) 277 } 278 279 // sever the connection 280 severable.Sever() 281 282 // Wait for subscription error. 283 timeout := time.NewTimer(5 * wsPingInterval) 284 defer timeout.Stop() 285 for { 286 select { 287 case err := <-sub.Err(): 288 t.Log("client subscription error:", err) 289 return 290 case result := <-resultChan: 291 t.Error("unexpected result:", result) 292 return 293 case <-timeout.C: 294 t.Error("didn't get any error within the test timeout") 295 return 296 } 297 } 298 } 299 300 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 301 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 302 // pong and finally delivers a single subscription result. 303 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 304 var srv http.Server 305 shutdown := make(chan struct{}) 306 srv.RegisterOnShutdown(func() { 307 close(shutdown) 308 }) 309 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 310 // Upgrade to WebSocket. 311 upgrader := websocket.Upgrader{ 312 CheckOrigin: func(r *http.Request) bool { return true }, 313 } 314 conn, err := upgrader.Upgrade(w, r, nil) 315 if err != nil { 316 t.Errorf("server WS upgrade error: %v", err) 317 return 318 } 319 defer conn.Close() 320 321 // Handle the connection. 322 wsPingTestHandler(t, conn, shutdown, sendPing) 323 }) 324 325 // Start the server. 326 listener, err := net.Listen("tcp", "127.0.0.1:0") 327 if err != nil { 328 t.Fatal("can't listen:", err) 329 } 330 srv.Addr = listener.Addr().String() 331 go srv.Serve(listener) 332 return &srv 333 } 334 335 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 336 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 337 const ( 338 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 339 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 340 ) 341 342 // Handle subscribe request. 343 if _, _, err := conn.ReadMessage(); err != nil { 344 t.Errorf("server read error: %v", err) 345 return 346 } 347 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 348 t.Errorf("server write error: %v", err) 349 return 350 } 351 352 // Read from the connection to process control messages. 353 var pongCh = make(chan string) 354 conn.SetPongHandler(func(d string) error { 355 t.Logf("server got pong: %q", d) 356 pongCh <- d 357 return nil 358 }) 359 go func() { 360 for { 361 typ, msg, err := conn.ReadMessage() 362 if err != nil { 363 return 364 } 365 t.Logf("server got message (%d): %q", typ, msg) 366 } 367 }() 368 369 // Write messages. 370 var ( 371 wantPong string 372 timer = time.NewTimer(0) 373 ) 374 defer timer.Stop() 375 <-timer.C 376 for { 377 select { 378 case _, open := <-sendPing: 379 if !open { 380 sendPing = nil 381 } 382 t.Logf("server sending ping") 383 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 384 wantPong = "ping" 385 case data := <-pongCh: 386 if wantPong == "" { 387 t.Errorf("unexpected pong") 388 } else if data != wantPong { 389 t.Errorf("got pong with wrong data %q", data) 390 } 391 wantPong = "" 392 timer.Reset(200 * time.Millisecond) 393 case <-timer.C: 394 t.Logf("server sending response") 395 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 396 case <-shutdown: 397 conn.Close() 398 return 399 } 400 } 401 } 402 403 // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. 404 type severableReadWriteCloser struct { 405 io.ReadWriteCloser 406 severed int32 // atomic 407 } 408 409 func (s *severableReadWriteCloser) Sever() { 410 atomic.StoreInt32(&s.severed, 1) 411 } 412 413 func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { 414 if atomic.LoadInt32(&s.severed) > 0 { 415 return 0, nil 416 } 417 return s.ReadWriteCloser.Read(p) 418 } 419 420 func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { 421 if atomic.LoadInt32(&s.severed) > 0 { 422 return len(p), nil 423 } 424 return s.ReadWriteCloser.Write(p) 425 } 426 427 func (s *severableReadWriteCloser) Close() error { 428 return s.ReadWriteCloser.Close() 429 }