gopkg.in/dedis/onet.v2@v2.0.0-20181115163211-c8f3724038a7/websocket_test.go (about)

     1  package onet
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/pem"
    11  	"io/ioutil"
    12  	"math/big"
    13  	"net"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/dedis/protobuf"
    22  	"github.com/stretchr/testify/require"
    23  	"gopkg.in/dedis/onet.v2/log"
    24  	"gopkg.in/dedis/onet.v2/network"
    25  	"gopkg.in/satori/go.uuid.v1"
    26  )
    27  
    28  func init() {
    29  	RegisterNewService(serviceWebSocket, newServiceWebSocket)
    30  }
    31  
    32  // Adapted from 'https://golang.org/src/crypto/tls/generate_cert.go'
    33  func generateSelfSignedCert() (string, string, error) {
    34  	// Hostname or IP to generate a certificate for
    35  	host := "127.0.0.1"
    36  	// Creation date formatted as Jan 2 15:04:05 2006
    37  	validFrom := time.Now().UTC().Format("Jan 2 15:04:05 2006")
    38  	// Duration that certificate is valid for
    39  	validFor := 365 * 24 * time.Hour
    40  
    41  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    42  	if err != nil {
    43  		return "", "", err
    44  	}
    45  
    46  	notBefore, err := time.Parse("Jan 2 15:04:05 2006", validFrom)
    47  	if err != nil {
    48  		return "", "", err
    49  	}
    50  
    51  	notAfter := notBefore.Add(validFor)
    52  
    53  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
    54  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
    55  	if err != nil {
    56  		return "", "", err
    57  	}
    58  
    59  	template := x509.Certificate{
    60  		SerialNumber: serialNumber,
    61  		Subject: pkix.Name{
    62  			Organization: []string{"DEDIS EPFL"},
    63  		},
    64  		NotBefore: notBefore,
    65  		NotAfter:  notAfter,
    66  
    67  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
    68  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
    69  		BasicConstraintsValid: true,
    70  	}
    71  
    72  	if ip := net.ParseIP(host); ip != nil {
    73  		template.IPAddresses = append(template.IPAddresses, ip)
    74  	} else {
    75  		template.DNSNames = append(template.DNSNames, host)
    76  	}
    77  
    78  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
    79  	if err != nil {
    80  		return "", "", err
    81  	}
    82  
    83  	certOut, err := ioutil.TempFile("", "cert.pem")
    84  	if err != nil {
    85  		return "", "", err
    86  	}
    87  	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
    88  	certOut.Close()
    89  
    90  	keyOut, err := ioutil.TempFile("", "key.pem")
    91  	if err != nil {
    92  		return "", "", err
    93  	}
    94  
    95  	b, err := x509.MarshalECPrivateKey(priv)
    96  	if err != nil {
    97  		return "", "", err
    98  	}
    99  	pemBlockForKey := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
   100  
   101  	pem.Encode(keyOut, pemBlockForKey)
   102  	keyOut.Close()
   103  
   104  	return certOut.Name(), keyOut.Name(), nil
   105  }
   106  
   107  func getSelfSignedCertificateAndKey() ([]byte, []byte, error) {
   108  	certFilePath, keyFilePath, err := generateSelfSignedCert()
   109  	if err != nil {
   110  		return nil, nil, err
   111  	}
   112  
   113  	cert, err := ioutil.ReadFile(certFilePath)
   114  	if err != nil {
   115  		return nil, nil, err
   116  	}
   117  	err = os.Remove(certFilePath)
   118  	if err != nil {
   119  		return nil, nil, err
   120  	}
   121  
   122  	key, err := ioutil.ReadFile(keyFilePath)
   123  	if err != nil {
   124  		return nil, nil, err
   125  	}
   126  	err = os.Remove(keyFilePath)
   127  	if err != nil {
   128  		return nil, nil, err
   129  	}
   130  
   131  	return cert, key, nil
   132  }
   133  
   134  func TestNewWebSocket(t *testing.T) {
   135  	l := NewLocalTest(tSuite)
   136  	defer l.CloseAll()
   137  
   138  	c := newTCPServer(tSuite, 0, l.path)
   139  	c.StartInBackground()
   140  
   141  	defer c.Close()
   142  	require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services))
   143  	require.NotEmpty(t, c.WebSocket.services[serviceWebSocket])
   144  	cl := NewClientKeep(tSuite, "WebSocket")
   145  	req := &SimpleResponse{}
   146  	log.Lvlf1("Sending message Request: %x", uuid.UUID(network.MessageType(req)).Bytes())
   147  	buf, err := protobuf.Encode(req)
   148  	require.Nil(t, err)
   149  	rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   150  	require.Nil(t, err)
   151  
   152  	log.Lvlf1("Received reply: %x", rcv)
   153  	rcvMsg := &SimpleResponse{}
   154  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   155  	require.Equal(t, 1, rcvMsg.Val)
   156  }
   157  
   158  func TestNewWebSocketTLS(t *testing.T) {
   159  	cert, key, err := getSelfSignedCertificateAndKey()
   160  	require.Nil(t, err)
   161  	CAPool := x509.NewCertPool()
   162  	CAPool.AppendCertsFromPEM(cert)
   163  
   164  	l := NewLocalTest(tSuite)
   165  	defer l.CloseAll()
   166  
   167  	c := newTCPServer(tSuite, 0, l.path)
   168  
   169  	certToAdd, err := tls.X509KeyPair(cert, key)
   170  	if err != nil {
   171  		require.Nil(t, err)
   172  	}
   173  	c.WebSocket.Lock()
   174  	c.WebSocket.TLSConfig = &tls.Config{Certificates: []tls.Certificate{certToAdd}}
   175  	c.WebSocket.Unlock()
   176  	c.StartInBackground()
   177  	defer c.Close()
   178  
   179  	require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services))
   180  	require.NotEmpty(t, c.WebSocket.services[serviceWebSocket])
   181  
   182  	// Test the traditional host:port+1 way of specifying the websocket server.
   183  	cl := NewClientKeep(tSuite, "WebSocket")
   184  	cl.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   185  	req := &SimpleResponse{}
   186  	log.Lvlf1("Sending message Request: %x", uuid.UUID(network.MessageType(req)).Bytes())
   187  	buf, err := protobuf.Encode(req)
   188  	require.Nil(t, err)
   189  	rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   190  	log.Lvlf1("Received reply: %x", rcv)
   191  	rcvMsg := &SimpleResponse{}
   192  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   193  	require.Equal(t, 1, rcvMsg.Val)
   194  	cl.Close()
   195  
   196  	// Set c.ServerIdentity.URL, in order to test the other way of triggering wss:// connection.
   197  	hp, err := getWSHostPort(c.ServerIdentity, false)
   198  	require.NoError(t, err)
   199  	u := &url.URL{Scheme: "https", Host: hp}
   200  	c.ServerIdentity.URL = u.String()
   201  
   202  	log.Lvlf1("Sending message Request: %x", uuid.UUID(network.MessageType(req)).Bytes())
   203  	rcv, err = cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   204  	log.Lvlf1("Received reply: %x", rcv)
   205  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   206  	require.Equal(t, 1, rcvMsg.Val)
   207  }
   208  
   209  func TestGetWebHost(t *testing.T) {
   210  	url, err := getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, true)
   211  	require.NotNil(t, err)
   212  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, false)
   213  	require.NotNil(t, err)
   214  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, true)
   215  	require.Nil(t, err)
   216  	require.Equal(t, "0.0.0.0:7771", url)
   217  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, false)
   218  	require.Nil(t, err)
   219  	require.Equal(t, "8.8.8.8:7771", url)
   220  }
   221  
   222  func TestClient_Send(t *testing.T) {
   223  	local := NewTCPTest(tSuite)
   224  	defer local.CloseAll()
   225  
   226  	// register service
   227  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   228  		return &simpleService{
   229  			ctx: c,
   230  		}, nil
   231  	})
   232  	defer ServiceFactory.Unregister(backForthServiceName)
   233  
   234  	// create servers
   235  	servers, el, _ := local.GenTree(4, false)
   236  	client := local.NewClient(backForthServiceName)
   237  
   238  	r := &SimpleRequest{
   239  		ServerIdentities: el,
   240  		Val:              10,
   241  	}
   242  	sr := &SimpleResponse{}
   243  	require.Equal(t, uint64(0), client.Rx())
   244  	require.Equal(t, uint64(0), client.Tx())
   245  	require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
   246  	require.Equal(t, sr.Val, 10)
   247  	require.NotEqual(t, uint64(0), client.Rx())
   248  	require.NotEqual(t, uint64(0), client.Tx())
   249  	require.True(t, client.Tx() > client.Rx())
   250  }
   251  
   252  func TestClientTLS_Send(t *testing.T) {
   253  	cert, key, err := getSelfSignedCertificateAndKey()
   254  	require.Nil(t, err)
   255  	CAPool := x509.NewCertPool()
   256  	CAPool.AppendCertsFromPEM(cert)
   257  
   258  	local := NewTCPTest(tSuite)
   259  	local.webSocketTLSCertificate = cert
   260  	local.webSocketTLSCertificateKey = key
   261  	defer local.CloseAll()
   262  
   263  	// register service
   264  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   265  		return &simpleService{
   266  			ctx: c,
   267  		}, nil
   268  	})
   269  	defer ServiceFactory.Unregister(backForthServiceName)
   270  
   271  	// create servers
   272  	servers, el, _ := local.GenTree(4, false)
   273  	client := local.NewClient(backForthServiceName)
   274  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   275  
   276  	r := &SimpleRequest{
   277  		ServerIdentities: el,
   278  		Val:              10,
   279  	}
   280  	sr := &SimpleResponse{}
   281  	require.Equal(t, uint64(0), client.Rx())
   282  	require.Equal(t, uint64(0), client.Tx())
   283  	require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
   284  	require.Equal(t, sr.Val, 10)
   285  	require.NotEqual(t, uint64(0), client.Rx())
   286  	require.NotEqual(t, uint64(0), client.Tx())
   287  	require.True(t, client.Tx() > client.Rx())
   288  }
   289  
   290  func TestClient_Parallel(t *testing.T) {
   291  	nbrNodes := 4
   292  	nbrParallel := 20
   293  	local := NewTCPTest(tSuite)
   294  	defer local.CloseAll()
   295  
   296  	// register service
   297  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   298  		return &simpleService{
   299  			ctx: c,
   300  		}, nil
   301  	})
   302  	defer ServiceFactory.Unregister(backForthServiceName)
   303  
   304  	// create servers
   305  	servers, el, _ := local.GenTree(nbrNodes, true)
   306  
   307  	wg := sync.WaitGroup{}
   308  	wg.Add(nbrParallel)
   309  	for i := 0; i < nbrParallel; i++ {
   310  		go func(i int) {
   311  			log.Lvl1("Starting message", i)
   312  			r := &SimpleRequest{
   313  				ServerIdentities: el,
   314  				Val:              10 * i,
   315  			}
   316  			client := local.NewClient(backForthServiceName)
   317  			sr := &SimpleResponse{}
   318  			err := client.SendProtobuf(servers[0].ServerIdentity, r, sr)
   319  			require.Nil(t, err)
   320  			require.Equal(t, 10*i, sr.Val)
   321  			log.Lvl1("Done with message", i)
   322  			wg.Done()
   323  		}(i)
   324  	}
   325  	wg.Wait()
   326  }
   327  
   328  func TestClientTLS_Parallel(t *testing.T) {
   329  	cert, key, err := getSelfSignedCertificateAndKey()
   330  	require.Nil(t, err)
   331  	CAPool := x509.NewCertPool()
   332  	CAPool.AppendCertsFromPEM(cert)
   333  
   334  	nbrNodes := 4
   335  	nbrParallel := 20
   336  	local := NewTCPTest(tSuite)
   337  	local.webSocketTLSCertificate = cert
   338  	local.webSocketTLSCertificateKey = key
   339  	defer local.CloseAll()
   340  
   341  	// register service
   342  	RegisterNewService(backForthServiceName, func(c *Context) (Service, error) {
   343  		return &simpleService{
   344  			ctx: c,
   345  		}, nil
   346  	})
   347  	defer ServiceFactory.Unregister(backForthServiceName)
   348  
   349  	// create servers
   350  	servers, el, _ := local.GenTree(nbrNodes, true)
   351  
   352  	wg := sync.WaitGroup{}
   353  	wg.Add(nbrParallel)
   354  	for i := 0; i < nbrParallel; i++ {
   355  		go func(i int) {
   356  			log.Lvl1("Starting message", i)
   357  			r := &SimpleRequest{
   358  				ServerIdentities: el,
   359  				Val:              10 * i,
   360  			}
   361  			client := local.NewClient(backForthServiceName)
   362  			client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   363  			sr := &SimpleResponse{}
   364  			require.Nil(t, client.SendProtobuf(servers[0].ServerIdentity, r, sr))
   365  			require.Equal(t, 10*i, sr.Val)
   366  			log.Lvl1("Done with message", i)
   367  			wg.Done()
   368  		}(i)
   369  	}
   370  	wg.Wait()
   371  }
   372  
   373  func TestNewClientKeep(t *testing.T) {
   374  	c := NewClientKeep(tSuite, serviceWebSocket)
   375  	require.True(t, c.keep)
   376  }
   377  
   378  func TestMultiplePath(t *testing.T) {
   379  	_, err := RegisterNewService(dummyService3Name, func(c *Context) (Service, error) {
   380  		ds := &DummyService3{}
   381  		return ds, nil
   382  	})
   383  	require.Nil(t, err)
   384  	defer UnregisterService(dummyService3Name)
   385  
   386  	local := NewTCPTest(tSuite)
   387  	hs := local.GenServers(2)
   388  	server := hs[0]
   389  	defer local.CloseAll()
   390  	client := NewClientKeep(tSuite, dummyService3Name)
   391  	msg, err := protobuf.Encode(&DummyMsg{})
   392  	require.Nil(t, err)
   393  	path1, path2 := "path1", "path2"
   394  	resp, err := client.Send(server.ServerIdentity, path1, msg)
   395  	require.Nil(t, err)
   396  	require.Equal(t, path1, string(resp))
   397  	resp, err = client.Send(server.ServerIdentity, path2, msg)
   398  	require.Nil(t, err)
   399  	require.Equal(t, path2, string(resp))
   400  }
   401  
   402  func TestMultiplePathTLS(t *testing.T) {
   403  	cert, key, err := getSelfSignedCertificateAndKey()
   404  	require.Nil(t, err)
   405  	CAPool := x509.NewCertPool()
   406  	CAPool.AppendCertsFromPEM(cert)
   407  
   408  	_, err = RegisterNewService(dummyService3Name, func(c *Context) (Service, error) {
   409  		ds := &DummyService3{}
   410  		return ds, nil
   411  	})
   412  	require.Nil(t, err)
   413  	defer UnregisterService(dummyService3Name)
   414  
   415  	local := NewTCPTest(tSuite)
   416  	local.webSocketTLSCertificate = cert
   417  	local.webSocketTLSCertificateKey = key
   418  	hs := local.GenServers(2)
   419  	server := hs[0]
   420  	defer local.CloseAll()
   421  	client := NewClientKeep(tSuite, dummyService3Name)
   422  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   423  	msg, err := protobuf.Encode(&DummyMsg{})
   424  	require.Nil(t, err)
   425  	path1, path2 := "path1", "path2"
   426  	resp, err := client.Send(server.ServerIdentity, path1, msg)
   427  	require.Nil(t, err)
   428  	require.Equal(t, path1, string(resp))
   429  	resp, err = client.Send(server.ServerIdentity, path2, msg)
   430  	require.Nil(t, err)
   431  	require.Equal(t, path2, string(resp))
   432  }
   433  
   434  func TestWebSocket_Error(t *testing.T) {
   435  	client := NewClientKeep(tSuite, dummyService3Name)
   436  	local := NewTCPTest(tSuite)
   437  	hs := local.GenServers(2)
   438  	server := hs[0]
   439  	defer local.CloseAll()
   440  
   441  	lvl := log.DebugVisible()
   442  	log.SetDebugVisible(0)
   443  	log.OutputToBuf()
   444  	_, err := client.Send(server.ServerIdentity, "test", nil)
   445  	log.OutputToOs()
   446  	log.SetDebugVisible(lvl)
   447  	require.NotEqual(t, "websocket: bad handshake", err.Error())
   448  	require.NotEqual(t, "", log.GetStdOut())
   449  }
   450  
   451  func TestWebSocketTLS_Error(t *testing.T) {
   452  	cert, key, err := getSelfSignedCertificateAndKey()
   453  	require.Nil(t, err)
   454  	CAPool := x509.NewCertPool()
   455  	CAPool.AppendCertsFromPEM(cert)
   456  
   457  	client := NewClientKeep(tSuite, dummyService3Name)
   458  	client.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   459  	local := NewTCPTest(tSuite)
   460  	local.webSocketTLSCertificate = cert
   461  	local.webSocketTLSCertificateKey = key
   462  	hs := local.GenServers(2)
   463  	server := hs[0]
   464  	defer local.CloseAll()
   465  
   466  	lvl := log.DebugVisible()
   467  	log.SetDebugVisible(0)
   468  	log.OutputToBuf()
   469  	_, err = client.Send(server.ServerIdentity, "test", nil)
   470  	log.OutputToOs()
   471  	log.SetDebugVisible(lvl)
   472  	require.NotEqual(t, "websocket: bad handshake", err.Error())
   473  	require.NotEqual(t, "", log.GetStdOut())
   474  }
   475  
   476  // TestWebSocket_Streaming performs 3 test cases.
   477  // (1) happy-path, where client reads all messages from the service
   478  // (2) unhappy-path, where client closes early
   479  // (3) unhappy-path, where service closes early
   480  func TestWebSocket_Streaming(t *testing.T) {
   481  	local := NewTCPTest(tSuite)
   482  	defer local.CloseAll()
   483  
   484  	serName := "streamingService"
   485  	serID, err := RegisterNewService(serName, newStreamingService)
   486  	require.NoError(t, err)
   487  	defer UnregisterService(serName)
   488  
   489  	servers, el, _ := local.GenTree(4, false)
   490  	client := local.NewClientKeep(serName)
   491  
   492  	n := 5
   493  	r := &SimpleRequest{
   494  		ServerIdentities: el,
   495  		Val:              n,
   496  	}
   497  
   498  	// (1) happy-path testing
   499  	conn, err := client.Stream(servers[0].ServerIdentity, r)
   500  	require.NoError(t, err)
   501  
   502  	for i := 0; i < n; i++ {
   503  		sr := &SimpleResponse{}
   504  		require.NoError(t, conn.ReadMessage(sr))
   505  		require.Equal(t, sr.Val, n)
   506  	}
   507  
   508  	// Using the same client (connection) to repeat the same request should
   509  	// fail because the connection should be closed by the service when
   510  	// there are no more messages.
   511  	sr := &SimpleResponse{}
   512  	require.Error(t, conn.ReadMessage(sr))
   513  	require.NoError(t, client.Close())
   514  
   515  	// (2) This time, have the client terminate early, the service's
   516  	// go-routine should also terminate.
   517  	client = local.NewClientKeep(serName)
   518  	services := local.GetServices(servers, serID)
   519  	serviceRoot := services[0].(*StreamingService)
   520  	serviceRoot.gotStopChan = make(chan bool, 1)
   521  
   522  	conn, err = client.Stream(servers[0].ServerIdentity, r)
   523  	require.NoError(t, err)
   524  	require.NoError(t, client.Close())
   525  
   526  	select {
   527  	case <-serviceRoot.gotStopChan:
   528  	case <-time.After(time.Second):
   529  		require.Fail(t, "should have got an early finish signal")
   530  	}
   531  
   532  	// (3) Finally, have the service terminate early. The client should
   533  	// stop receiving messages.
   534  	stopAt := 1
   535  	serviceRoot.stopAt = stopAt
   536  	client = local.NewClientKeep(serName)
   537  
   538  	conn, err = client.Stream(servers[0].ServerIdentity, r)
   539  	require.NoError(t, err)
   540  	for i := 0; i < n; i++ {
   541  		if i > stopAt {
   542  			sr := &SimpleResponse{}
   543  			require.Error(t, conn.ReadMessage(sr))
   544  		} else {
   545  			sr := &SimpleResponse{}
   546  			require.NoError(t, conn.ReadMessage(sr))
   547  			require.Equal(t, sr.Val, n)
   548  		}
   549  	}
   550  	require.NoError(t, client.Close())
   551  }
   552  
   553  // TestWebSocket_Streaming_Parallel is essentially the same as
   554  // TestWebSocket_Streaming, except we do it in parallel.
   555  func TestWebSocket_Streaming_Parallel(t *testing.T) {
   556  	local := NewTCPTest(tSuite)
   557  	defer local.CloseAll()
   558  
   559  	serName := "streamingService"
   560  	serID, err := RegisterNewService(serName, newStreamingService)
   561  	require.NoError(t, err)
   562  	defer UnregisterService(serName)
   563  
   564  	servers, el, _ := local.GenTree(4, false)
   565  	services := local.GetServices(servers, serID)
   566  	serviceRoot := services[0].(*StreamingService)
   567  	n := 10
   568  
   569  	// (1) We try to do streaming with 10 clients in parallel. Start with
   570  	// the happy-path where clients read everything.
   571  	clients := make([]*Client, 100)
   572  	for i := range clients {
   573  		clients[i] = local.NewClientKeep(serName)
   574  	}
   575  	var wg sync.WaitGroup
   576  	for _, client := range clients {
   577  		wg.Add(1)
   578  		go func(c *Client) {
   579  			defer wg.Done()
   580  			r := &SimpleRequest{
   581  				ServerIdentities: el,
   582  				Val:              n,
   583  			}
   584  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   585  			require.NoError(t, err)
   586  
   587  			for i := 0; i < n; i++ {
   588  				sr := &SimpleResponse{}
   589  				require.NoError(t, conn.ReadMessage(sr))
   590  				require.Equal(t, sr.Val, n)
   591  			}
   592  		}(client)
   593  	}
   594  	wg.Wait()
   595  	for i := range clients {
   596  		require.NoError(t, clients[i].Close())
   597  	}
   598  
   599  	// (2) Now try the unhappy-path where clients stop early.
   600  	for i := range clients {
   601  		clients[i] = local.NewClientKeep(serName)
   602  	}
   603  	serviceRoot.gotStopChan = make(chan bool, len(clients))
   604  	wg = sync.WaitGroup{}
   605  	for _, client := range clients {
   606  		wg.Add(1)
   607  		go func(c *Client) {
   608  			defer wg.Done()
   609  			r := &SimpleRequest{
   610  				ServerIdentities: el,
   611  				Val:              n,
   612  			}
   613  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   614  			require.NoError(t, err)
   615  
   616  			// read one message instead of n then close
   617  			sr := &SimpleResponse{}
   618  			require.NoError(t, conn.ReadMessage(sr))
   619  			require.Equal(t, sr.Val, n)
   620  			require.NoError(t, c.Close())
   621  		}(client)
   622  	}
   623  
   624  	// we should get close messages
   625  	for i := 0; i < len(clients); i++ {
   626  		select {
   627  		case <-serviceRoot.gotStopChan:
   628  		case <-time.After(time.Second):
   629  			require.Fail(t, "should have got an early finish signal")
   630  		}
   631  	}
   632  	wg.Wait()
   633  
   634  	// (3) The other unhappy-path where the service stops early.
   635  	for i := range clients {
   636  		clients[i] = local.NewClientKeep(serName)
   637  	}
   638  	stopAt := 1
   639  	serviceRoot.stopAt = stopAt
   640  	wg = sync.WaitGroup{}
   641  	for _, client := range clients {
   642  		wg.Add(1)
   643  		go func(c *Client) {
   644  			defer wg.Done()
   645  			r := &SimpleRequest{
   646  				ServerIdentities: el,
   647  				Val:              n,
   648  			}
   649  			conn, err := c.Stream(servers[0].ServerIdentity, r)
   650  			require.NoError(t, err)
   651  
   652  			for i := 0; i < n; i++ {
   653  				if i > stopAt {
   654  					sr := &SimpleResponse{}
   655  					require.Error(t, conn.ReadMessage(sr))
   656  				} else {
   657  					sr := &SimpleResponse{}
   658  					require.NoError(t, conn.ReadMessage(sr))
   659  					require.Equal(t, sr.Val, n)
   660  				}
   661  			}
   662  		}(client)
   663  	}
   664  	wg.Wait()
   665  	for i := range clients {
   666  		require.NoError(t, clients[i].Close())
   667  	}
   668  }
   669  
   670  const serviceWebSocket = "WebSocket"
   671  
   672  type ServiceWebSocket struct {
   673  	*ServiceProcessor
   674  }
   675  
   676  func (i *ServiceWebSocket) SimpleResponse(msg *SimpleResponse) (network.Message, error) {
   677  	return &SimpleResponse{msg.Val + 1}, nil
   678  }
   679  
   680  func newServiceWebSocket(c *Context) (Service, error) {
   681  	s := &ServiceWebSocket{
   682  		ServiceProcessor: NewServiceProcessor(c),
   683  	}
   684  	log.ErrFatal(s.RegisterHandler(s.SimpleResponse))
   685  	return s, nil
   686  }
   687  
   688  const dummyService3Name = "dummyService3"
   689  
   690  type DummyService3 struct {
   691  }
   692  
   693  func (ds *DummyService3) ProcessClientRequest(req *http.Request, path string, buf []byte) ([]byte, *StreamingTunnel, error) {
   694  	log.Lvl2("Got called with path", path, buf)
   695  	return []byte(path), nil, nil
   696  }
   697  
   698  func (ds *DummyService3) NewProtocol(tn *TreeNodeInstance, conf *GenericConfig) (ProtocolInstance, error) {
   699  	return nil, nil
   700  }
   701  
   702  func (ds *DummyService3) Process(env *network.Envelope) {
   703  }
   704  
   705  type StreamingService struct {
   706  	*ServiceProcessor
   707  	stopAt      int
   708  	gotStopChan chan bool
   709  }
   710  
   711  func newStreamingService(c *Context) (Service, error) {
   712  	s := &StreamingService{
   713  		ServiceProcessor: NewServiceProcessor(c),
   714  		stopAt:           -1,
   715  	}
   716  	if err := s.RegisterStreamingHandler(s.StreamValues); err != nil {
   717  		panic(err.Error())
   718  	}
   719  	return s, nil
   720  }
   721  
   722  func (ss *StreamingService) StreamValues(msg *SimpleRequest) (chan *SimpleResponse, chan bool, error) {
   723  	streamingChan := make(chan *SimpleResponse)
   724  	stopChan := make(chan bool)
   725  	go func() {
   726  	outer:
   727  		for i := 0; i < msg.Val; i++ {
   728  			// Add some delay between every message so that we can
   729  			// actually catch the stop signal before everything is
   730  			// sent out.
   731  			time.Sleep(100 * time.Millisecond)
   732  			select {
   733  			case <-stopChan:
   734  				ss.gotStopChan <- true
   735  				break outer
   736  			default:
   737  				streamingChan <- &SimpleResponse{msg.Val}
   738  			}
   739  			if ss.stopAt == i {
   740  				break outer
   741  			}
   742  		}
   743  		close(streamingChan)
   744  	}()
   745  	return streamingChan, stopChan, nil
   746  }