github.com/jimmyx0x/go-ethereum@v1.10.28/rpc/websocket_test.go (about) 1 // Copyright 2018 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "net/http" 24 "net/http/httptest" 25 "strings" 26 "testing" 27 "time" 28 29 "github.com/gorilla/websocket" 30 ) 31 32 func TestWebsocketClientHeaders(t *testing.T) { 33 t.Parallel() 34 35 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 36 if err != nil { 37 t.Fatalf("wsGetConfig failed: %s", err) 38 } 39 if endpoint != "wss://example.com:1234" { 40 t.Fatal("User should have been stripped from the URL") 41 } 42 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 43 t.Fatal("Basic auth header is incorrect") 44 } 45 if header.Get("origin") != "https://example.com" { 46 t.Fatal("Origin not set") 47 } 48 } 49 50 // This test checks that the server rejects connections from disallowed origins. 51 func TestWebsocketOriginCheck(t *testing.T) { 52 t.Parallel() 53 54 var ( 55 srv = newTestServer() 56 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 57 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 58 ) 59 defer srv.Stop() 60 defer httpsrv.Close() 61 62 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 63 if err == nil { 64 client.Close() 65 t.Fatal("no error for wrong origin") 66 } 67 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 68 if !errors.Is(err, wantErr) { 69 t.Fatalf("wrong error for wrong origin: %q", err) 70 } 71 72 // Connections without origin header should work. 73 client, err = DialWebsocket(context.Background(), wsURL, "") 74 if err != nil { 75 t.Fatalf("error for empty origin: %v", err) 76 } 77 client.Close() 78 } 79 80 // This test checks whether calls exceeding the request size limit are rejected. 81 func TestWebsocketLargeCall(t *testing.T) { 82 t.Parallel() 83 84 var ( 85 srv = newTestServer() 86 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 87 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 88 ) 89 defer srv.Stop() 90 defer httpsrv.Close() 91 92 client, err := DialWebsocket(context.Background(), wsURL, "") 93 if err != nil { 94 t.Fatalf("can't dial: %v", err) 95 } 96 defer client.Close() 97 98 // This call sends slightly less than the limit and should work. 99 var result echoResult 100 arg := strings.Repeat("x", maxRequestContentLength-200) 101 if err := client.Call(&result, "test_echo", arg, 1); err != nil { 102 t.Fatalf("valid call didn't work: %v", err) 103 } 104 if result.String != arg { 105 t.Fatal("wrong string echoed") 106 } 107 108 // This call sends twice the allowed size and shouldn't work. 109 arg = strings.Repeat("x", maxRequestContentLength*2) 110 err = client.Call(&result, "test_echo", arg) 111 if err == nil { 112 t.Fatal("no error for too large call") 113 } 114 } 115 116 func TestWebsocketPeerInfo(t *testing.T) { 117 var ( 118 s = newTestServer() 119 ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) 120 tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") 121 ) 122 defer s.Stop() 123 defer ts.Close() 124 125 ctx := context.Background() 126 c, err := DialWebsocket(ctx, tsurl, "origin.example.com") 127 if err != nil { 128 t.Fatal(err) 129 } 130 131 // Request peer information. 132 var connInfo PeerInfo 133 if err := c.Call(&connInfo, "test_peerInfo"); err != nil { 134 t.Fatal(err) 135 } 136 137 if connInfo.RemoteAddr == "" { 138 t.Error("RemoteAddr not set") 139 } 140 if connInfo.Transport != "ws" { 141 t.Errorf("wrong Transport %q", connInfo.Transport) 142 } 143 if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { 144 t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) 145 } 146 if connInfo.HTTP.Origin != "origin.example.com" { 147 t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) 148 } 149 } 150 151 // This test checks that client handles WebSocket ping frames correctly. 152 func TestClientWebsocketPing(t *testing.T) { 153 t.Parallel() 154 155 var ( 156 sendPing = make(chan struct{}) 157 server = wsPingTestServer(t, sendPing) 158 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 159 ) 160 defer cancel() 161 defer server.Shutdown(ctx) 162 163 client, err := DialContext(ctx, "ws://"+server.Addr) 164 if err != nil { 165 t.Fatalf("client dial error: %v", err) 166 } 167 defer client.Close() 168 169 resultChan := make(chan int) 170 sub, err := client.EthSubscribe(ctx, resultChan, "foo") 171 if err != nil { 172 t.Fatalf("client subscribe error: %v", err) 173 } 174 // Note: Unsubscribe is not called on this subscription because the mockup 175 // server can't handle the request. 176 177 // Wait for the context's deadline to be reached before proceeding. 178 // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 179 <-ctx.Done() 180 close(sendPing) 181 182 // Wait for the subscription result. 183 timeout := time.NewTimer(5 * time.Second) 184 defer timeout.Stop() 185 for { 186 select { 187 case err := <-sub.Err(): 188 t.Error("client subscription error:", err) 189 case result := <-resultChan: 190 t.Log("client got result:", result) 191 return 192 case <-timeout.C: 193 t.Error("didn't get any result within the test timeout") 194 return 195 } 196 } 197 } 198 199 // This checks that the websocket transport can deal with large messages. 200 func TestClientWebsocketLargeMessage(t *testing.T) { 201 var ( 202 srv = NewServer() 203 httpsrv = httptest.NewServer(srv.WebsocketHandler(nil)) 204 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 205 ) 206 defer srv.Stop() 207 defer httpsrv.Close() 208 209 respLength := wsMessageSizeLimit - 50 210 srv.RegisterName("test", largeRespService{respLength}) 211 212 c, err := DialWebsocket(context.Background(), wsURL, "") 213 if err != nil { 214 t.Fatal(err) 215 } 216 217 var r string 218 if err := c.Call(&r, "test_largeResp"); err != nil { 219 t.Fatal("call failed:", err) 220 } 221 if len(r) != respLength { 222 t.Fatalf("response has wrong length %d, want %d", len(r), respLength) 223 } 224 } 225 226 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 227 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 228 // pong and finally delivers a single subscription result. 229 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 230 var srv http.Server 231 shutdown := make(chan struct{}) 232 srv.RegisterOnShutdown(func() { 233 close(shutdown) 234 }) 235 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 236 // Upgrade to WebSocket. 237 upgrader := websocket.Upgrader{ 238 CheckOrigin: func(r *http.Request) bool { return true }, 239 } 240 conn, err := upgrader.Upgrade(w, r, nil) 241 if err != nil { 242 t.Errorf("server WS upgrade error: %v", err) 243 return 244 } 245 defer conn.Close() 246 247 // Handle the connection. 248 wsPingTestHandler(t, conn, shutdown, sendPing) 249 }) 250 251 // Start the server. 252 listener, err := net.Listen("tcp", "127.0.0.1:0") 253 if err != nil { 254 t.Fatal("can't listen:", err) 255 } 256 srv.Addr = listener.Addr().String() 257 go srv.Serve(listener) 258 return &srv 259 } 260 261 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 262 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 263 const ( 264 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 265 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 266 ) 267 268 // Handle subscribe request. 269 if _, _, err := conn.ReadMessage(); err != nil { 270 t.Errorf("server read error: %v", err) 271 return 272 } 273 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 274 t.Errorf("server write error: %v", err) 275 return 276 } 277 278 // Read from the connection to process control messages. 279 var pongCh = make(chan string) 280 conn.SetPongHandler(func(d string) error { 281 t.Logf("server got pong: %q", d) 282 pongCh <- d 283 return nil 284 }) 285 go func() { 286 for { 287 typ, msg, err := conn.ReadMessage() 288 if err != nil { 289 return 290 } 291 t.Logf("server got message (%d): %q", typ, msg) 292 } 293 }() 294 295 // Write messages. 296 var ( 297 wantPong string 298 timer = time.NewTimer(0) 299 ) 300 defer timer.Stop() 301 <-timer.C 302 for { 303 select { 304 case _, open := <-sendPing: 305 if !open { 306 sendPing = nil 307 } 308 t.Logf("server sending ping") 309 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 310 wantPong = "ping" 311 case data := <-pongCh: 312 if wantPong == "" { 313 t.Errorf("unexpected pong") 314 } else if data != wantPong { 315 t.Errorf("got pong with wrong data %q", data) 316 } 317 wantPong = "" 318 timer.Reset(200 * time.Millisecond) 319 case <-timer.C: 320 t.Logf("server sending response") 321 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 322 case <-shutdown: 323 conn.Close() 324 return 325 } 326 } 327 }