github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/client_test.go (about) 1 // Copyright 2016 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "fmt" 22 "math/rand" 23 "net" 24 "net/http" 25 "net/http/httptest" 26 "os" 27 "reflect" 28 "runtime" 29 "strings" 30 "sync" 31 "testing" 32 "time" 33 34 "github.com/davecgh/go-spew/spew" 35 "github.com/kisexp/xdchain/core/types" 36 "github.com/kisexp/xdchain/log" 37 "github.com/stretchr/testify/assert" 38 ) 39 40 func TestClientRequest(t *testing.T) { 41 server := newTestServer() 42 defer server.Stop() 43 client := DialInProc(server) 44 defer client.Close() 45 46 var resp echoResult 47 if err := client.Call(&resp, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { 48 t.Fatal(err) 49 } 50 if !reflect.DeepEqual(resp, echoResult{"hello", 10, &echoArgs{"world"}}) { 51 t.Errorf("incorrect result %#v", resp) 52 } 53 } 54 55 func TestClientResponseType(t *testing.T) { 56 server := newTestServer() 57 defer server.Stop() 58 client := DialInProc(server) 59 defer client.Close() 60 61 if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { 62 t.Errorf("Passing nil as result should be fine, but got an error: %v", err) 63 } 64 var resultVar echoResult 65 // Note: passing the var, not a ref 66 err := client.Call(resultVar, "test_echo", "hello", 10, &echoArgs{"world"}) 67 if err == nil { 68 t.Error("Passing a var as result should be an error") 69 } 70 } 71 72 // This test checks that server-returned errors with code and data come out of Client.Call. 73 func TestClientErrorData(t *testing.T) { 74 server := newTestServer() 75 defer server.Stop() 76 client := DialInProc(server) 77 defer client.Close() 78 79 var resp interface{} 80 err := client.Call(&resp, "test_returnError") 81 if err == nil { 82 t.Fatal("expected error") 83 } 84 85 // Check code. 86 if e, ok := err.(Error); !ok { 87 t.Fatalf("client did not return rpc.Error, got %#v", e) 88 } else if e.ErrorCode() != (testError{}.ErrorCode()) { 89 t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), testError{}.ErrorCode()) 90 } 91 // Check data. 92 if e, ok := err.(DataError); !ok { 93 t.Fatalf("client did not return rpc.DataError, got %#v", e) 94 } else if e.ErrorData() != (testError{}.ErrorData()) { 95 t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData()) 96 } 97 } 98 99 func TestClientBatchRequest(t *testing.T) { 100 server := newTestServer() 101 defer server.Stop() 102 client := DialInProc(server) 103 defer client.Close() 104 105 batch := []BatchElem{ 106 { 107 Method: "test_echo", 108 Args: []interface{}{"hello", 10, &echoArgs{"world"}}, 109 Result: new(echoResult), 110 }, 111 { 112 Method: "test_echo", 113 Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, 114 Result: new(echoResult), 115 }, 116 { 117 Method: "no_such_method", 118 Args: []interface{}{1, 2, 3}, 119 Result: new(int), 120 }, 121 } 122 if err := client.BatchCall(batch); err != nil { 123 t.Fatal(err) 124 } 125 wantResult := []BatchElem{ 126 { 127 Method: "test_echo", 128 Args: []interface{}{"hello", 10, &echoArgs{"world"}}, 129 Result: &echoResult{"hello", 10, &echoArgs{"world"}}, 130 }, 131 { 132 Method: "test_echo", 133 Args: []interface{}{"hello2", 11, &echoArgs{"world"}}, 134 Result: &echoResult{"hello2", 11, &echoArgs{"world"}}, 135 }, 136 { 137 Method: "no_such_method", 138 Args: []interface{}{1, 2, 3}, 139 Result: new(int), 140 Error: &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"}, 141 }, 142 } 143 if !reflect.DeepEqual(batch, wantResult) { 144 t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult)) 145 } 146 } 147 148 func TestClientNotify(t *testing.T) { 149 server := newTestServer() 150 defer server.Stop() 151 client := DialInProc(server) 152 defer client.Close() 153 154 if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil { 155 t.Fatal(err) 156 } 157 } 158 159 // func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) } 160 func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) } 161 func TestClientCancelHTTP(t *testing.T) { testClientCancel("http", t) } 162 func TestClientCancelIPC(t *testing.T) { testClientCancel("ipc", t) } 163 164 // This test checks that requests made through CallContext can be canceled by canceling 165 // the context. 166 func testClientCancel(transport string, t *testing.T) { 167 // These tests take a lot of time, run them all at once. 168 // You probably want to run with -parallel 1 or comment out 169 // the call to t.Parallel if you enable the logging. 170 t.Parallel() 171 172 server := newTestServer() 173 defer server.Stop() 174 175 // What we want to achieve is that the context gets canceled 176 // at various stages of request processing. The interesting cases 177 // are: 178 // - cancel during dial 179 // - cancel while performing a HTTP request 180 // - cancel while waiting for a response 181 // 182 // To trigger those, the times are chosen such that connections 183 // are killed within the deadline for every other call (maxKillTimeout 184 // is 2x maxCancelTimeout). 185 // 186 // Once a connection is dead, there is a fair chance it won't connect 187 // successfully because the accept is delayed by 1s. 188 maxContextCancelTimeout := 300 * time.Millisecond 189 fl := &flakeyListener{ 190 maxAcceptDelay: 1 * time.Second, 191 maxKillTimeout: 600 * time.Millisecond, 192 } 193 194 var client *Client 195 switch transport { 196 case "ws", "http": 197 c, hs := httpTestClient(server, transport, fl) 198 defer hs.Close() 199 client = c 200 case "ipc": 201 c, l := ipcTestClient(server, fl) 202 defer l.Close() 203 client = c 204 default: 205 panic("unknown transport: " + transport) 206 } 207 208 // The actual test starts here. 209 var ( 210 wg sync.WaitGroup 211 nreqs = 10 212 ncallers = 10 213 ) 214 caller := func(index int) { 215 defer wg.Done() 216 for i := 0; i < nreqs; i++ { 217 var ( 218 ctx context.Context 219 cancel func() 220 timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout))) 221 ) 222 if index < ncallers/2 { 223 // For half of the callers, create a context without deadline 224 // and cancel it later. 225 ctx, cancel = context.WithCancel(context.Background()) 226 time.AfterFunc(timeout, cancel) 227 } else { 228 // For the other half, create a context with a deadline instead. This is 229 // different because the context deadline is used to set the socket write 230 // deadline. 231 ctx, cancel = context.WithTimeout(context.Background(), timeout) 232 } 233 234 // Now perform a call with the context. 235 // The key thing here is that no call will ever complete successfully. 236 err := client.CallContext(ctx, nil, "test_block") 237 switch { 238 case err == nil: 239 _, hasDeadline := ctx.Deadline() 240 t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline) 241 // default: 242 // t.Logf("got expected error with %v wait time: %v", timeout, err) 243 } 244 cancel() 245 } 246 } 247 wg.Add(ncallers) 248 for i := 0; i < ncallers; i++ { 249 go caller(i) 250 } 251 wg.Wait() 252 } 253 254 func TestClientSubscribeInvalidArg(t *testing.T) { 255 server := newTestServer() 256 defer server.Stop() 257 client := DialInProc(server) 258 defer client.Close() 259 260 check := func(shouldPanic bool, arg interface{}) { 261 defer func() { 262 err := recover() 263 if shouldPanic && err == nil { 264 t.Errorf("EthSubscribe should've panicked for %#v", arg) 265 } 266 if !shouldPanic && err != nil { 267 t.Errorf("EthSubscribe shouldn't have panicked for %#v", arg) 268 buf := make([]byte, 1024*1024) 269 buf = buf[:runtime.Stack(buf, false)] 270 t.Error(err) 271 t.Error(string(buf)) 272 } 273 }() 274 client.EthSubscribe(context.Background(), arg, "foo_bar") 275 } 276 check(true, nil) 277 check(true, 1) 278 check(true, (chan int)(nil)) 279 check(true, make(<-chan int)) 280 check(false, make(chan int)) 281 check(false, make(chan<- int)) 282 } 283 284 func TestClientSubscribe(t *testing.T) { 285 server := newTestServer() 286 defer server.Stop() 287 client := DialInProc(server) 288 defer client.Close() 289 290 nc := make(chan int) 291 count := 10 292 sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0) 293 if err != nil { 294 t.Fatal("can't subscribe:", err) 295 } 296 for i := 0; i < count; i++ { 297 if val := <-nc; val != i { 298 t.Fatalf("value mismatch: got %d, want %d", val, i) 299 } 300 } 301 302 sub.Unsubscribe() 303 select { 304 case v := <-nc: 305 t.Fatal("received value after unsubscribe:", v) 306 case err := <-sub.Err(): 307 if err != nil { 308 t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err) 309 } 310 case <-time.After(1 * time.Second): 311 t.Fatalf("subscription not closed within 1s after unsubscribe") 312 } 313 } 314 315 // In this test, the connection drops while Subscribe is waiting for a response. 316 func TestClientSubscribeClose(t *testing.T) { 317 server := newTestServer() 318 service := ¬ificationTestService{ 319 gotHangSubscriptionReq: make(chan struct{}), 320 unblockHangSubscription: make(chan struct{}), 321 } 322 if err := server.RegisterName("nftest2", service); err != nil { 323 t.Fatal(err) 324 } 325 326 defer server.Stop() 327 client := DialInProc(server) 328 defer client.Close() 329 330 var ( 331 nc = make(chan int) 332 errc = make(chan error, 1) 333 sub *ClientSubscription 334 err error 335 ) 336 go func() { 337 sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999) 338 errc <- err 339 }() 340 341 <-service.gotHangSubscriptionReq 342 client.Close() 343 service.unblockHangSubscription <- struct{}{} 344 345 select { 346 case err := <-errc: 347 if err == nil { 348 t.Errorf("Subscribe returned nil error after Close") 349 } 350 if sub != nil { 351 t.Error("Subscribe returned non-nil subscription after Close") 352 } 353 case <-time.After(1 * time.Second): 354 t.Fatalf("Subscribe did not return within 1s after Close") 355 } 356 } 357 358 // This test reproduces https://github.com/kisexp/xdchain/issues/17837 where the 359 // client hangs during shutdown when Unsubscribe races with Client.Close. 360 func TestClientCloseUnsubscribeRace(t *testing.T) { 361 server := newTestServer() 362 defer server.Stop() 363 364 for i := 0; i < 20; i++ { 365 client := DialInProc(server) 366 nc := make(chan int) 367 sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1) 368 if err != nil { 369 t.Fatal(err) 370 } 371 go client.Close() 372 go sub.Unsubscribe() 373 select { 374 case <-sub.Err(): 375 case <-time.After(5 * time.Second): 376 t.Fatal("subscription not closed within timeout") 377 } 378 } 379 } 380 381 // This test checks that Client doesn't lock up when a single subscriber 382 // doesn't read subscription events. 383 func TestClientNotificationStorm(t *testing.T) { 384 server := newTestServer() 385 defer server.Stop() 386 387 doTest := func(count int, wantError bool) { 388 client := DialInProc(server) 389 defer client.Close() 390 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 391 defer cancel() 392 393 // Subscribe on the server. It will start sending many notifications 394 // very quickly. 395 nc := make(chan int) 396 sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0) 397 if err != nil { 398 t.Fatal("can't subscribe:", err) 399 } 400 defer sub.Unsubscribe() 401 402 // Process each notification, try to run a call in between each of them. 403 for i := 0; i < count; i++ { 404 select { 405 case val := <-nc: 406 if val != i { 407 t.Fatalf("(%d/%d) unexpected value %d", i, count, val) 408 } 409 case err := <-sub.Err(): 410 if wantError && err != ErrSubscriptionQueueOverflow { 411 t.Fatalf("(%d/%d) got error %q, want %q", i, count, err, ErrSubscriptionQueueOverflow) 412 } else if !wantError { 413 t.Fatalf("(%d/%d) got unexpected error %q", i, count, err) 414 } 415 return 416 } 417 var r int 418 err := client.CallContext(ctx, &r, "nftest_echo", i) 419 if err != nil { 420 if !wantError { 421 t.Fatalf("(%d/%d) call error: %v", i, count, err) 422 } 423 return 424 } 425 } 426 if wantError { 427 t.Fatalf("didn't get expected error") 428 } 429 } 430 431 doTest(8000, false) 432 doTest(24000, true) 433 } 434 435 func TestClientSetHeader(t *testing.T) { 436 var gotHeader bool 437 srv := newTestServer() 438 httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 439 if r.Header.Get("test") == "ok" { 440 gotHeader = true 441 } 442 srv.ServeHTTP(w, r) 443 })) 444 defer httpsrv.Close() 445 defer srv.Stop() 446 447 client, err := Dial(httpsrv.URL) 448 if err != nil { 449 t.Fatal(err) 450 } 451 defer client.Close() 452 453 client.SetHeader("test", "ok") 454 if _, err := client.SupportedModules(); err != nil { 455 t.Fatal(err) 456 } 457 if !gotHeader { 458 t.Fatal("client did not set custom header") 459 } 460 461 // Check that Content-Type can be replaced. 462 client.SetHeader("content-type", "application/x-garbage") 463 _, err = client.SupportedModules() 464 if err == nil { 465 t.Fatal("no error for invalid content-type header") 466 } else if !strings.Contains(err.Error(), "Unsupported Media Type") { 467 t.Fatalf("error is not related to content-type: %q", err) 468 } 469 } 470 471 func TestClientHTTP(t *testing.T) { 472 server := newTestServer() 473 defer server.Stop() 474 475 client, hs := httpTestClient(server, "http", nil) 476 defer hs.Close() 477 defer client.Close() 478 479 // Launch concurrent requests. 480 var ( 481 results = make([]echoResult, 100) 482 errc = make(chan error, len(results)) 483 wantResult = echoResult{"a", 1, new(echoArgs)} 484 ) 485 defer client.Close() 486 for i := range results { 487 i := i 488 go func() { 489 errc <- client.Call(&results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args) 490 }() 491 } 492 493 // Wait for all of them to complete. 494 timeout := time.NewTimer(5 * time.Second) 495 defer timeout.Stop() 496 for i := range results { 497 select { 498 case err := <-errc: 499 if err != nil { 500 t.Fatal(err) 501 } 502 case <-timeout.C: 503 t.Fatalf("timeout (got %d/%d) results)", i+1, len(results)) 504 } 505 } 506 507 // Check results. 508 for i := range results { 509 if !reflect.DeepEqual(results[i], wantResult) { 510 t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult) 511 } 512 } 513 } 514 515 func TestClientReconnect(t *testing.T) { 516 startServer := func(addr string) (*Server, net.Listener) { 517 srv := newTestServer() 518 l, err := net.Listen("tcp", addr) 519 if err != nil { 520 t.Fatal("can't listen:", err) 521 } 522 go http.Serve(l, srv.WebsocketHandler([]string{"*"})) 523 return srv, l 524 } 525 526 ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) 527 defer cancel() 528 529 // Start a server and corresponding client. 530 s1, l1 := startServer("127.0.0.1:0") 531 client, err := DialContext(ctx, "ws://"+l1.Addr().String()) 532 if err != nil { 533 t.Fatal("can't dial", err) 534 } 535 536 // Perform a call. This should work because the server is up. 537 var resp echoResult 538 if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil { 539 t.Fatal(err) 540 } 541 542 // Shut down the server and allow for some cool down time so we can listen on the same 543 // address again. 544 l1.Close() 545 s1.Stop() 546 time.Sleep(2 * time.Second) 547 548 // Try calling again. It shouldn't work. 549 if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil { 550 t.Error("successful call while the server is down") 551 t.Logf("resp: %#v", resp) 552 } 553 554 // Start it up again and call again. The connection should be reestablished. 555 // We spawn multiple calls here to check whether this hangs somehow. 556 s2, l2 := startServer(l1.Addr().String()) 557 defer l2.Close() 558 defer s2.Stop() 559 560 start := make(chan struct{}) 561 errors := make(chan error, 20) 562 for i := 0; i < cap(errors); i++ { 563 go func() { 564 <-start 565 var resp echoResult 566 errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil) 567 }() 568 } 569 close(start) 570 errcount := 0 571 for i := 0; i < cap(errors); i++ { 572 if err = <-errors; err != nil { 573 errcount++ 574 } 575 } 576 t.Logf("%d errors, last error: %v", errcount, err) 577 if errcount > 1 { 578 t.Errorf("expected one error after disconnect, got %d", errcount) 579 } 580 } 581 582 func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) { 583 // Create the HTTP server. 584 var hs *httptest.Server 585 switch transport { 586 case "ws": 587 hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"})) 588 case "http": 589 hs = httptest.NewUnstartedServer(srv) 590 default: 591 panic("unknown HTTP transport: " + transport) 592 } 593 // Wrap the listener if required. 594 if fl != nil { 595 fl.Listener = hs.Listener 596 hs.Listener = fl 597 } 598 // Connect the client. 599 hs.Start() 600 client, err := Dial(transport + "://" + hs.Listener.Addr().String()) 601 if err != nil { 602 panic(err) 603 } 604 return client, hs 605 } 606 607 func ipcTestClient(srv *Server, fl *flakeyListener) (*Client, net.Listener) { 608 // Listen on a random endpoint. 609 endpoint := fmt.Sprintf("go-ethereum-test-ipc-%d-%d", os.Getpid(), rand.Int63()) 610 if runtime.GOOS == "windows" { 611 endpoint = `\\.\pipe\` + endpoint 612 } else { 613 endpoint = os.TempDir() + "/" + endpoint 614 } 615 l, err := ipcListen(endpoint) 616 if err != nil { 617 panic(err) 618 } 619 // Connect the listener to the server. 620 if fl != nil { 621 fl.Listener = l 622 l = fl 623 } 624 go srv.ServeListener(l) 625 // Connect the client. 626 client, err := Dial(endpoint) 627 if err != nil { 628 panic(err) 629 } 630 return client, l 631 } 632 633 // flakeyListener kills accepted connections after a random timeout. 634 type flakeyListener struct { 635 net.Listener 636 maxKillTimeout time.Duration 637 maxAcceptDelay time.Duration 638 } 639 640 func (l *flakeyListener) Accept() (net.Conn, error) { 641 delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay))) 642 time.Sleep(delay) 643 644 c, err := l.Listener.Accept() 645 if err == nil { 646 timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout))) 647 time.AfterFunc(timeout, func() { 648 log.Debug(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout)) 649 c.Close() 650 }) 651 } 652 return c, err 653 } 654 655 func TestClient_withCredentials_whenTargetingHTTP(t *testing.T) { 656 server := newTestServer() 657 server.authenticationManager = &stubAuthenticationManager{isEnabled: true} 658 defer server.Stop() 659 fl := &flakeyListener{ 660 maxAcceptDelay: 1 * time.Second, 661 maxKillTimeout: 600 * time.Millisecond, 662 } 663 hs := httptest.NewUnstartedServer(server) 664 fl.Listener = hs.Listener 665 hs.Listener = fl 666 // Connect the client. 667 hs.Start() 668 defer hs.Close() 669 670 c, err := Dial("http://" + hs.Listener.Addr().String()) 671 assert.NoError(t, err) 672 var f HttpCredentialsProviderFunc = func(ctx context.Context) (string, error) { 673 return "Bearer arbitrary_token", nil 674 } 675 authenticatedClient := c.WithHTTPCredentials(f) 676 677 err = authenticatedClient.CallContext(context.Background(), nil, "arbitrary_call") 678 assert.EqualError(t, err, "arbitrary_call - access denied") 679 } 680 681 func TestClient_withCredentials_whenTargetingWS(t *testing.T) { 682 server := newTestServer() 683 server.authenticationManager = &stubAuthenticationManager{isEnabled: true} 684 defer server.Stop() 685 fl := &flakeyListener{ 686 maxAcceptDelay: 1 * time.Second, 687 maxKillTimeout: 600 * time.Millisecond, 688 } 689 hs := httptest.NewUnstartedServer(server.WebsocketHandler([]string{"*"})) 690 fl.Listener = hs.Listener 691 hs.Listener = fl 692 // Connect the client. 693 hs.Start() 694 defer hs.Close() 695 var f HttpCredentialsProviderFunc = func(ctx context.Context) (string, error) { 696 return "Bearer arbitrary_token", nil 697 } 698 ctx := WithCredentialsProvider(context.Background(), f) 699 authenticatedClient, err := DialContext(ctx, "ws://"+hs.Listener.Addr().String()) 700 assert.NoError(t, err) 701 702 err = authenticatedClient.CallContext(context.Background(), nil, "arbitrary_call") 703 assert.EqualError(t, err, "arbitrary_call - access denied") 704 } 705 706 func TestClient_HTTP_WS_whenDefaultPSI(t *testing.T) { 707 for _, transport := range []string{"http", "ws"} { 708 f := func(transport string) { 709 server := newTestServer() 710 defer server.Stop() 711 712 client, hs := httpTestClient(server, transport, nil) 713 defer hs.Close() 714 defer client.Close() 715 716 verifyPSI(t, client, types.DefaultPrivateStateIdentifier) 717 } 718 f(transport) 719 } 720 } 721 722 func TestClient_InProc_whenDefaultPSI(t *testing.T) { 723 server := newTestServer() 724 defer server.Stop() 725 726 client := DialInProc(server) 727 defer client.Close() 728 729 verifyPSI(t, client, types.DefaultPrivateStateIdentifier) 730 } 731 732 func TestClient_IPC_whenDefaultPSI(t *testing.T) { 733 server := newTestServer() 734 defer server.Stop() 735 736 client, l := ipcTestClient(server, nil) 737 defer l.Close() 738 defer client.Close() 739 740 verifyPSI(t, client, types.DefaultPrivateStateIdentifier) 741 } 742 743 func startHTTPTestServer(transport string) (*Server, *httptest.Server) { 744 handler := newTestServer() 745 // Create the HTTP server. 746 var hs *httptest.Server 747 switch transport { 748 case "ws": 749 hs = httptest.NewUnstartedServer(handler.WebsocketHandler([]string{"*"})) 750 case "http": 751 hs = httptest.NewUnstartedServer(handler) 752 default: 753 panic("unknown HTTP transport: " + transport) 754 } 755 // Connect the client. 756 hs.Start() 757 return handler, hs 758 } 759 760 func TestClient_whenProvidingPSIViaURLParam(t *testing.T) { 761 for _, transport := range []string{"http", "ws"} { 762 f := func(transport string) { 763 expectedPSI := "PS1" 764 srvHandler, srvHttp := startHTTPTestServer(transport) 765 defer func() { 766 srvHandler.Stop() 767 srvHttp.Close() 768 }() 769 770 endpoint := fmt.Sprintf("%s://%s?%s=%s", transport, srvHttp.Listener.Addr().String(), QueryPrivateStateIdentifierParamName, expectedPSI) 771 client, err := Dial(endpoint) 772 assert.NoError(t, err, endpoint) 773 774 verifyPSI(t, client, types.PrivateStateIdentifier(expectedPSI), endpoint) 775 } 776 f(transport) 777 } 778 } 779 780 func TestClient_whenProvidingPSIViaEnvVar(t *testing.T) { 781 for _, transport := range []string{"http", "ws"} { 782 f := func(transport string) { 783 expectedPSI := "PS1" 784 assert.NoError(t, os.Setenv(EnvVarPrivateStateIdentifier, expectedPSI)) 785 defer os.Unsetenv(EnvVarPrivateStateIdentifier) 786 srvHandler, srvHttp := startHTTPTestServer(transport) 787 defer func() { 788 srvHandler.Stop() 789 srvHttp.Close() 790 }() 791 792 endpoint := fmt.Sprintf("%s://%s", transport, srvHttp.Listener.Addr().String()) 793 client, err := Dial(endpoint) 794 assert.NoError(t, err, endpoint) 795 796 verifyPSI(t, client, types.PrivateStateIdentifier(expectedPSI), endpoint) 797 } 798 f(transport) 799 } 800 } 801 802 func TestClient_IPC_whenSetupPSIExplicitly(t *testing.T) { 803 expectedPSI := types.ToPrivateStateIdentifier("arbitrary_psi") 804 server := newTestServer() 805 defer server.Stop() 806 807 client, l := ipcTestClient(server, nil) 808 defer l.Close() 809 defer client.Close() 810 811 client.WithPSI(expectedPSI) 812 813 verifyPSI(t, client, expectedPSI) 814 } 815 816 func TestClient_InProc_whenSetupPSIExplicitly(t *testing.T) { 817 expectedPSI := types.ToPrivateStateIdentifier("arbitrary_psi") 818 server := newTestServer() 819 defer server.Stop() 820 821 client := DialInProc(server) 822 defer client.Close() 823 824 client.WithPSI(expectedPSI) 825 826 verifyPSI(t, client, expectedPSI) 827 } 828 829 func verifyPSI(t *testing.T, client *Client, expectedPSI types.PrivateStateIdentifier, msgAndArgs ...interface{}) { 830 var resp echoPSIResult 831 err := client.Call(&resp, "test_echoCtxPSI") 832 assert.NoError(t, err) 833 834 assert.Equal(t, expectedPSI, resp.PSI, msgAndArgs) 835 }