github.com/klaytn/klaytn@v1.10.2/networks/rpc/websocket_test.go (about) 1 // Modifications Copyright 2020 The klaytn Authors 2 // Copyright 2018 The go-ethereum Authors 3 // This file is part of the go-ethereum library. 4 // 5 // The go-ethereum library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The go-ethereum library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 // This file is derived from rpc/websocket_test.go (2020/04/03). 19 // Modified and improved for the klaytn development. 20 21 package rpc 22 23 import ( 24 "context" 25 "encoding/base64" 26 "net" 27 "net/http" 28 "net/http/httptest" 29 "reflect" 30 "strings" 31 "testing" 32 "time" 33 34 "github.com/gorilla/websocket" 35 "github.com/klaytn/klaytn/common" 36 "github.com/stretchr/testify/assert" 37 ) 38 39 type echoArgs struct { 40 S string 41 } 42 43 type echoResult struct { 44 String string 45 Int int 46 Args *echoArgs 47 } 48 49 func TestWebsocketLargeCall(t *testing.T) { 50 t.Parallel() 51 52 // create server 53 var ( 54 srv = newTestServer("service", new(Service)) 55 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) 56 wsAddr = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 57 ) 58 defer srv.Stop() 59 defer httpsrv.Close() 60 time.Sleep(100 * time.Millisecond) 61 62 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 63 defer cancel() 64 client, err := DialWebsocket(ctx, wsAddr, "") 65 if err != nil { 66 t.Fatalf("can't dial: %v", err) 67 } 68 defer client.Close() 69 70 // set configurations before testing 71 var result echoResult 72 method := "service_echo" 73 74 // This call sends slightly less than the limit and should work. 75 arg := strings.Repeat("x", common.MaxRequestContentLength-200) 76 assert.NoError(t, client.Call(&result, method, arg, 1), "valid call didn't work") 77 assert.Equal(t, arg, result.String, "wrong string echoed") 78 79 // This call sends slightly larger than the allowed size and shouldn't work. 80 arg = strings.Repeat("x", common.MaxRequestContentLength) 81 assert.Error(t, client.Call(&result, method, arg, 1), "no error for too large call") 82 } 83 84 func newTestListener() net.Listener { 85 ln, err := net.Listen("tcp", "localhost:0") 86 if err != nil { 87 panic(err) 88 } 89 return ln 90 } 91 92 /* 93 func TestWSServer_MaxConnections(t *testing.T) { 94 // create server 95 var ( 96 srv = newTestServer("service", new(Service)) 97 ln = newTestListener() 98 ) 99 defer srv.Stop() 100 defer ln.Close() 101 102 go NewWSServer([]string{"*"}, srv).Serve(ln) 103 time.Sleep(100 * time.Millisecond) 104 105 // set max websocket connections 106 MaxWebsocketConnections = 3 107 testWebsocketMaxConnections(t, "ws://"+ln.Addr().String(), int(MaxWebsocketConnections)) 108 } 109 */ 110 111 func TestFastWSServer_MaxConnections(t *testing.T) { 112 // create server 113 var ( 114 srv = newTestServer("service", new(Service)) 115 ln = newTestListener() 116 ) 117 defer srv.Stop() 118 defer ln.Close() 119 120 go NewFastWSServer([]string{"*"}, srv).Serve(ln) 121 time.Sleep(100 * time.Millisecond) 122 123 // set max websocket connections 124 MaxWebsocketConnections = 3 125 testWebsocketMaxConnections(t, "ws://"+ln.Addr().String(), int(MaxWebsocketConnections)) 126 } 127 128 func testWebsocketMaxConnections(t *testing.T, addr string, maxConnections int) { 129 var closers []*Client 130 131 for i := 0; i <= maxConnections; i++ { 132 client, err := DialWebsocket(context.Background(), addr, "") 133 if err != nil { 134 t.Fatal(err) 135 } 136 closers = append(closers, client) 137 138 var result echoResult 139 method := "service_echo" 140 arg := strings.Repeat("x", i) 141 err = client.Call(&result, method, arg, 1) 142 if i < int(MaxWebsocketConnections) { 143 assert.NoError(t, err) 144 assert.Equal(t, arg, result.String, "wrong string echoed") 145 } else { 146 assert.Error(t, err) 147 // assert.Equal(t, "EOF", err.Error()) 148 } 149 } 150 151 for _, client := range closers { 152 client.Close() 153 } 154 } 155 156 func TestWebsocketClientHeaders(t *testing.T) { 157 t.Parallel() 158 159 endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") 160 if err != nil { 161 t.Fatalf("wsGetConfig failed: %s", err) 162 } 163 if endpoint != "wss://example.com:1234" { 164 t.Fatal("User should have been stripped from the URL") 165 } 166 if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { 167 t.Fatal("Basic auth header is incorrect") 168 } 169 if header.Get("origin") != "https://example.com" { 170 t.Fatal("Origin not set") 171 } 172 } 173 174 // This test checks that the server rejects connections from disallowed origins. 175 func TestWebsocketOriginCheck(t *testing.T) { 176 t.Parallel() 177 178 var ( 179 srv = newTestServer("service", new(Service)) 180 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 181 wsURL = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") 182 ) 183 defer srv.Stop() 184 defer httpsrv.Close() 185 186 client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") 187 if err == nil { 188 client.Close() 189 t.Fatal("no error for wrong origin") 190 } 191 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} 192 if !reflect.DeepEqual(err, wantErr) { 193 t.Fatalf("wrong error for wrong origin: %q", err) 194 } 195 } 196 197 func TestClientWebsocketPing(t *testing.T) { 198 t.Parallel() 199 200 var ( 201 sendPing = make(chan struct{}) 202 server = wsPingTestServer(t, sendPing) 203 ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) 204 ) 205 defer cancel() 206 defer server.Shutdown(ctx) 207 208 client, err := DialContext(ctx, "ws://"+server.Addr) 209 if err != nil { 210 t.Fatalf("client dial error: %v", err) 211 } 212 resultChan := make(chan int) 213 sub, err := client.KlaySubscribe(ctx, resultChan, "foo") 214 if err != nil { 215 t.Fatalf("client subscribe error: %v", err) 216 } 217 218 // Wait for the context's deadline to be reached before proceeding. 219 // This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 220 <-ctx.Done() 221 close(sendPing) 222 223 // Wait for the subscription result. 224 timeout := time.NewTimer(5 * time.Second) 225 for { 226 select { 227 case err := <-sub.Err(): 228 t.Error("client subscription error:", err) 229 case result := <-resultChan: 230 t.Log("client got result:", result) 231 return 232 case <-timeout.C: 233 t.Error("didn't get any result within the test timeout") 234 return 235 } 236 } 237 } 238 239 // wsPingTestServer runs a WebSocket server which accepts a single subscription request. 240 // When a value arrives on sendPing, the server sends a ping frame, waits for a matching 241 // pong and finally delivers a single subscription result. 242 func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { 243 var srv http.Server 244 shutdown := make(chan struct{}) 245 srv.RegisterOnShutdown(func() { 246 close(shutdown) 247 }) 248 srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 249 // Upgrade to WebSocket. 250 upgrader := websocket.Upgrader{ 251 CheckOrigin: func(r *http.Request) bool { return true }, 252 } 253 conn, err := upgrader.Upgrade(w, r, nil) 254 if err != nil { 255 t.Errorf("server WS upgrade error: %v", err) 256 return 257 } 258 defer conn.Close() 259 260 // Handle the connection. 261 wsPingTestHandler(t, conn, shutdown, sendPing) 262 }) 263 264 // Start the server. 265 listener, err := net.Listen("tcp", "127.0.0.1:0") 266 if err != nil { 267 t.Fatal("can't listen:", err) 268 } 269 srv.Addr = listener.Addr().String() 270 go srv.Serve(listener) 271 return &srv 272 } 273 274 func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { 275 // Canned responses for the eth_subscribe call in TestClientWebsocketPing. 276 const ( 277 subResp = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` 278 subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` 279 ) 280 281 // Handle subscribe request. 282 if _, _, err := conn.ReadMessage(); err != nil { 283 t.Errorf("server read error: %v", err) 284 return 285 } 286 if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { 287 t.Errorf("server write error: %v", err) 288 return 289 } 290 291 // Read from the connection to process control messages. 292 pongCh := make(chan string) 293 conn.SetPongHandler(func(d string) error { 294 t.Logf("server got pong: %q", d) 295 pongCh <- d 296 return nil 297 }) 298 go func() { 299 for { 300 typ, msg, err := conn.ReadMessage() 301 if err != nil { 302 return 303 } 304 t.Logf("server got message (%d): %q", typ, msg) 305 } 306 }() 307 308 // Write messages. 309 var ( 310 sendResponse <-chan time.Time 311 wantPong string 312 ) 313 for { 314 select { 315 case _, open := <-sendPing: 316 if !open { 317 sendPing = nil 318 } 319 t.Logf("server sending ping") 320 conn.WriteMessage(websocket.PingMessage, []byte("ping")) 321 wantPong = "ping" 322 case data := <-pongCh: 323 if wantPong == "" { 324 t.Errorf("unexpected pong") 325 } else if data != wantPong { 326 t.Errorf("got pong with wrong data %q", data) 327 } 328 wantPong = "" 329 sendResponse = time.NewTimer(200 * time.Millisecond).C 330 case <-sendResponse: 331 t.Logf("server sending response") 332 conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) 333 sendResponse = nil 334 case <-shutdown: 335 conn.Close() 336 return 337 } 338 } 339 } 340 341 func TestWebsocketAuthCheck(t *testing.T) { 342 t.Parallel() 343 344 var ( 345 srv = newTestServer("service", new(Service)) 346 httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) 347 wsURL = "ws://testuser:test-PASS_01@" + strings.TrimPrefix(httpsrv.URL, "http://") 348 ) 349 connect := false 350 origHandler := httpsrv.Config.Handler 351 httpsrv.Config.Handler = http.HandlerFunc( 352 func(w http.ResponseWriter, r *http.Request) { 353 auth := r.Header.Get("Authorization") 354 expectedAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("testuser:test-PASS_01")) 355 if r.Method == http.MethodGet && auth == expectedAuth { 356 connect = true 357 w.WriteHeader(http.StatusSwitchingProtocols) 358 return 359 } 360 if !connect { 361 http.Error(w, "connect with authorization not received", http.StatusMethodNotAllowed) 362 return 363 } 364 origHandler.ServeHTTP(w, r) 365 }) 366 defer srv.Stop() 367 defer httpsrv.Close() 368 369 client, err := DialWebsocket(context.Background(), wsURL, "http://example.com") 370 if err == nil { 371 client.Close() 372 t.Fatal("no error for connect with auth header") 373 } 374 wantErr := wsHandshakeError{websocket.ErrBadHandshake, "101 Switching Protocols"} 375 if !reflect.DeepEqual(err, wantErr) { 376 t.Fatalf("wrong error for auth header: %q", err) 377 } 378 }