github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/client_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"math/rand"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"os"
    27  	"reflect"
    28  	"runtime"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/davecgh/go-spew/spew"
    35  	"github.com/kisexp/xdchain/core/types"
    36  	"github.com/kisexp/xdchain/log"
    37  	"github.com/stretchr/testify/assert"
    38  )
    39  
    40  func TestClientRequest(t *testing.T) {
    41  	server := newTestServer()
    42  	defer server.Stop()
    43  	client := DialInProc(server)
    44  	defer client.Close()
    45  
    46  	var resp echoResult
    47  	if err := client.Call(&resp, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
    48  		t.Fatal(err)
    49  	}
    50  	if !reflect.DeepEqual(resp, echoResult{"hello", 10, &echoArgs{"world"}}) {
    51  		t.Errorf("incorrect result %#v", resp)
    52  	}
    53  }
    54  
    55  func TestClientResponseType(t *testing.T) {
    56  	server := newTestServer()
    57  	defer server.Stop()
    58  	client := DialInProc(server)
    59  	defer client.Close()
    60  
    61  	if err := client.Call(nil, "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
    62  		t.Errorf("Passing nil as result should be fine, but got an error: %v", err)
    63  	}
    64  	var resultVar echoResult
    65  	// Note: passing the var, not a ref
    66  	err := client.Call(resultVar, "test_echo", "hello", 10, &echoArgs{"world"})
    67  	if err == nil {
    68  		t.Error("Passing a var as result should be an error")
    69  	}
    70  }
    71  
    72  // This test checks that server-returned errors with code and data come out of Client.Call.
    73  func TestClientErrorData(t *testing.T) {
    74  	server := newTestServer()
    75  	defer server.Stop()
    76  	client := DialInProc(server)
    77  	defer client.Close()
    78  
    79  	var resp interface{}
    80  	err := client.Call(&resp, "test_returnError")
    81  	if err == nil {
    82  		t.Fatal("expected error")
    83  	}
    84  
    85  	// Check code.
    86  	if e, ok := err.(Error); !ok {
    87  		t.Fatalf("client did not return rpc.Error, got %#v", e)
    88  	} else if e.ErrorCode() != (testError{}.ErrorCode()) {
    89  		t.Fatalf("wrong error code %d, want %d", e.ErrorCode(), testError{}.ErrorCode())
    90  	}
    91  	// Check data.
    92  	if e, ok := err.(DataError); !ok {
    93  		t.Fatalf("client did not return rpc.DataError, got %#v", e)
    94  	} else if e.ErrorData() != (testError{}.ErrorData()) {
    95  		t.Fatalf("wrong error data %#v, want %#v", e.ErrorData(), testError{}.ErrorData())
    96  	}
    97  }
    98  
    99  func TestClientBatchRequest(t *testing.T) {
   100  	server := newTestServer()
   101  	defer server.Stop()
   102  	client := DialInProc(server)
   103  	defer client.Close()
   104  
   105  	batch := []BatchElem{
   106  		{
   107  			Method: "test_echo",
   108  			Args:   []interface{}{"hello", 10, &echoArgs{"world"}},
   109  			Result: new(echoResult),
   110  		},
   111  		{
   112  			Method: "test_echo",
   113  			Args:   []interface{}{"hello2", 11, &echoArgs{"world"}},
   114  			Result: new(echoResult),
   115  		},
   116  		{
   117  			Method: "no_such_method",
   118  			Args:   []interface{}{1, 2, 3},
   119  			Result: new(int),
   120  		},
   121  	}
   122  	if err := client.BatchCall(batch); err != nil {
   123  		t.Fatal(err)
   124  	}
   125  	wantResult := []BatchElem{
   126  		{
   127  			Method: "test_echo",
   128  			Args:   []interface{}{"hello", 10, &echoArgs{"world"}},
   129  			Result: &echoResult{"hello", 10, &echoArgs{"world"}},
   130  		},
   131  		{
   132  			Method: "test_echo",
   133  			Args:   []interface{}{"hello2", 11, &echoArgs{"world"}},
   134  			Result: &echoResult{"hello2", 11, &echoArgs{"world"}},
   135  		},
   136  		{
   137  			Method: "no_such_method",
   138  			Args:   []interface{}{1, 2, 3},
   139  			Result: new(int),
   140  			Error:  &jsonError{Code: -32601, Message: "the method no_such_method does not exist/is not available"},
   141  		},
   142  	}
   143  	if !reflect.DeepEqual(batch, wantResult) {
   144  		t.Errorf("batch results mismatch:\ngot %swant %s", spew.Sdump(batch), spew.Sdump(wantResult))
   145  	}
   146  }
   147  
   148  func TestClientNotify(t *testing.T) {
   149  	server := newTestServer()
   150  	defer server.Stop()
   151  	client := DialInProc(server)
   152  	defer client.Close()
   153  
   154  	if err := client.Notify(context.Background(), "test_echo", "hello", 10, &echoArgs{"world"}); err != nil {
   155  		t.Fatal(err)
   156  	}
   157  }
   158  
   159  // func TestClientCancelInproc(t *testing.T) { testClientCancel("inproc", t) }
   160  func TestClientCancelWebsocket(t *testing.T) { testClientCancel("ws", t) }
   161  func TestClientCancelHTTP(t *testing.T)      { testClientCancel("http", t) }
   162  func TestClientCancelIPC(t *testing.T)       { testClientCancel("ipc", t) }
   163  
   164  // This test checks that requests made through CallContext can be canceled by canceling
   165  // the context.
   166  func testClientCancel(transport string, t *testing.T) {
   167  	// These tests take a lot of time, run them all at once.
   168  	// You probably want to run with -parallel 1 or comment out
   169  	// the call to t.Parallel if you enable the logging.
   170  	t.Parallel()
   171  
   172  	server := newTestServer()
   173  	defer server.Stop()
   174  
   175  	// What we want to achieve is that the context gets canceled
   176  	// at various stages of request processing. The interesting cases
   177  	// are:
   178  	//  - cancel during dial
   179  	//  - cancel while performing a HTTP request
   180  	//  - cancel while waiting for a response
   181  	//
   182  	// To trigger those, the times are chosen such that connections
   183  	// are killed within the deadline for every other call (maxKillTimeout
   184  	// is 2x maxCancelTimeout).
   185  	//
   186  	// Once a connection is dead, there is a fair chance it won't connect
   187  	// successfully because the accept is delayed by 1s.
   188  	maxContextCancelTimeout := 300 * time.Millisecond
   189  	fl := &flakeyListener{
   190  		maxAcceptDelay: 1 * time.Second,
   191  		maxKillTimeout: 600 * time.Millisecond,
   192  	}
   193  
   194  	var client *Client
   195  	switch transport {
   196  	case "ws", "http":
   197  		c, hs := httpTestClient(server, transport, fl)
   198  		defer hs.Close()
   199  		client = c
   200  	case "ipc":
   201  		c, l := ipcTestClient(server, fl)
   202  		defer l.Close()
   203  		client = c
   204  	default:
   205  		panic("unknown transport: " + transport)
   206  	}
   207  
   208  	// The actual test starts here.
   209  	var (
   210  		wg       sync.WaitGroup
   211  		nreqs    = 10
   212  		ncallers = 10
   213  	)
   214  	caller := func(index int) {
   215  		defer wg.Done()
   216  		for i := 0; i < nreqs; i++ {
   217  			var (
   218  				ctx     context.Context
   219  				cancel  func()
   220  				timeout = time.Duration(rand.Int63n(int64(maxContextCancelTimeout)))
   221  			)
   222  			if index < ncallers/2 {
   223  				// For half of the callers, create a context without deadline
   224  				// and cancel it later.
   225  				ctx, cancel = context.WithCancel(context.Background())
   226  				time.AfterFunc(timeout, cancel)
   227  			} else {
   228  				// For the other half, create a context with a deadline instead. This is
   229  				// different because the context deadline is used to set the socket write
   230  				// deadline.
   231  				ctx, cancel = context.WithTimeout(context.Background(), timeout)
   232  			}
   233  
   234  			// Now perform a call with the context.
   235  			// The key thing here is that no call will ever complete successfully.
   236  			err := client.CallContext(ctx, nil, "test_block")
   237  			switch {
   238  			case err == nil:
   239  				_, hasDeadline := ctx.Deadline()
   240  				t.Errorf("no error for call with %v wait time (deadline: %v)", timeout, hasDeadline)
   241  				// default:
   242  				// 	t.Logf("got expected error with %v wait time: %v", timeout, err)
   243  			}
   244  			cancel()
   245  		}
   246  	}
   247  	wg.Add(ncallers)
   248  	for i := 0; i < ncallers; i++ {
   249  		go caller(i)
   250  	}
   251  	wg.Wait()
   252  }
   253  
   254  func TestClientSubscribeInvalidArg(t *testing.T) {
   255  	server := newTestServer()
   256  	defer server.Stop()
   257  	client := DialInProc(server)
   258  	defer client.Close()
   259  
   260  	check := func(shouldPanic bool, arg interface{}) {
   261  		defer func() {
   262  			err := recover()
   263  			if shouldPanic && err == nil {
   264  				t.Errorf("EthSubscribe should've panicked for %#v", arg)
   265  			}
   266  			if !shouldPanic && err != nil {
   267  				t.Errorf("EthSubscribe shouldn't have panicked for %#v", arg)
   268  				buf := make([]byte, 1024*1024)
   269  				buf = buf[:runtime.Stack(buf, false)]
   270  				t.Error(err)
   271  				t.Error(string(buf))
   272  			}
   273  		}()
   274  		client.EthSubscribe(context.Background(), arg, "foo_bar")
   275  	}
   276  	check(true, nil)
   277  	check(true, 1)
   278  	check(true, (chan int)(nil))
   279  	check(true, make(<-chan int))
   280  	check(false, make(chan int))
   281  	check(false, make(chan<- int))
   282  }
   283  
   284  func TestClientSubscribe(t *testing.T) {
   285  	server := newTestServer()
   286  	defer server.Stop()
   287  	client := DialInProc(server)
   288  	defer client.Close()
   289  
   290  	nc := make(chan int)
   291  	count := 10
   292  	sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", count, 0)
   293  	if err != nil {
   294  		t.Fatal("can't subscribe:", err)
   295  	}
   296  	for i := 0; i < count; i++ {
   297  		if val := <-nc; val != i {
   298  			t.Fatalf("value mismatch: got %d, want %d", val, i)
   299  		}
   300  	}
   301  
   302  	sub.Unsubscribe()
   303  	select {
   304  	case v := <-nc:
   305  		t.Fatal("received value after unsubscribe:", v)
   306  	case err := <-sub.Err():
   307  		if err != nil {
   308  			t.Fatalf("Err returned a non-nil error after explicit unsubscribe: %q", err)
   309  		}
   310  	case <-time.After(1 * time.Second):
   311  		t.Fatalf("subscription not closed within 1s after unsubscribe")
   312  	}
   313  }
   314  
   315  // In this test, the connection drops while Subscribe is waiting for a response.
   316  func TestClientSubscribeClose(t *testing.T) {
   317  	server := newTestServer()
   318  	service := &notificationTestService{
   319  		gotHangSubscriptionReq:  make(chan struct{}),
   320  		unblockHangSubscription: make(chan struct{}),
   321  	}
   322  	if err := server.RegisterName("nftest2", service); err != nil {
   323  		t.Fatal(err)
   324  	}
   325  
   326  	defer server.Stop()
   327  	client := DialInProc(server)
   328  	defer client.Close()
   329  
   330  	var (
   331  		nc   = make(chan int)
   332  		errc = make(chan error, 1)
   333  		sub  *ClientSubscription
   334  		err  error
   335  	)
   336  	go func() {
   337  		sub, err = client.Subscribe(context.Background(), "nftest2", nc, "hangSubscription", 999)
   338  		errc <- err
   339  	}()
   340  
   341  	<-service.gotHangSubscriptionReq
   342  	client.Close()
   343  	service.unblockHangSubscription <- struct{}{}
   344  
   345  	select {
   346  	case err := <-errc:
   347  		if err == nil {
   348  			t.Errorf("Subscribe returned nil error after Close")
   349  		}
   350  		if sub != nil {
   351  			t.Error("Subscribe returned non-nil subscription after Close")
   352  		}
   353  	case <-time.After(1 * time.Second):
   354  		t.Fatalf("Subscribe did not return within 1s after Close")
   355  	}
   356  }
   357  
   358  // This test reproduces https://github.com/kisexp/xdchain/issues/17837 where the
   359  // client hangs during shutdown when Unsubscribe races with Client.Close.
   360  func TestClientCloseUnsubscribeRace(t *testing.T) {
   361  	server := newTestServer()
   362  	defer server.Stop()
   363  
   364  	for i := 0; i < 20; i++ {
   365  		client := DialInProc(server)
   366  		nc := make(chan int)
   367  		sub, err := client.Subscribe(context.Background(), "nftest", nc, "someSubscription", 3, 1)
   368  		if err != nil {
   369  			t.Fatal(err)
   370  		}
   371  		go client.Close()
   372  		go sub.Unsubscribe()
   373  		select {
   374  		case <-sub.Err():
   375  		case <-time.After(5 * time.Second):
   376  			t.Fatal("subscription not closed within timeout")
   377  		}
   378  	}
   379  }
   380  
   381  // This test checks that Client doesn't lock up when a single subscriber
   382  // doesn't read subscription events.
   383  func TestClientNotificationStorm(t *testing.T) {
   384  	server := newTestServer()
   385  	defer server.Stop()
   386  
   387  	doTest := func(count int, wantError bool) {
   388  		client := DialInProc(server)
   389  		defer client.Close()
   390  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   391  		defer cancel()
   392  
   393  		// Subscribe on the server. It will start sending many notifications
   394  		// very quickly.
   395  		nc := make(chan int)
   396  		sub, err := client.Subscribe(ctx, "nftest", nc, "someSubscription", count, 0)
   397  		if err != nil {
   398  			t.Fatal("can't subscribe:", err)
   399  		}
   400  		defer sub.Unsubscribe()
   401  
   402  		// Process each notification, try to run a call in between each of them.
   403  		for i := 0; i < count; i++ {
   404  			select {
   405  			case val := <-nc:
   406  				if val != i {
   407  					t.Fatalf("(%d/%d) unexpected value %d", i, count, val)
   408  				}
   409  			case err := <-sub.Err():
   410  				if wantError && err != ErrSubscriptionQueueOverflow {
   411  					t.Fatalf("(%d/%d) got error %q, want %q", i, count, err, ErrSubscriptionQueueOverflow)
   412  				} else if !wantError {
   413  					t.Fatalf("(%d/%d) got unexpected error %q", i, count, err)
   414  				}
   415  				return
   416  			}
   417  			var r int
   418  			err := client.CallContext(ctx, &r, "nftest_echo", i)
   419  			if err != nil {
   420  				if !wantError {
   421  					t.Fatalf("(%d/%d) call error: %v", i, count, err)
   422  				}
   423  				return
   424  			}
   425  		}
   426  		if wantError {
   427  			t.Fatalf("didn't get expected error")
   428  		}
   429  	}
   430  
   431  	doTest(8000, false)
   432  	doTest(24000, true)
   433  }
   434  
   435  func TestClientSetHeader(t *testing.T) {
   436  	var gotHeader bool
   437  	srv := newTestServer()
   438  	httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   439  		if r.Header.Get("test") == "ok" {
   440  			gotHeader = true
   441  		}
   442  		srv.ServeHTTP(w, r)
   443  	}))
   444  	defer httpsrv.Close()
   445  	defer srv.Stop()
   446  
   447  	client, err := Dial(httpsrv.URL)
   448  	if err != nil {
   449  		t.Fatal(err)
   450  	}
   451  	defer client.Close()
   452  
   453  	client.SetHeader("test", "ok")
   454  	if _, err := client.SupportedModules(); err != nil {
   455  		t.Fatal(err)
   456  	}
   457  	if !gotHeader {
   458  		t.Fatal("client did not set custom header")
   459  	}
   460  
   461  	// Check that Content-Type can be replaced.
   462  	client.SetHeader("content-type", "application/x-garbage")
   463  	_, err = client.SupportedModules()
   464  	if err == nil {
   465  		t.Fatal("no error for invalid content-type header")
   466  	} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
   467  		t.Fatalf("error is not related to content-type: %q", err)
   468  	}
   469  }
   470  
   471  func TestClientHTTP(t *testing.T) {
   472  	server := newTestServer()
   473  	defer server.Stop()
   474  
   475  	client, hs := httpTestClient(server, "http", nil)
   476  	defer hs.Close()
   477  	defer client.Close()
   478  
   479  	// Launch concurrent requests.
   480  	var (
   481  		results    = make([]echoResult, 100)
   482  		errc       = make(chan error, len(results))
   483  		wantResult = echoResult{"a", 1, new(echoArgs)}
   484  	)
   485  	defer client.Close()
   486  	for i := range results {
   487  		i := i
   488  		go func() {
   489  			errc <- client.Call(&results[i], "test_echo", wantResult.String, wantResult.Int, wantResult.Args)
   490  		}()
   491  	}
   492  
   493  	// Wait for all of them to complete.
   494  	timeout := time.NewTimer(5 * time.Second)
   495  	defer timeout.Stop()
   496  	for i := range results {
   497  		select {
   498  		case err := <-errc:
   499  			if err != nil {
   500  				t.Fatal(err)
   501  			}
   502  		case <-timeout.C:
   503  			t.Fatalf("timeout (got %d/%d) results)", i+1, len(results))
   504  		}
   505  	}
   506  
   507  	// Check results.
   508  	for i := range results {
   509  		if !reflect.DeepEqual(results[i], wantResult) {
   510  			t.Errorf("result %d mismatch: got %#v, want %#v", i, results[i], wantResult)
   511  		}
   512  	}
   513  }
   514  
   515  func TestClientReconnect(t *testing.T) {
   516  	startServer := func(addr string) (*Server, net.Listener) {
   517  		srv := newTestServer()
   518  		l, err := net.Listen("tcp", addr)
   519  		if err != nil {
   520  			t.Fatal("can't listen:", err)
   521  		}
   522  		go http.Serve(l, srv.WebsocketHandler([]string{"*"}))
   523  		return srv, l
   524  	}
   525  
   526  	ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second)
   527  	defer cancel()
   528  
   529  	// Start a server and corresponding client.
   530  	s1, l1 := startServer("127.0.0.1:0")
   531  	client, err := DialContext(ctx, "ws://"+l1.Addr().String())
   532  	if err != nil {
   533  		t.Fatal("can't dial", err)
   534  	}
   535  
   536  	// Perform a call. This should work because the server is up.
   537  	var resp echoResult
   538  	if err := client.CallContext(ctx, &resp, "test_echo", "", 1, nil); err != nil {
   539  		t.Fatal(err)
   540  	}
   541  
   542  	// Shut down the server and allow for some cool down time so we can listen on the same
   543  	// address again.
   544  	l1.Close()
   545  	s1.Stop()
   546  	time.Sleep(2 * time.Second)
   547  
   548  	// Try calling again. It shouldn't work.
   549  	if err := client.CallContext(ctx, &resp, "test_echo", "", 2, nil); err == nil {
   550  		t.Error("successful call while the server is down")
   551  		t.Logf("resp: %#v", resp)
   552  	}
   553  
   554  	// Start it up again and call again. The connection should be reestablished.
   555  	// We spawn multiple calls here to check whether this hangs somehow.
   556  	s2, l2 := startServer(l1.Addr().String())
   557  	defer l2.Close()
   558  	defer s2.Stop()
   559  
   560  	start := make(chan struct{})
   561  	errors := make(chan error, 20)
   562  	for i := 0; i < cap(errors); i++ {
   563  		go func() {
   564  			<-start
   565  			var resp echoResult
   566  			errors <- client.CallContext(ctx, &resp, "test_echo", "", 3, nil)
   567  		}()
   568  	}
   569  	close(start)
   570  	errcount := 0
   571  	for i := 0; i < cap(errors); i++ {
   572  		if err = <-errors; err != nil {
   573  			errcount++
   574  		}
   575  	}
   576  	t.Logf("%d errors, last error: %v", errcount, err)
   577  	if errcount > 1 {
   578  		t.Errorf("expected one error after disconnect, got %d", errcount)
   579  	}
   580  }
   581  
   582  func httpTestClient(srv *Server, transport string, fl *flakeyListener) (*Client, *httptest.Server) {
   583  	// Create the HTTP server.
   584  	var hs *httptest.Server
   585  	switch transport {
   586  	case "ws":
   587  		hs = httptest.NewUnstartedServer(srv.WebsocketHandler([]string{"*"}))
   588  	case "http":
   589  		hs = httptest.NewUnstartedServer(srv)
   590  	default:
   591  		panic("unknown HTTP transport: " + transport)
   592  	}
   593  	// Wrap the listener if required.
   594  	if fl != nil {
   595  		fl.Listener = hs.Listener
   596  		hs.Listener = fl
   597  	}
   598  	// Connect the client.
   599  	hs.Start()
   600  	client, err := Dial(transport + "://" + hs.Listener.Addr().String())
   601  	if err != nil {
   602  		panic(err)
   603  	}
   604  	return client, hs
   605  }
   606  
   607  func ipcTestClient(srv *Server, fl *flakeyListener) (*Client, net.Listener) {
   608  	// Listen on a random endpoint.
   609  	endpoint := fmt.Sprintf("go-ethereum-test-ipc-%d-%d", os.Getpid(), rand.Int63())
   610  	if runtime.GOOS == "windows" {
   611  		endpoint = `\\.\pipe\` + endpoint
   612  	} else {
   613  		endpoint = os.TempDir() + "/" + endpoint
   614  	}
   615  	l, err := ipcListen(endpoint)
   616  	if err != nil {
   617  		panic(err)
   618  	}
   619  	// Connect the listener to the server.
   620  	if fl != nil {
   621  		fl.Listener = l
   622  		l = fl
   623  	}
   624  	go srv.ServeListener(l)
   625  	// Connect the client.
   626  	client, err := Dial(endpoint)
   627  	if err != nil {
   628  		panic(err)
   629  	}
   630  	return client, l
   631  }
   632  
   633  // flakeyListener kills accepted connections after a random timeout.
   634  type flakeyListener struct {
   635  	net.Listener
   636  	maxKillTimeout time.Duration
   637  	maxAcceptDelay time.Duration
   638  }
   639  
   640  func (l *flakeyListener) Accept() (net.Conn, error) {
   641  	delay := time.Duration(rand.Int63n(int64(l.maxAcceptDelay)))
   642  	time.Sleep(delay)
   643  
   644  	c, err := l.Listener.Accept()
   645  	if err == nil {
   646  		timeout := time.Duration(rand.Int63n(int64(l.maxKillTimeout)))
   647  		time.AfterFunc(timeout, func() {
   648  			log.Debug(fmt.Sprintf("killing conn %v after %v", c.LocalAddr(), timeout))
   649  			c.Close()
   650  		})
   651  	}
   652  	return c, err
   653  }
   654  
   655  func TestClient_withCredentials_whenTargetingHTTP(t *testing.T) {
   656  	server := newTestServer()
   657  	server.authenticationManager = &stubAuthenticationManager{isEnabled: true}
   658  	defer server.Stop()
   659  	fl := &flakeyListener{
   660  		maxAcceptDelay: 1 * time.Second,
   661  		maxKillTimeout: 600 * time.Millisecond,
   662  	}
   663  	hs := httptest.NewUnstartedServer(server)
   664  	fl.Listener = hs.Listener
   665  	hs.Listener = fl
   666  	// Connect the client.
   667  	hs.Start()
   668  	defer hs.Close()
   669  
   670  	c, err := Dial("http://" + hs.Listener.Addr().String())
   671  	assert.NoError(t, err)
   672  	var f HttpCredentialsProviderFunc = func(ctx context.Context) (string, error) {
   673  		return "Bearer arbitrary_token", nil
   674  	}
   675  	authenticatedClient := c.WithHTTPCredentials(f)
   676  
   677  	err = authenticatedClient.CallContext(context.Background(), nil, "arbitrary_call")
   678  	assert.EqualError(t, err, "arbitrary_call - access denied")
   679  }
   680  
   681  func TestClient_withCredentials_whenTargetingWS(t *testing.T) {
   682  	server := newTestServer()
   683  	server.authenticationManager = &stubAuthenticationManager{isEnabled: true}
   684  	defer server.Stop()
   685  	fl := &flakeyListener{
   686  		maxAcceptDelay: 1 * time.Second,
   687  		maxKillTimeout: 600 * time.Millisecond,
   688  	}
   689  	hs := httptest.NewUnstartedServer(server.WebsocketHandler([]string{"*"}))
   690  	fl.Listener = hs.Listener
   691  	hs.Listener = fl
   692  	// Connect the client.
   693  	hs.Start()
   694  	defer hs.Close()
   695  	var f HttpCredentialsProviderFunc = func(ctx context.Context) (string, error) {
   696  		return "Bearer arbitrary_token", nil
   697  	}
   698  	ctx := WithCredentialsProvider(context.Background(), f)
   699  	authenticatedClient, err := DialContext(ctx, "ws://"+hs.Listener.Addr().String())
   700  	assert.NoError(t, err)
   701  
   702  	err = authenticatedClient.CallContext(context.Background(), nil, "arbitrary_call")
   703  	assert.EqualError(t, err, "arbitrary_call - access denied")
   704  }
   705  
   706  func TestClient_HTTP_WS_whenDefaultPSI(t *testing.T) {
   707  	for _, transport := range []string{"http", "ws"} {
   708  		f := func(transport string) {
   709  			server := newTestServer()
   710  			defer server.Stop()
   711  
   712  			client, hs := httpTestClient(server, transport, nil)
   713  			defer hs.Close()
   714  			defer client.Close()
   715  
   716  			verifyPSI(t, client, types.DefaultPrivateStateIdentifier)
   717  		}
   718  		f(transport)
   719  	}
   720  }
   721  
   722  func TestClient_InProc_whenDefaultPSI(t *testing.T) {
   723  	server := newTestServer()
   724  	defer server.Stop()
   725  
   726  	client := DialInProc(server)
   727  	defer client.Close()
   728  
   729  	verifyPSI(t, client, types.DefaultPrivateStateIdentifier)
   730  }
   731  
   732  func TestClient_IPC_whenDefaultPSI(t *testing.T) {
   733  	server := newTestServer()
   734  	defer server.Stop()
   735  
   736  	client, l := ipcTestClient(server, nil)
   737  	defer l.Close()
   738  	defer client.Close()
   739  
   740  	verifyPSI(t, client, types.DefaultPrivateStateIdentifier)
   741  }
   742  
   743  func startHTTPTestServer(transport string) (*Server, *httptest.Server) {
   744  	handler := newTestServer()
   745  	// Create the HTTP server.
   746  	var hs *httptest.Server
   747  	switch transport {
   748  	case "ws":
   749  		hs = httptest.NewUnstartedServer(handler.WebsocketHandler([]string{"*"}))
   750  	case "http":
   751  		hs = httptest.NewUnstartedServer(handler)
   752  	default:
   753  		panic("unknown HTTP transport: " + transport)
   754  	}
   755  	// Connect the client.
   756  	hs.Start()
   757  	return handler, hs
   758  }
   759  
   760  func TestClient_whenProvidingPSIViaURLParam(t *testing.T) {
   761  	for _, transport := range []string{"http", "ws"} {
   762  		f := func(transport string) {
   763  			expectedPSI := "PS1"
   764  			srvHandler, srvHttp := startHTTPTestServer(transport)
   765  			defer func() {
   766  				srvHandler.Stop()
   767  				srvHttp.Close()
   768  			}()
   769  
   770  			endpoint := fmt.Sprintf("%s://%s?%s=%s", transport, srvHttp.Listener.Addr().String(), QueryPrivateStateIdentifierParamName, expectedPSI)
   771  			client, err := Dial(endpoint)
   772  			assert.NoError(t, err, endpoint)
   773  
   774  			verifyPSI(t, client, types.PrivateStateIdentifier(expectedPSI), endpoint)
   775  		}
   776  		f(transport)
   777  	}
   778  }
   779  
   780  func TestClient_whenProvidingPSIViaEnvVar(t *testing.T) {
   781  	for _, transport := range []string{"http", "ws"} {
   782  		f := func(transport string) {
   783  			expectedPSI := "PS1"
   784  			assert.NoError(t, os.Setenv(EnvVarPrivateStateIdentifier, expectedPSI))
   785  			defer os.Unsetenv(EnvVarPrivateStateIdentifier)
   786  			srvHandler, srvHttp := startHTTPTestServer(transport)
   787  			defer func() {
   788  				srvHandler.Stop()
   789  				srvHttp.Close()
   790  			}()
   791  
   792  			endpoint := fmt.Sprintf("%s://%s", transport, srvHttp.Listener.Addr().String())
   793  			client, err := Dial(endpoint)
   794  			assert.NoError(t, err, endpoint)
   795  
   796  			verifyPSI(t, client, types.PrivateStateIdentifier(expectedPSI), endpoint)
   797  		}
   798  		f(transport)
   799  	}
   800  }
   801  
   802  func TestClient_IPC_whenSetupPSIExplicitly(t *testing.T) {
   803  	expectedPSI := types.ToPrivateStateIdentifier("arbitrary_psi")
   804  	server := newTestServer()
   805  	defer server.Stop()
   806  
   807  	client, l := ipcTestClient(server, nil)
   808  	defer l.Close()
   809  	defer client.Close()
   810  
   811  	client.WithPSI(expectedPSI)
   812  
   813  	verifyPSI(t, client, expectedPSI)
   814  }
   815  
   816  func TestClient_InProc_whenSetupPSIExplicitly(t *testing.T) {
   817  	expectedPSI := types.ToPrivateStateIdentifier("arbitrary_psi")
   818  	server := newTestServer()
   819  	defer server.Stop()
   820  
   821  	client := DialInProc(server)
   822  	defer client.Close()
   823  
   824  	client.WithPSI(expectedPSI)
   825  
   826  	verifyPSI(t, client, expectedPSI)
   827  }
   828  
   829  func verifyPSI(t *testing.T, client *Client, expectedPSI types.PrivateStateIdentifier, msgAndArgs ...interface{}) {
   830  	var resp echoPSIResult
   831  	err := client.Call(&resp, "test_echoCtxPSI")
   832  	assert.NoError(t, err)
   833  
   834  	assert.Equal(t, expectedPSI, resp.PSI, msgAndArgs)
   835  }