github.com/ethw3/go-ethereuma@v0.0.0-20221013053120-c14602a4c23c/rpc/websocket_test.go (about) 1 // Copyright 2018 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "errors" 22 "io" 23 "net" 24 "net/http" 25 "net/http/httptest" 26 "net/http/httputil" 27 "net/url" 28 "strings" 29 "sync/atomic" 30 "testing" 31 "time" 32 33 "github.com/gorilla/websocket" 34 ) 35 36 func TestWebsocketClientHeaders(t *testing.T) { 37 t.Parallel() 38 39 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 40 if err != nil { 41 t.Fatalf("wsGetConfig failed: %s", err) 42 } 43 if endpoint != "wss://example.com:1234" { 44 t.Fatal("User should have been stripped from the URL") 45 } 46 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 47 t.Fatal("Basic auth header is incorrect") 48 } 49 if header.Get("origin") != "https://example.com" { 50 t.Fatal("Origin not set") 51 } 52 } 53 54 // This test checks that the server rejects connections from disallowed origins. 55 func TestWebsocketOriginCheck(t *testing.T) { 56 t.Parallel() 57 58 var ( 59 srv = newTestServer() 60 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 61 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 62 ) 63 defer srv.Stop() 64 defer httpsrv.Close() 65 66 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 67 if err == nil { 68 client.Close() 69 t.Fatal("no error for wrong origin") 70 } 71 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 72 if !errors.Is(err, wantErr) { 73 t.Fatalf("wrong error for wrong origin: %q", err) 74 } 75 76 // Connections without origin header should work. 77 client, err = DialWebsocket(context.Background(), wsURL, "") 78 if err != nil { 79 t.Fatalf("error for empty origin: %v", err) 80 } 81 client.Close() 82 } 83 84 // This test checks whether calls exceeding the request size limit are rejected. 85 func TestWebsocketLargeCall(t *testing.T) { 86 t.Parallel() 87 88 var ( 89 srv = newTestServer() 90 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 91 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 92 ) 93 defer srv.Stop() 94 defer httpsrv.Close() 95 96 client, err := DialWebsocket(context.Background(), wsURL, "") 97 if err != nil { 98 t.Fatalf("can't dial: %v", err) 99 } 100 defer client.Close() 101 102 // This call sends slightly less than the limit and should work. 103 var result echoResult 104 arg := strings.Repeat("x", maxRequestContentLength-200) 105 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 106 t.Fatalf("valid call didn't work: %v", err) 107 } 108 if result.String != arg { 109 t.Fatal("wrong string echoed") 110 } 111 112 // This call sends twice the allowed size and shouldn't work. 113 arg = strings.Repeat("x", maxRequestContentLength*2) 114 err = client.Call(&result, "test_echo", arg) 115 if err == nil { 116 t.Fatal("no error for too large call") 117 } 118 } 119 120 func TestWebsocketPeerInfo(t *testing.T) { 121 var ( 122 s = newTestServer() 123 ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) 124 tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") 125 ) 126 defer s.Stop() 127 defer ts.Close() 128 129 ctx := context.Background() 130 c, err := DialWebsocket(ctx, tsurl, "origin.example.com") 131 if err != nil { 132 t.Fatal(err) 133 } 134 135 // Request peer information. 136 var connInfo PeerInfo 137 if err := c.Call(&connInfo, "test_peerInfo"); err != nil { 138 t.Fatal(err) 139 } 140 141 if connInfo.RemoteAddr == "" { 142 t.Error("RemoteAddr not set") 143 } 144 if connInfo.Transport != "ws" { 145 t.Errorf("wrong Transport %q", connInfo.Transport) 146 } 147 if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { 148 t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) 149 } 150 if connInfo.HTTP.Origin != "origin.example.com" { 151 t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) 152 } 153 } 154 155 // This test checks that client handles WebSocket ping frames correctly. 156 func TestClientWebsocketPing(t *testing.T) { 157 t.Parallel() 158 159 var ( 160 sendPing = make(chan struct{}) 161 server = wsPingTestServer(t, sendPing) 162 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 163 ) 164 defer cancel() 165 defer server.Shutdown(ctx) 166 167 client, err := DialContext(ctx, "ws://"+server.Addr) 168 if err != nil { 169 t.Fatalf("client dial error: %v", err) 170 } 171 defer client.Close() 172 173 resultChan := make(chan int) 174 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 175 if err != nil { 176 t.Fatalf("client subscribe error: %v", err) 177 } 178 // Note: Unsubscribe is not called on this subscription because the mockup 179 // server can't handle the request. 180 181 // Wait for the context's deadline to be reached before proceeding. 182 // This is important for reproducing https://github.com/ethw3/go-ethereuma/issues/19798 183 <-ctx.Done() 184 close(sendPing) 185 186 // Wait for the subscription result. 187 timeout := time.NewTimer(5 * time.Second) 188 defer timeout.Stop() 189 for { 190 select { 191 case err := <-sub.Err(): 192 t.Error("client subscription error:", err) 193 case result := <-resultChan: 194 t.Log("client got result:", result) 195 return 196 case <-timeout.C: 197 t.Error("didn't get any result within the test timeout") 198 return 199 } 200 } 201 } 202 203 // This checks that the websocket transport can deal with large messages. 204 func TestClientWebsocketLargeMessage(t *testing.T) { 205 var ( 206 srv = NewServer() 207 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 208 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 209 ) 210 defer srv.Stop() 211 defer httpsrv.Close() 212 213 respLength := wsMessageSizeLimit - 50 214 srv.RegisterName("test", largeRespService{respLength}) 215 216 c, err := DialWebsocket(context.Background(), wsURL, "") 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 var r string 222 if err := c.Call(&r, "test_largeResp"); err != nil { 223 t.Fatal("call failed:", err) 224 } 225 if len(r) != respLength { 226 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 227 } 228 } 229 230 func TestClientWebsocketSevered(t *testing.T) { 231 t.Parallel() 232 233 var ( 234 server = wsPingTestServer(t, nil) 235 ctx = context.Background() 236 ) 237 defer server.Shutdown(ctx) 238 239 u, err := url.Parse("http://" + server.Addr) 240 if err != nil { 241 t.Fatal(err) 242 } 243 rproxy := httputil.NewSingleHostReverseProxy(u) 244 var severable *severableReadWriteCloser 245 rproxy.ModifyResponse = func(response *http.Response) error { 246 severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} 247 response.Body = severable 248 return nil 249 } 250 frontendProxy := httptest.NewServer(rproxy) 251 defer frontendProxy.Close() 252 253 wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") 254 client, err := DialWebsocket(ctx, wsURL, "") 255 if err != nil { 256 t.Fatalf("client dial error: %v", err) 257 } 258 defer client.Close() 259 260 resultChan := make(chan int) 261 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 262 if err != nil { 263 t.Fatalf("client subscribe error: %v", err) 264 } 265 266 // sever the connection 267 severable.Sever() 268 269 // Wait for subscription error. 270 timeout := time.NewTimer(3 * wsPingInterval) 271 defer timeout.Stop() 272 for { 273 select { 274 case err := <-sub.Err(): 275 t.Log("client subscription error:", err) 276 return 277 case result := <-resultChan: 278 t.Error("unexpected result:", result) 279 return 280 case <-timeout.C: 281 t.Error("didn't get any error within the test timeout") 282 return 283 } 284 } 285 } 286 287 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 288 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 289 // pong and finally delivers a single subscription result. 290 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 291 var srv http.Server 292 shutdown := make(chan struct{}) 293 srv.RegisterOnShutdown(func() { 294 close(shutdown) 295 }) 296 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 297 // Upgrade to WebSocket. 298 upgrader := websocket.Upgrader{ 299 CheckOrigin: func(r *http.Request) bool { return true }, 300 } 301 conn, err := upgrader.Upgrade(w, r, nil) 302 if err != nil { 303 t.Errorf("server WS upgrade error: %v", err) 304 return 305 } 306 defer conn.Close() 307 308 // Handle the connection. 309 wsPingTestHandler(t, conn, shutdown, sendPing) 310 }) 311 312 // Start the server. 313 listener, err := net.Listen("tcp", "127.0.0.1:0") 314 if err != nil { 315 t.Fatal("can't listen:", err) 316 } 317 srv.Addr = listener.Addr().String() 318 go srv.Serve(listener) 319 return &srv 320 } 321 322 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 323 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 324 const ( 325 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 326 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 327 ) 328 329 // Handle subscribe request. 330 if _, _, err := conn.ReadMessage(); err != nil { 331 t.Errorf("server read error: %v", err) 332 return 333 } 334 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 335 t.Errorf("server write error: %v", err) 336 return 337 } 338 339 // Read from the connection to process control messages. 340 var pongCh = make(chan string) 341 conn.SetPongHandler(func(d string) error { 342 t.Logf("server got pong: %q", d) 343 pongCh <- d 344 return nil 345 }) 346 go func() { 347 for { 348 typ, msg, err := conn.ReadMessage() 349 if err != nil { 350 return 351 } 352 t.Logf("server got message (%d): %q", typ, msg) 353 } 354 }() 355 356 // Write messages. 357 var ( 358 wantPong string 359 timer = time.NewTimer(0) 360 ) 361 defer timer.Stop() 362 <-timer.C 363 for { 364 select { 365 case _, open := <-sendPing: 366 if !open { 367 sendPing = nil 368 } 369 t.Logf("server sending ping") 370 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 371 wantPong = "ping" 372 case data := <-pongCh: 373 if wantPong == "" { 374 t.Errorf("unexpected pong") 375 } else if data != wantPong { 376 t.Errorf("got pong with wrong data %q", data) 377 } 378 wantPong = "" 379 timer.Reset(200 * time.Millisecond) 380 case <-timer.C: 381 t.Logf("server sending response") 382 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 383 case <-shutdown: 384 conn.Close() 385 return 386 } 387 } 388 } 389 390 // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. 391 type severableReadWriteCloser struct { 392 io.ReadWriteCloser 393 severed int32 // atomic 394 } 395 396 func (s *severableReadWriteCloser) Sever() { 397 atomic.StoreInt32(&s.severed, 1) 398 } 399 400 func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { 401 if atomic.LoadInt32(&s.severed) > 0 { 402 return 0, nil 403 } 404 return s.ReadWriteCloser.Read(p) 405 } 406 407 func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { 408 if atomic.LoadInt32(&s.severed) > 0 { 409 return len(p), nil 410 } 411 return s.ReadWriteCloser.Write(p) 412 } 413 414 func (s *severableReadWriteCloser) Close() error { 415 return s.ReadWriteCloser.Close() 416 }