github.com/LampardNguyen234/go-ethereum@v1.10.16-0.20220117140830-b6a3b0260724/rpc/websocket_test.go (about) 1 // Copyright 2018 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "errors" 22 "io" 23 "net" 24 "net/http" 25 "net/http/httptest" 26 "net/http/httputil" 27 "net/url" 28 "strings" 29 "sync/atomic" 30 "testing" 31 "time" 32 33 "github.com/gorilla/websocket" 34 ) 35 36 func TestWebsocketClientHeaders(t *testing.T) { 37 t.Parallel() 38 39 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 40 if err != nil { 41 t.Fatalf("wsGetConfig failed: %s", err) 42 } 43 if endpoint != "wss://example.com:1234" { 44 t.Fatal("User should have been stripped from the URL") 45 } 46 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 47 t.Fatal("Basic auth header is incorrect") 48 } 49 if header.Get("origin") != "https://example.com" { 50 t.Fatal("Origin not set") 51 } 52 } 53 54 // This test checks that the server rejects connections from disallowed origins. 55 func TestWebsocketOriginCheck(t *testing.T) { 56 t.Parallel() 57 58 var ( 59 srv = newTestServer() 60 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 61 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 62 ) 63 defer srv.Stop() 64 defer httpsrv.Close() 65 66 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 67 if err == nil { 68 client.Close() 69 t.Fatal("no error for wrong origin") 70 } 71 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 72 if !errors.Is(err, wantErr) { 73 t.Fatalf("wrong error for wrong origin: %q", err) 74 } 75 76 // Connections without origin header should work. 77 client, err = DialWebsocket(context.Background(), wsURL, "") 78 if err != nil { 79 t.Fatal("error for empty origin") 80 } 81 client.Close() 82 } 83 84 // This test checks whether calls exceeding the request size limit are rejected. 85 func TestWebsocketLargeCall(t *testing.T) { 86 t.Parallel() 87 88 var ( 89 srv = newTestServer() 90 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 91 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 92 ) 93 defer srv.Stop() 94 defer httpsrv.Close() 95 96 client, err := DialWebsocket(context.Background(), wsURL, "") 97 if err != nil { 98 t.Fatalf("can't dial: %v", err) 99 } 100 defer client.Close() 101 102 // This call sends slightly less than the limit and should work. 103 var result echoResult 104 arg := strings.Repeat("x", maxRequestContentLength-200) 105 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 106 t.Fatalf("valid call didn't work: %v", err) 107 } 108 if result.String != arg { 109 t.Fatal("wrong string echoed") 110 } 111 112 // This call sends twice the allowed size and shouldn't work. 113 arg = strings.Repeat("x", maxRequestContentLength*2) 114 err = client.Call(&result, "test_echo", arg) 115 if err == nil { 116 t.Fatal("no error for too large call") 117 } 118 } 119 120 // This test checks that client handles WebSocket ping frames correctly. 121 func TestClientWebsocketPing(t *testing.T) { 122 t.Parallel() 123 124 var ( 125 sendPing = make(chan struct{}) 126 server = wsPingTestServer(t, sendPing) 127 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 128 ) 129 defer cancel() 130 defer server.Shutdown(ctx) 131 132 client, err := DialContext(ctx, "ws://"+server.Addr) 133 if err != nil { 134 t.Fatalf("client dial error: %v", err) 135 } 136 defer client.Close() 137 138 resultChan := make(chan int) 139 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 140 if err != nil { 141 t.Fatalf("client subscribe error: %v", err) 142 } 143 // Note: Unsubscribe is not called on this subscription because the mockup 144 // server can't handle the request. 145 146 // Wait for the context's deadline to be reached before proceeding. 147 // This is important for reproducing https://github.com/LampardNguyen234/go-ethereum/issues/19798 148 <-ctx.Done() 149 close(sendPing) 150 151 // Wait for the subscription result. 152 timeout := time.NewTimer(5 * time.Second) 153 defer timeout.Stop() 154 for { 155 select { 156 case err := <-sub.Err(): 157 t.Error("client subscription error:", err) 158 case result := <-resultChan: 159 t.Log("client got result:", result) 160 return 161 case <-timeout.C: 162 t.Error("didn't get any result within the test timeout") 163 return 164 } 165 } 166 } 167 168 // This checks that the websocket transport can deal with large messages. 169 func TestClientWebsocketLargeMessage(t *testing.T) { 170 var ( 171 srv = NewServer() 172 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 173 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 174 ) 175 defer srv.Stop() 176 defer httpsrv.Close() 177 178 respLength := wsMessageSizeLimit - 50 179 srv.RegisterName("test", largeRespService{respLength}) 180 181 c, err := DialWebsocket(context.Background(), wsURL, "") 182 if err != nil { 183 t.Fatal(err) 184 } 185 186 var r string 187 if err := c.Call(&r, "test_largeResp"); err != nil { 188 t.Fatal("call failed:", err) 189 } 190 if len(r) != respLength { 191 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 192 } 193 } 194 195 func TestClientWebsocketSevered(t *testing.T) { 196 t.Parallel() 197 198 var ( 199 server = wsPingTestServer(t, nil) 200 ctx = context.Background() 201 ) 202 defer server.Shutdown(ctx) 203 204 u, err := url.Parse("http://" + server.Addr) 205 if err != nil { 206 t.Fatal(err) 207 } 208 rproxy := httputil.NewSingleHostReverseProxy(u) 209 var severable *severableReadWriteCloser 210 rproxy.ModifyResponse = func(response *http.Response) error { 211 severable = &severableReadWriteCloser{ReadWriteCloser: response.Body.(io.ReadWriteCloser)} 212 response.Body = severable 213 return nil 214 } 215 frontendProxy := httptest.NewServer(rproxy) 216 defer frontendProxy.Close() 217 218 wsURL := "ws:" + strings.TrimPrefix(frontendProxy.URL, "http:") 219 client, err := DialWebsocket(ctx, wsURL, "") 220 if err != nil { 221 t.Fatalf("client dial error: %v", err) 222 } 223 defer client.Close() 224 225 resultChan := make(chan int) 226 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 227 if err != nil { 228 t.Fatalf("client subscribe error: %v", err) 229 } 230 231 // sever the connection 232 severable.Sever() 233 234 // Wait for subscription error. 235 timeout := time.NewTimer(3 * wsPingInterval) 236 defer timeout.Stop() 237 for { 238 select { 239 case err := <-sub.Err(): 240 t.Log("client subscription error:", err) 241 return 242 case result := <-resultChan: 243 t.Error("unexpected result:", result) 244 return 245 case <-timeout.C: 246 t.Error("didn't get any error within the test timeout") 247 return 248 } 249 } 250 } 251 252 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 253 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 254 // pong and finally delivers a single subscription result. 255 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 256 var srv http.Server 257 shutdown := make(chan struct{}) 258 srv.RegisterOnShutdown(func() { 259 close(shutdown) 260 }) 261 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 262 // Upgrade to WebSocket. 263 upgrader := websocket.Upgrader{ 264 CheckOrigin: func(r *http.Request) bool { return true }, 265 } 266 conn, err := upgrader.Upgrade(w, r, nil) 267 if err != nil { 268 t.Errorf("server WS upgrade error: %v", err) 269 return 270 } 271 defer conn.Close() 272 273 // Handle the connection. 274 wsPingTestHandler(t, conn, shutdown, sendPing) 275 }) 276 277 // Start the server. 278 listener, err := net.Listen("tcp", "127.0.0.1:0") 279 if err != nil { 280 t.Fatal("can't listen:", err) 281 } 282 srv.Addr = listener.Addr().String() 283 go srv.Serve(listener) 284 return &srv 285 } 286 287 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 288 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 289 const ( 290 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 291 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 292 ) 293 294 // Handle subscribe request. 295 if _, _, err := conn.ReadMessage(); err != nil { 296 t.Errorf("server read error: %v", err) 297 return 298 } 299 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 300 t.Errorf("server write error: %v", err) 301 return 302 } 303 304 // Read from the connection to process control messages. 305 var pongCh = make(chan string) 306 conn.SetPongHandler(func(d string) error { 307 t.Logf("server got pong: %q", d) 308 pongCh <- d 309 return nil 310 }) 311 go func() { 312 for { 313 typ, msg, err := conn.ReadMessage() 314 if err != nil { 315 return 316 } 317 t.Logf("server got message (%d): %q", typ, msg) 318 } 319 }() 320 321 // Write messages. 322 var ( 323 wantPong string 324 timer = time.NewTimer(0) 325 ) 326 defer timer.Stop() 327 <-timer.C 328 for { 329 select { 330 case _, open := <-sendPing: 331 if !open { 332 sendPing = nil 333 } 334 t.Logf("server sending ping") 335 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 336 wantPong = "ping" 337 case data := <-pongCh: 338 if wantPong == "" { 339 t.Errorf("unexpected pong") 340 } else if data != wantPong { 341 t.Errorf("got pong with wrong data %q", data) 342 } 343 wantPong = "" 344 timer.Reset(200 * time.Millisecond) 345 case <-timer.C: 346 t.Logf("server sending response") 347 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 348 case <-shutdown: 349 conn.Close() 350 return 351 } 352 } 353 } 354 355 // severableReadWriteCloser wraps an io.ReadWriteCloser and provides a Sever() method to drop writes and read empty. 356 type severableReadWriteCloser struct { 357 io.ReadWriteCloser 358 severed int32 // atomic 359 } 360 361 func (s *severableReadWriteCloser) Sever() { 362 atomic.StoreInt32(&s.severed, 1) 363 } 364 365 func (s *severableReadWriteCloser) Read(p []byte) (n int, err error) { 366 if atomic.LoadInt32(&s.severed) > 0 { 367 return 0, nil 368 } 369 return s.ReadWriteCloser.Read(p) 370 } 371 372 func (s *severableReadWriteCloser) Write(p []byte) (n int, err error) { 373 if atomic.LoadInt32(&s.severed) > 0 { 374 return len(p), nil 375 } 376 return s.ReadWriteCloser.Write(p) 377 } 378 379 func (s *severableReadWriteCloser) Close() error { 380 return s.ReadWriteCloser.Close() 381 }