github.com/ethereum/go-ethereum@v1.16.1/rpc/websocket_test.go (about) 1 // Copyright 2018 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "net/http" 24 "net/http/httptest" 25 "strings" 26 "testing" 27 "time" 28 29 "github.com/gorilla/websocket" 30 ) 31 32 func TestWebsocketClientHeaders(t *testing.T) { 33 t.Parallel() 34 35 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 36 if err != nil { 37 t.Fatalf("wsGetConfig failed: %s", err) 38 } 39 if endpoint != "wss://example.com:1234" { 40 t.Fatal("User should have been stripped from the URL") 41 } 42 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 43 t.Fatal("Basic auth header is incorrect") 44 } 45 if header.Get("origin") != "https://example.com" { 46 t.Fatal("Origin not set") 47 } 48 } 49 50 // This test checks that the server rejects connections from disallowed origins. 51 func TestWebsocketOriginCheck(t *testing.T) { 52 t.Parallel() 53 54 var ( 55 srv = newTestServer() 56 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 57 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 58 ) 59 defer srv.Stop() 60 defer httpsrv.Close() 61 62 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 63 if err == nil { 64 client.Close() 65 t.Fatal("no error for wrong origin") 66 } 67 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 68 if !errors.Is(err, wantErr) { 69 t.Fatalf("wrong error for wrong origin: %q", err) 70 } 71 72 // Connections without origin header should work. 73 client, err = DialWebsocket(context.Background(), wsURL, "") 74 if err != nil { 75 t.Fatalf("error for empty origin: %v", err) 76 } 77 client.Close() 78 } 79 80 // This test checks whether calls exceeding the request size limit are rejected. 81 func TestWebsocketLargeCall(t *testing.T) { 82 t.Parallel() 83 84 var ( 85 srv = newTestServer() 86 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 87 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 88 ) 89 defer srv.Stop() 90 defer httpsrv.Close() 91 92 client, err := DialWebsocket(context.Background(), wsURL, "") 93 if err != nil { 94 t.Fatalf("can't dial: %v", err) 95 } 96 defer client.Close() 97 98 // This call sends slightly less than the limit and should work. 99 var result echoResult 100 arg := strings.Repeat("x", defaultBodyLimit-200) 101 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 102 t.Fatalf("valid call didn't work: %v", err) 103 } 104 if result.String != arg { 105 t.Fatal("wrong string echoed") 106 } 107 108 // This call sends twice the allowed size and shouldn't work. 109 arg = strings.Repeat("x", defaultBodyLimit*2) 110 err = client.Call(&result, "test_echo", arg) 111 if err == nil { 112 t.Fatal("no error for too large call") 113 } 114 } 115 116 // This test checks whether the wsMessageSizeLimit option is obeyed. 117 func TestWebsocketLargeRead(t *testing.T) { 118 t.Parallel() 119 120 var ( 121 srv = newTestServer() 122 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 123 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 124 ) 125 defer srv.Stop() 126 defer httpsrv.Close() 127 128 testLimit := func(limit *int64) { 129 opts := []ClientOption{} 130 expLimit := int64(wsDefaultReadLimit) 131 if limit != nil && *limit >= 0 { 132 opts = append(opts, WithWebsocketMessageSizeLimit(*limit)) 133 if *limit > 0 { 134 expLimit = *limit // 0 means infinite 135 } 136 } 137 client, err := DialOptions(context.Background(), wsURL, opts...) 138 if err != nil { 139 t.Fatalf("can't dial: %v", err) 140 } 141 defer client.Close() 142 // Remove some bytes for json encoding overhead. 143 underLimit := int(expLimit - 128) 144 overLimit := expLimit + 1 145 if expLimit == wsDefaultReadLimit { 146 // No point trying the full 32MB in tests. Just sanity-check that 147 // it's not obviously limited. 148 underLimit = 1024 149 overLimit = -1 150 } 151 var res string 152 // Check under limit 153 if err = client.Call(&res, "test_repeat", "A", underLimit); err != nil { 154 t.Fatalf("unexpected error with limit %d: %v", expLimit, err) 155 } 156 if len(res) != underLimit || strings.Count(res, "A") != underLimit { 157 t.Fatal("incorrect data") 158 } 159 // Check over limit 160 if overLimit > 0 { 161 err = client.Call(&res, "test_repeat", "A", expLimit+1) 162 if err == nil || err != websocket.ErrReadLimit { 163 t.Fatalf("wrong error with limit %d: %v expecting %v", expLimit, err, websocket.ErrReadLimit) 164 } 165 } 166 } 167 ptr := func(v int64) *int64 { return &v } 168 169 testLimit(ptr(-1)) // Should be ignored (use default) 170 testLimit(ptr(0)) // Should be ignored (use default) 171 testLimit(nil) // Should be ignored (use default) 172 testLimit(ptr(200)) 173 testLimit(ptr(wsDefaultReadLimit * 2)) 174 } 175 176 func TestWebsocketPeerInfo(t *testing.T) { 177 t.Parallel() 178 179 var ( 180 s = newTestServer() 181 ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) 182 tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") 183 ) 184 defer s.Stop() 185 defer ts.Close() 186 187 ctx := context.Background() 188 c, err := DialWebsocket(ctx, tsurl, "origin.example.com") 189 if err != nil { 190 t.Fatal(err) 191 } 192 defer c.Close() 193 194 // Request peer information. 195 var connInfo PeerInfo 196 if err := c.Call(&connInfo, "test_peerInfo"); err != nil { 197 t.Fatal(err) 198 } 199 200 if connInfo.RemoteAddr == "" { 201 t.Error("RemoteAddr not set") 202 } 203 if connInfo.Transport != "ws" { 204 t.Errorf("wrong Transport %q", connInfo.Transport) 205 } 206 if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { 207 t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) 208 } 209 if connInfo.HTTP.Origin != "origin.example.com" { 210 t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) 211 } 212 } 213 214 // This test checks that client handles WebSocket ping frames correctly. 215 func TestClientWebsocketPing(t *testing.T) { 216 t.Parallel() 217 218 var ( 219 sendPing = make(chan struct{}) 220 server = wsPingTestServer(t, sendPing) 221 ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) 222 ) 223 defer cancel() 224 defer server.Shutdown(ctx) 225 226 client, err := DialContext(ctx, "ws://"+server.Addr) 227 if err != nil { 228 t.Fatalf("client dial error: %v", err) 229 } 230 defer client.Close() 231 232 resultChan := make(chan int) 233 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 234 if err != nil { 235 t.Fatalf("client subscribe error: %v", err) 236 } 237 // Note: Unsubscribe is not called on this subscription because the mockup 238 // server can't handle the request. 239 240 // Wait for the context's deadline to be reached before proceeding. 241 // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 242 <-ctx.Done() 243 close(sendPing) 244 245 // Wait for the subscription result. 246 timeout := time.NewTimer(5 * time.Second) 247 defer timeout.Stop() 248 for { 249 select { 250 case err := <-sub.Err(): 251 t.Error("client subscription error:", err) 252 case result := <-resultChan: 253 t.Log("client got result:", result) 254 return 255 case <-timeout.C: 256 t.Error("didn't get any result within the test timeout") 257 return 258 } 259 } 260 } 261 262 // This checks that the websocket transport can deal with large messages. 263 func TestClientWebsocketLargeMessage(t *testing.T) { 264 t.Parallel() 265 266 var ( 267 srv = NewServer() 268 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 269 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 270 ) 271 defer srv.Stop() 272 defer httpsrv.Close() 273 274 respLength := wsDefaultReadLimit - 50 275 srv.RegisterName("test", largeRespService{respLength}) 276 277 c, err := DialWebsocket(context.Background(), wsURL, "") 278 if err != nil { 279 t.Fatal(err) 280 } 281 defer c.Close() 282 283 var r string 284 if err := c.Call(&r, "test_largeResp"); err != nil { 285 t.Fatal("call failed:", err) 286 } 287 if len(r) != respLength { 288 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 289 } 290 } 291 292 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 293 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 294 // pong and finally delivers a single subscription result. 295 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 296 var srv http.Server 297 shutdown := make(chan struct{}) 298 srv.RegisterOnShutdown(func() { 299 close(shutdown) 300 }) 301 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 302 // Upgrade to WebSocket. 303 upgrader := websocket.Upgrader{ 304 CheckOrigin: func(r *http.Request) bool { return true }, 305 } 306 conn, err := upgrader.Upgrade(w, r, nil) 307 if err != nil { 308 t.Errorf("server WS upgrade error: %v", err) 309 return 310 } 311 defer conn.Close() 312 313 // Handle the connection. 314 wsPingTestHandler(t, conn, shutdown, sendPing) 315 }) 316 317 // Start the server. 318 listener, err := net.Listen("tcp", "127.0.0.1:0") 319 if err != nil { 320 t.Fatal("can't listen:", err) 321 } 322 srv.Addr = listener.Addr().String() 323 go srv.Serve(listener) 324 return &srv 325 } 326 327 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 328 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 329 const ( 330 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 331 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 332 ) 333 334 // Handle subscribe request. 335 if _, _, err := conn.ReadMessage(); err != nil { 336 t.Errorf("server read error: %v", err) 337 return 338 } 339 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 340 t.Errorf("server write error: %v", err) 341 return 342 } 343 344 // Read from the connection to process control messages. 345 var pongCh = make(chan string) 346 conn.SetPongHandler(func(d string) error { 347 t.Logf("server got pong: %q", d) 348 pongCh <- d 349 return nil 350 }) 351 go func() { 352 for { 353 typ, msg, err := conn.ReadMessage() 354 if err != nil { 355 return 356 } 357 t.Logf("server got message (%d): %q", typ, msg) 358 } 359 }() 360 361 // Write messages. 362 var ( 363 wantPong string 364 timer = time.NewTimer(0) 365 ) 366 defer timer.Stop() 367 <-timer.C 368 for { 369 select { 370 case _, open := <-sendPing: 371 if !open { 372 sendPing = nil 373 } 374 t.Logf("server sending ping") 375 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 376 wantPong = "ping" 377 case data := <-pongCh: 378 if wantPong == "" { 379 t.Errorf("unexpected pong") 380 } else if data != wantPong { 381 t.Errorf("got pong with wrong data %q", data) 382 } 383 wantPong = "" 384 timer.Reset(200 * time.Millisecond) 385 case <-timer.C: 386 t.Logf("server sending response") 387 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 388 case <-shutdown: 389 conn.Close() 390 return 391 } 392 } 393 } 394 395 func TestWebsocketMethodNameLengthLimit(t *testing.T) { 396 t.Parallel() 397 398 var ( 399 srv = newTestServer() 400 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 401 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 402 ) 403 defer srv.Stop() 404 defer httpsrv.Close() 405 406 client, err := DialWebsocket(context.Background(), wsURL, "") 407 if err != nil { 408 t.Fatalf("can't dial: %v", err) 409 } 410 defer client.Close() 411 412 // Test cases 413 tests := []struct { 414 name string 415 method string 416 params []interface{} 417 expectedError string 418 isSubscription bool 419 }{ 420 { 421 name: "valid method name", 422 method: "test_echo", 423 params: []interface{}{"test", 1}, 424 expectedError: "", 425 isSubscription: false, 426 }, 427 { 428 name: "method name too long", 429 method: "test_" + string(make([]byte, maxMethodNameLength+1)), 430 params: []interface{}{"test", 1}, 431 expectedError: "method name too long", 432 isSubscription: false, 433 }, 434 { 435 name: "valid subscription", 436 method: "nftest_subscribe", 437 params: []interface{}{"someSubscription", 1, 2}, 438 expectedError: "", 439 isSubscription: true, 440 }, 441 { 442 name: "subscription name too long", 443 method: string(make([]byte, maxMethodNameLength+1)) + "_subscribe", 444 params: []interface{}{"newHeads"}, 445 expectedError: "subscription name too long", 446 isSubscription: true, 447 }, 448 } 449 450 for _, tt := range tests { 451 t.Run(tt.name, func(t *testing.T) { 452 var result interface{} 453 err := client.Call(&result, tt.method, tt.params...) 454 if tt.expectedError == "" { 455 if err != nil { 456 t.Errorf("unexpected error: %v", err) 457 } 458 } else { 459 if err == nil { 460 t.Error("expected error, got nil") 461 } else if !strings.Contains(err.Error(), tt.expectedError) { 462 t.Errorf("expected error containing %q, got %q", tt.expectedError, err.Error()) 463 } 464 } 465 }) 466 } 467 }