go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket_client_test.go (about)

     1  package onet
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"os"
    10  	"sort"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/gorilla/websocket"
    16  	"github.com/stretchr/testify/require"
    17  	"go.dedis.ch/onet/v3/log"
    18  	"go.dedis.ch/onet/v3/network"
    19  	"go.dedis.ch/protobuf"
    20  	"golang.org/x/xerrors"
    21  )
    22  
    23  func TestClient_Send(t *testing.T) {
    24  	local := NewTCPTest(tSuite)
    25  	defer local.CloseAll()
    26  
    27  	// register service
    28  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
    29  		return &simpleService{
    30  			ctx: c,
    31  		}, nil
    32  	})
    33  	defer ServiceFactory.Unregister(backForthServiceName)
    34  
    35  	// create servers
    36  	servers, el, _ := local.GenTree(4, false)
    37  	client := local.NewClient(backForthServiceName)
    38  
    39  	r := &SimpleRequest{
    40  		ServerIdentities: el,
    41  		Val:              10,
    42  	}
    43  	sr := &SimpleResponse{}
    44  	require.Equal(t, uint64(0), client.Rx())
    45  	require.Equal(t, uint64(0), client.Tx())
    46  	require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
    47  	require.Equal(t, sr.Val, int64(10))
    48  	require.NotEqual(t, uint64(0), client.Rx())
    49  	require.NotEqual(t, uint64(0), client.Tx())
    50  	require.True(t, client.Tx() > client.Rx())
    51  }
    52  
    53  func TestClientTLS_Send(t *testing.T) {
    54  	cert, key, err := getSelfSignedCertificateAndKey()
    55  	require.Nil(t, err)
    56  	CAPool := x509.NewCertPool()
    57  	CAPool.AppendCertsFromPEM(cert)
    58  
    59  	local := NewTCPTest(tSuite)
    60  	local.webSocketTLSCertificate = cert
    61  	local.webSocketTLSCertificateKey = key
    62  	defer local.CloseAll()
    63  
    64  	// register service
    65  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
    66  		return &simpleService{
    67  			ctx: c,
    68  		}, nil
    69  	})
    70  	defer ServiceFactory.Unregister(backForthServiceName)
    71  
    72  	// create servers
    73  	servers, el, _ := local.GenTree(4, false)
    74  	client := local.NewClient(backForthServiceName)
    75  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
    76  
    77  	r := &SimpleRequest{
    78  		ServerIdentities: el,
    79  		Val:              10,
    80  	}
    81  	sr := &SimpleResponse{}
    82  	require.Equal(t, uint64(0), client.Rx())
    83  	require.Equal(t, uint64(0), client.Tx())
    84  	require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
    85  	require.Equal(t, sr.Val, int64(10))
    86  	require.NotEqual(t, uint64(0), client.Rx())
    87  	require.NotEqual(t, uint64(0), client.Tx())
    88  	require.True(t, client.Tx() > client.Rx())
    89  }
    90  
    91  func TestClientTLS_certfile_Send(t *testing.T) {
    92  	// like TestClientTLSfile_Send, but uses cert and key from a file
    93  	// to solve issue 583.
    94  	cert, key, err := getSelfSignedCertificateAndKey()
    95  	require.Nil(t, err)
    96  	CAPool := x509.NewCertPool()
    97  	CAPool.AppendCertsFromPEM(cert)
    98  
    99  	f1, err := ioutil.TempFile("", "cert")
   100  	require.NoError(t, err)
   101  	defer os.Remove(f1.Name())
   102  	f1.Write(cert)
   103  	f1.Close()
   104  
   105  	f2, err := ioutil.TempFile("", "key")
   106  	require.NoError(t, err)
   107  	defer os.Remove(f2.Name())
   108  	f2.Write(key)
   109  	f2.Close()
   110  
   111  	local := NewTCPTest(tSuite)
   112  	local.webSocketTLSCertificate = []byte(f1.Name())
   113  	local.webSocketTLSCertificateKey = []byte(f2.Name())
   114  	local.webSocketTLSReadFiles = true
   115  	defer local.CloseAll()
   116  
   117  	// register service
   118  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   119  		return &simpleService{
   120  			ctx: c,
   121  		}, nil
   122  	})
   123  	defer ServiceFactory.Unregister(backForthServiceName)
   124  
   125  	// create servers
   126  	servers, el, _ := local.GenTree(4, false)
   127  	client := local.NewClient(backForthServiceName)
   128  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   129  
   130  	r := &SimpleRequest{
   131  		ServerIdentities: el,
   132  		Val:              10,
   133  	}
   134  	sr := &SimpleResponse{}
   135  	require.Equal(t, uint64(0), client.Rx())
   136  	require.Equal(t, uint64(0), client.Tx())
   137  	require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
   138  	require.Equal(t, sr.Val, int64(10))
   139  	require.NotEqual(t, uint64(0), client.Rx())
   140  	require.NotEqual(t, uint64(0), client.Tx())
   141  	require.True(t, client.Tx() > client.Rx())
   142  }
   143  
   144  func TestClient_Parallel(t *testing.T) {
   145  	nbrNodes := 4
   146  	nbrParallel := 20
   147  	local := NewTCPTest(tSuite)
   148  	defer local.CloseAll()
   149  
   150  	// register service
   151  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   152  		return &simpleService{
   153  			ctx: c,
   154  		}, nil
   155  	})
   156  	defer ServiceFactory.Unregister(backForthServiceName)
   157  
   158  	// create servers
   159  	servers, el, _ := local.GenTree(nbrNodes, true)
   160  
   161  	wg := sync.WaitGroup{}
   162  	wg.Add(nbrParallel)
   163  	for i := 0; i < nbrParallel; i++ {
   164  		go func(i int) {
   165  			defer wg.Done()
   166  			log.Lvl1("Starting message", i)
   167  			r := &SimpleRequest{
   168  				ServerIdentities: el,
   169  				Val:              int64(10 * i),
   170  			}
   171  			client := local.NewClient(backForthServiceName)
   172  			sr := &SimpleResponse{}
   173  			err := client.SendProtobuf(servers[0].ServerIdentity, r, sr)
   174  			require.Nil(t, err)
   175  			require.Equal(t, int64(10*i), sr.Val)
   176  			log.Lvl1("Done with message", i)
   177  		}(i)
   178  	}
   179  	wg.Wait()
   180  }
   181  
   182  func TestClientTLS_Parallel(t *testing.T) {
   183  	cert, key, err := getSelfSignedCertificateAndKey()
   184  	require.Nil(t, err)
   185  	CAPool := x509.NewCertPool()
   186  	CAPool.AppendCertsFromPEM(cert)
   187  
   188  	nbrNodes := 4
   189  	nbrParallel := 20
   190  	local := NewTCPTest(tSuite)
   191  	local.webSocketTLSCertificate = cert
   192  	local.webSocketTLSCertificateKey = key
   193  	defer local.CloseAll()
   194  
   195  	// register service
   196  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   197  		return &simpleService{
   198  			ctx: c,
   199  		}, nil
   200  	})
   201  	defer ServiceFactory.Unregister(backForthServiceName)
   202  
   203  	// create servers
   204  	servers, el, _ := local.GenTree(nbrNodes, true)
   205  
   206  	wg := sync.WaitGroup{}
   207  	wg.Add(nbrParallel)
   208  	for i := 0; i < nbrParallel; i++ {
   209  		go func(i int) {
   210  			defer wg.Done()
   211  			log.Lvl1("Starting message", i)
   212  			r := &SimpleRequest{
   213  				ServerIdentities: el,
   214  				Val:              int64(10 * i),
   215  			}
   216  			client := local.NewClient(backForthServiceName)
   217  			client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   218  			sr := &SimpleResponse{}
   219  			require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
   220  			require.Equal(t, int64(10*i), sr.Val)
   221  			log.Lvl1("Done with message", i)
   222  		}(i)
   223  	}
   224  	wg.Wait()
   225  }
   226  
   227  func TestNewClientKeep(t *testing.T) {
   228  	c := NewClientKeep(tSuite, serviceWebSocket)
   229  	require.True(t, c.keep)
   230  }
   231  
   232  func TestMultiplePath(t *testing.T) {
   233  	_, err := RegisterNewService(dummyService3Name, func(c *Context) (Service, error) {
   234  		ds := &DummyService3{}
   235  		return ds, nil
   236  	})
   237  	require.Nil(t, err)
   238  	defer UnregisterService(dummyService3Name)
   239  
   240  	local := NewTCPTest(tSuite)
   241  	hs := local.GenServers(2)
   242  	server := hs[0]
   243  	defer local.CloseAll()
   244  	client := NewClientKeep(tSuite, dummyService3Name)
   245  	msg, err := protobuf.Encode(&DummyMsg{})
   246  	require.Nil(t, err)
   247  	path1, path2 := "path1", "path2"
   248  	resp, err := client.Send(server.ServerIdentity, path1, msg)
   249  	require.Nil(t, err)
   250  	require.Equal(t, path1, string(resp))
   251  	resp, err = client.Send(server.ServerIdentity, path2, msg)
   252  	require.Nil(t, err)
   253  	require.Equal(t, path2, string(resp))
   254  }
   255  
   256  func TestMultiplePathTLS(t *testing.T) {
   257  	cert, key, err := getSelfSignedCertificateAndKey()
   258  	require.Nil(t, err)
   259  	CAPool := x509.NewCertPool()
   260  	CAPool.AppendCertsFromPEM(cert)
   261  
   262  	_, err = RegisterNewService(dummyService3Name, func(c *Context) (Service, error) {
   263  		ds := &DummyService3{}
   264  		return ds, nil
   265  	})
   266  	require.Nil(t, err)
   267  	defer UnregisterService(dummyService3Name)
   268  
   269  	local := NewTCPTest(tSuite)
   270  	local.webSocketTLSCertificate = cert
   271  	local.webSocketTLSCertificateKey = key
   272  	hs := local.GenServers(2)
   273  	server := hs[0]
   274  	defer local.CloseAll()
   275  	client := NewClientKeep(tSuite, dummyService3Name)
   276  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   277  	msg, err := protobuf.Encode(&DummyMsg{})
   278  	require.Nil(t, err)
   279  	path1, path2 := "path1", "path2"
   280  	resp, err := client.Send(server.ServerIdentity, path1, msg)
   281  	require.Nil(t, err)
   282  	require.Equal(t, path1, string(resp))
   283  	resp, err = client.Send(server.ServerIdentity, path2, msg)
   284  	require.Nil(t, err)
   285  	require.Equal(t, path2, string(resp))
   286  }
   287  
   288  func TestWebSocket_Error(t *testing.T) {
   289  	client := NewClientKeep(tSuite, dummyService3Name)
   290  	local := NewTCPTest(tSuite)
   291  	hs := local.GenServers(2)
   292  	server := hs[0]
   293  	defer local.CloseAll()
   294  
   295  	lvl := log.DebugVisible()
   296  	log.SetDebugVisible(0)
   297  	log.OutputToBuf()
   298  	_, err := client.Send(server.ServerIdentity, "test", nil)
   299  	log.OutputToOs()
   300  	log.SetDebugVisible(lvl)
   301  	require.NotEqual(t, websocket.ErrBadHandshake, err)
   302  	require.NotEqual(t, "", log.GetStdErr())
   303  }
   304  
   305  func TestWebSocketTLS_Error(t *testing.T) {
   306  	cert, key, err := getSelfSignedCertificateAndKey()
   307  	require.Nil(t, err)
   308  	CAPool := x509.NewCertPool()
   309  	CAPool.AppendCertsFromPEM(cert)
   310  
   311  	client := NewClientKeep(tSuite, dummyService3Name)
   312  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   313  	local := NewTCPTest(tSuite)
   314  	local.webSocketTLSCertificate = cert
   315  	local.webSocketTLSCertificateKey = key
   316  	hs := local.GenServers(2)
   317  	server := hs[0]
   318  	defer local.CloseAll()
   319  
   320  	lvl := log.DebugVisible()
   321  	log.SetDebugVisible(0)
   322  	log.OutputToBuf()
   323  	_, err = client.Send(server.ServerIdentity, "test", nil)
   324  	log.OutputToOs()
   325  	log.SetDebugVisible(lvl)
   326  	require.NotEqual(t, websocket.ErrBadHandshake, err)
   327  	require.NotEqual(t, "", log.GetStdErr())
   328  }
   329  
   330  // TestWebSocket_Streaming_normal reads all messages from the service
   331  func TestWebSocket_Streaming_normal(t *testing.T) {
   332  	local := NewTCPTest(tSuite)
   333  	defer local.CloseAll()
   334  
   335  	serName := "streamingService"
   336  	_, err := RegisterNewService(serName, newStreamingService)
   337  	require.NoError(t, err)
   338  	defer UnregisterService(serName)
   339  
   340  	servers, el, _ := local.GenTree(4, false)
   341  	client := local.NewClientKeep(serName)
   342  
   343  	n := 5
   344  	r := &SimpleRequest{
   345  		ServerIdentities: el,
   346  		Val:              int64(n),
   347  	}
   348  
   349  	log.Lvl1("Happy-path testing")
   350  	conn, err := client.Stream(servers[0].ServerIdentity, r)
   351  	require.NoError(t, err)
   352  
   353  	for i := 0; i < n; i++ {
   354  		sr := &SimpleResponse{}
   355  		require.NoError(t, conn.ReadMessage(sr))
   356  		require.Equal(t, sr.Val, int64(n))
   357  	}
   358  
   359  	// Using the same client (connection) to repeat the same request should
   360  	// fail because the connection should be closed by the service when
   361  	// there are no more messages.
   362  	log.Lvl1("Fail on re-use")
   363  	sr := &SimpleResponse{}
   364  	require.Error(t, conn.ReadMessage(sr))
   365  	require.NoError(t, client.Close())
   366  }
   367  
   368  // TestWebSocket_Streaming_Parallel_normal
   369  func TestWebSocket_Streaming_Parallel_normal(t *testing.T) {
   370  	local := NewTCPTest(tSuite)
   371  	defer local.CloseAll()
   372  
   373  	serName := "streamingService"
   374  	_, err := RegisterNewService(serName, newStreamingService)
   375  	require.NoError(t, err)
   376  	defer UnregisterService(serName)
   377  
   378  	servers, el, _ := local.GenTree(4, false)
   379  	n := 10
   380  
   381  	// Do streaming with 10 clients in parallel. Happy-path where clients read
   382  	// everything.
   383  	clients := make([]*Client, 100)
   384  	for i := range clients {
   385  		clients[i] = local.NewClientKeep(serName)
   386  	}
   387  	var wg sync.WaitGroup
   388  	for _, client := range clients {
   389  		wg.Add(1)
   390  		go func(c *Client) {
   391  			defer wg.Done()
   392  			r := &SimpleRequest{
   393  				ServerIdentities: el,
   394  				Val:              int64(n),
   395  			}
   396  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   397  			require.NoError(t, err)
   398  
   399  			for i := 0; i < n; i++ {
   400  				sr := &SimpleResponse{}
   401  				require.NoError(t, conn.ReadMessage(sr))
   402  				require.Equal(t, sr.Val, int64(n))
   403  			}
   404  		}(client)
   405  	}
   406  	wg.Wait()
   407  	for i := range clients {
   408  		require.NoError(t, clients[i].Close())
   409  	}
   410  }
   411  
   412  // TestWebSocket_Streaming_bi_normal sends multiple messages from the clients
   413  // and reads all the messages
   414  func TestWebSocket_Streaming_bi_normal(t *testing.T) {
   415  	local := NewTCPTest(tSuite)
   416  	defer local.CloseAll()
   417  
   418  	serviceStruct := struct {
   419  		once      sync.Once
   420  		outChan   chan *SimpleResponse
   421  		closeChan chan bool
   422  	}{
   423  		outChan:   make(chan *SimpleResponse, 10),
   424  		closeChan: make(chan bool),
   425  	}
   426  
   427  	h := func(m *SimpleRequest) (chan *SimpleResponse, chan bool, error) {
   428  		go func() {
   429  			for i := 0; i < int(m.Val); i++ {
   430  				time.Sleep(100 * time.Millisecond)
   431  				serviceStruct.outChan <- &SimpleResponse{int64(i)}
   432  			}
   433  			<-serviceStruct.closeChan
   434  			serviceStruct.once.Do(func() {
   435  				close(serviceStruct.outChan)
   436  			})
   437  		}()
   438  		return serviceStruct.outChan, serviceStruct.closeChan, nil
   439  	}
   440  
   441  	newCustomStreamingService := func(c *Context) (Service, error) {
   442  		s := &StreamingService{
   443  			ServiceProcessor: NewServiceProcessor(c),
   444  			stopAt:           -1,
   445  		}
   446  		if err := s.RegisterStreamingHandler(h); err != nil {
   447  			panic(err.Error())
   448  		}
   449  		return s, nil
   450  	}
   451  	serName := "biStreamingService"
   452  	_, err := RegisterNewService(serName, newCustomStreamingService)
   453  	require.NoError(t, err)
   454  	defer UnregisterService(serName)
   455  
   456  	servers, el, _ := local.GenTree(4, false)
   457  	client := local.NewClientKeep(serName)
   458  
   459  	// A first request to the service
   460  	n := 5
   461  	r := &SimpleRequest{
   462  		ServerIdentities: el,
   463  		Val:              int64(n),
   464  	}
   465  
   466  	conn, err := client.Stream(servers[0].ServerIdentity, r)
   467  	require.NoError(t, err)
   468  
   469  	for i := 0; i < n; i++ {
   470  		sr := &SimpleResponse{}
   471  		require.NoError(t, conn.ReadMessage(sr))
   472  		require.Equal(t, sr.Val, int64(i))
   473  	}
   474  
   475  	// Lets perform a second request
   476  	n = 5
   477  	r = &SimpleRequest{
   478  		ServerIdentities: el,
   479  		Val:              int64(n),
   480  	}
   481  
   482  	conn, err = client.Stream(servers[0].ServerIdentity, r)
   483  	require.NoError(t, err)
   484  
   485  	for i := 0; i < n; i++ {
   486  		sr := &SimpleResponse{}
   487  		require.NoError(t, conn.ReadMessage(sr))
   488  		require.Equal(t, sr.Val, int64(i))
   489  	}
   490  
   491  	client.Close()
   492  	time.Sleep(time.Second)
   493  }
   494  
   495  // TestWebSocket_Streaming_early_client makes the client close early.
   496  func TestWebSocket_Streaming_early_client(t *testing.T) {
   497  	local := NewTCPTest(tSuite)
   498  	defer local.CloseAll()
   499  
   500  	serName := "streamingService"
   501  	serID, err := RegisterNewService(serName, newStreamingService)
   502  	require.NoError(t, err)
   503  	defer UnregisterService(serName)
   504  
   505  	servers, el, _ := local.GenTree(4, false)
   506  	client := local.NewClientKeep(serName)
   507  
   508  	n := 5
   509  	r := &SimpleRequest{
   510  		ServerIdentities: el,
   511  		Val:              int64(n),
   512  	}
   513  
   514  	// go-routine should also terminate.
   515  	client = local.NewClientKeep(serName)
   516  	services := local.GetServices(servers, serID)
   517  	serviceRoot := services[0].(*StreamingService)
   518  	serviceRoot.gotStopChan = make(chan bool, 1)
   519  
   520  	_, err = client.Stream(servers[0].ServerIdentity, r)
   521  	require.NoError(t, err)
   522  	require.NoError(t, client.Close())
   523  
   524  	select {
   525  	case <-serviceRoot.gotStopChan:
   526  	case <-time.After(time.Second):
   527  		require.Fail(t, "should have got an early finish signal")
   528  	}
   529  
   530  }
   531  
   532  // TestWebSocket_Streaming_Parallel_early_client
   533  func TestWebSocket_Streaming_Parallel_early_client2(t *testing.T) {
   534  	local := NewTCPTest(tSuite)
   535  	defer local.CloseAll()
   536  
   537  	serName := "streamingService"
   538  	serID, err := RegisterNewService(serName, newStreamingService)
   539  	require.NoError(t, err)
   540  	defer UnregisterService(serName)
   541  
   542  	servers, el, _ := local.GenTree(4, false)
   543  	services := local.GetServices(servers, serID)
   544  	serviceRoot := services[0].(*StreamingService)
   545  	n := 10
   546  
   547  	// Unhappy-path where clients stop early.
   548  	clients := make([]*Client, 100)
   549  	for i := range clients {
   550  		clients[i] = local.NewClientKeep(serName)
   551  	}
   552  	serviceRoot.gotStopChan = make(chan bool, len(clients))
   553  	wg := sync.WaitGroup{}
   554  	for _, client := range clients {
   555  		wg.Add(1)
   556  		go func(c *Client) {
   557  			defer wg.Done()
   558  			r := &SimpleRequest{
   559  				ServerIdentities: el,
   560  				Val:              int64(n),
   561  			}
   562  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   563  			require.NoError(t, err)
   564  
   565  			// read one message instead of n then close
   566  			sr := &SimpleResponse{}
   567  			require.NoError(t, conn.ReadMessage(sr))
   568  			require.Equal(t, sr.Val, int64(n))
   569  			require.NoError(t, c.Close())
   570  		}(client)
   571  	}
   572  
   573  	// we should get close messages
   574  	for i := 0; i < len(clients); i++ {
   575  		select {
   576  		case <-serviceRoot.gotStopChan:
   577  		case <-time.After(time.Second):
   578  			require.Fail(t, "should have got an early finish signal")
   579  		}
   580  	}
   581  	wg.Wait()
   582  }
   583  
   584  // TestWebSocket_Streaming_early_service closes the service early
   585  func TestWebSocket_Streaming_early_service(t *testing.T) {
   586  	local := NewTCPTest(tSuite)
   587  	defer local.CloseAll()
   588  
   589  	serName := "streamingService"
   590  	serID, err := RegisterNewService(serName, newStreamingService)
   591  	require.NoError(t, err)
   592  	defer UnregisterService(serName)
   593  
   594  	servers, el, _ := local.GenTree(4, false)
   595  	client := local.NewClientKeep(serName)
   596  
   597  	n := 5
   598  	r := &SimpleRequest{
   599  		ServerIdentities: el,
   600  		Val:              int64(n),
   601  	}
   602  
   603  	// Have the service terminate early. The client should stop receiving
   604  	// messages.
   605  	log.Lvl1("Service terminate early")
   606  	stopAt := 1
   607  	client = local.NewClientKeep(serName)
   608  	services := local.GetServices(servers, serID)
   609  	serviceRoot := services[0].(*StreamingService)
   610  	serviceRoot.stopAt = stopAt
   611  
   612  	conn, err := client.Stream(servers[0].ServerIdentity, r)
   613  	require.NoError(t, err)
   614  	for i := 0; i < n; i++ {
   615  		if i > stopAt {
   616  			sr := &SimpleResponse{}
   617  			require.Error(t, conn.ReadMessage(sr))
   618  		} else {
   619  			sr := &SimpleResponse{}
   620  			require.NoError(t, conn.ReadMessage(sr))
   621  			require.Equal(t, sr.Val, int64(n))
   622  		}
   623  	}
   624  	require.NoError(t, client.Close())
   625  }
   626  
   627  // TestWebSocket_Streaming_Parallel_early_service
   628  func TestWebSocket_Streaming_Parallel_ealry_service(t *testing.T) {
   629  	local := NewTCPTest(tSuite)
   630  	defer local.CloseAll()
   631  
   632  	serName := "streamingService"
   633  	serID, err := RegisterNewService(serName, newStreamingService)
   634  	require.NoError(t, err)
   635  	defer UnregisterService(serName)
   636  
   637  	servers, el, _ := local.GenTree(4, false)
   638  	services := local.GetServices(servers, serID)
   639  	serviceRoot := services[0].(*StreamingService)
   640  	n := 10
   641  
   642  	// The other unhappy-path where the service stops early.
   643  	clients := make([]*Client, 100)
   644  	for i := range clients {
   645  		clients[i] = local.NewClientKeep(serName)
   646  	}
   647  	stopAt := 1
   648  	serviceRoot.stopAt = stopAt
   649  	wg := sync.WaitGroup{}
   650  	for _, client := range clients {
   651  		wg.Add(1)
   652  		go func(c *Client) {
   653  			defer wg.Done()
   654  			r := &SimpleRequest{
   655  				ServerIdentities: el,
   656  				Val:              int64(n),
   657  			}
   658  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   659  			require.NoError(t, err)
   660  
   661  			for i := 0; i < n; i++ {
   662  				if i > stopAt {
   663  					sr := &SimpleResponse{}
   664  					require.Error(t, conn.ReadMessage(sr))
   665  				} else {
   666  					sr := &SimpleResponse{}
   667  					require.NoError(t, conn.ReadMessage(sr))
   668  					require.Equal(t, sr.Val, int64(n))
   669  				}
   670  			}
   671  		}(client)
   672  	}
   673  	wg.Wait()
   674  	for i := range clients {
   675  		require.NoError(t, clients[i].Close())
   676  	}
   677  }
   678  
   679  // Tests the correct returning of values depending on the ParallelOptions structure
   680  func TestParallelOptions_GetList(t *testing.T) {
   681  	l := NewLocalTest(tSuite)
   682  	defer l.CloseAll()
   683  
   684  	var po *ParallelOptions
   685  	_, roster, _ := l.GenTree(3, false)
   686  	nodes := roster.List
   687  
   688  	count, list := po.GetList(nodes)
   689  	require.Equal(t, 2, count)
   690  	require.Equal(t, 3, len(list))
   691  	require.False(t, po.Quit())
   692  
   693  	po = &ParallelOptions{}
   694  	count, list = po.GetList(nodes)
   695  	require.Equal(t, 2, count)
   696  	require.Equal(t, 3, len(list))
   697  	require.False(t, po.Quit())
   698  
   699  	first := 0
   700  	for i := 0; i < 32; i++ {
   701  		_, list := po.GetList(nodes)
   702  		if (<-list).Equal(nodes[0]) {
   703  			first++
   704  		}
   705  	}
   706  	require.NotEqual(t, 0, first)
   707  	require.NotEqual(t, 32, first)
   708  	po.DontShuffle = true
   709  	first = 0
   710  	for i := 0; i < 32; i++ {
   711  		_, list := po.GetList(nodes)
   712  		if (<-list).Equal(nodes[0]) {
   713  			first++
   714  		}
   715  	}
   716  	require.Equal(t, 32, first)
   717  
   718  	po.IgnoreNodes = append(po.IgnoreNodes, nodes[0])
   719  	count, list = po.GetList(nodes)
   720  	require.Equal(t, 2, count)
   721  	require.Equal(t, 2, len(list))
   722  
   723  	po.IgnoreNodes = append(po.IgnoreNodes, nodes[1])
   724  	count, list = po.GetList(nodes)
   725  	require.Equal(t, 2, count)
   726  	require.Equal(t, 1, len(list))
   727  
   728  	po.IgnoreNodes = po.IgnoreNodes[0:1]
   729  	po.QuitError = true
   730  	require.True(t, po.Quit())
   731  
   732  	po.AskNodes = 1
   733  	count, list = po.GetList(nodes)
   734  	require.Equal(t, 1, count)
   735  	require.Equal(t, 1, len(list))
   736  
   737  	po.StartNode = 1
   738  	count, list = po.GetList(nodes)
   739  	require.Equal(t, 1, count)
   740  	require.Equal(t, 1, len(list))
   741  }
   742  
   743  func TestClient_SendProtobufParallel(t *testing.T) {
   744  	l := NewLocalTest(tSuite)
   745  	defer l.CloseAll()
   746  
   747  	servers, roster, _ := l.GenTree(3, false)
   748  	cl := NewClient(tSuite, serviceWebSocket)
   749  	tests := 10
   750  	firstNodes := make([]*network.ServerIdentity, tests)
   751  	for i := 0; i < tests; i++ {
   752  		log.Lvl1("Sending", i)
   753  		var err error
   754  		firstNodes[i], err = cl.SendProtobufParallel(roster.List, &SimpleResponse{}, nil, nil)
   755  		require.Nil(t, err)
   756  	}
   757  
   758  	for flags := 0; flags < 8; flags++ {
   759  		log.Lvl1("Count errors over all services with error-flags", flags)
   760  		_, err := cl.SendProtobufParallel(roster.List, &ErrorRequest{
   761  			Roster: *roster,
   762  			Flags:  flags,
   763  		}, nil, nil)
   764  		if flags == 7 {
   765  			require.Error(t, err)
   766  		} else {
   767  			require.NoError(t, err)
   768  		}
   769  		// Need to close here to make sure that all messages are being sent
   770  		// before going on to the next stage.
   771  		require.NoError(t, cl.Close())
   772  	}
   773  	var errs int
   774  	for _, server := range servers {
   775  		errs += server.Service(serviceWebSocket).(*ServiceWebSocket).Errors
   776  	}
   777  	require.Equal(t, 3, errs)
   778  
   779  	sort.Slice(firstNodes, func(i, j int) bool {
   780  		return bytes.Compare(firstNodes[i].ID[:], firstNodes[j].ID[:]) < 0
   781  	})
   782  	require.False(t, firstNodes[0].Equal(firstNodes[tests-1]))
   783  }
   784  
   785  func TestClient_SendProtobufParallelWithDecoder(t *testing.T) {
   786  	l := NewLocalTest(tSuite)
   787  	defer l.CloseAll()
   788  
   789  	_, roster, _ := l.GenTree(3, false)
   790  	cl := NewClient(tSuite, serviceWebSocket)
   791  
   792  	decoderWithError := func(data []byte, ret interface{}) error {
   793  		// As an example, the decoder should first decode the response, and it can then make
   794  		// further verification like the latest block index.
   795  		return xerrors.New("decoder error")
   796  	}
   797  
   798  	_, err := cl.SendProtobufParallelWithDecoder(roster.List, &SimpleResponse{}, &SimpleResponse{}, nil, decoderWithError)
   799  	require.Error(t, err)
   800  	require.Contains(t, err.Error(), "decoder error")
   801  
   802  	decoderNoError := func(data []byte, ret interface{}) error {
   803  		return nil
   804  	}
   805  
   806  	_, err = cl.SendProtobufParallelWithDecoder(roster.List, &SimpleResponse{}, &SimpleResponse{}, nil, decoderNoError)
   807  	require.NoError(t, err)
   808  }
   809  
   810  const dummyService3Name = "dummyService3"
   811  
   812  type DummyService3 struct {
   813  }
   814  
   815  func (ds *DummyService3) ProcessClientRequest(req *http.Request, path string, buf []byte) ([]byte, *StreamingTunnel, error) {
   816  	log.Lvl2("Got called with path", path, buf)
   817  	return []byte(path), nil, nil
   818  }
   819  
   820  func (ds *DummyService3) NewProtocol(tn *TreeNodeInstance, conf *GenericConfig) (ProtocolInstance, error) {
   821  	return nil, nil
   822  }
   823  
   824  func (ds *DummyService3) Process(env *network.Envelope) {
   825  }
   826  
   827  type StreamingService struct {
   828  	*ServiceProcessor
   829  	stopAt      int
   830  	gotStopChan chan bool
   831  }
   832  
   833  func newStreamingService(c *Context) (Service, error) {
   834  	s := &StreamingService{
   835  		ServiceProcessor: NewServiceProcessor(c),
   836  		stopAt:           -1,
   837  	}
   838  	if err := s.RegisterStreamingHandler(s.StreamValues); err != nil {
   839  		panic(err.Error())
   840  	}
   841  	return s, nil
   842  }
   843  
   844  func (ss *StreamingService) StreamValues(msg *SimpleRequest) (chan *SimpleResponse, chan bool, error) {
   845  	streamingChan := make(chan *SimpleResponse)
   846  	stopChan := make(chan bool)
   847  	go func() {
   848  	outer:
   849  		for i := 0; i < int(msg.Val); i++ {
   850  			// Add some delay between every message so that we can
   851  			// actually catch the stop signal before everything is
   852  			// sent out.
   853  			time.Sleep(100 * time.Millisecond)
   854  			select {
   855  			case <-stopChan:
   856  				ss.gotStopChan <- true
   857  				break outer
   858  			default:
   859  				streamingChan <- &SimpleResponse{msg.Val}
   860  			}
   861  			if ss.stopAt == i {
   862  				break outer
   863  			}
   864  		}
   865  		close(streamingChan)
   866  	}()
   867  	return streamingChan, stopChan, nil
   868  }