go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket_test.go (about)

     1  package onet
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/pem"
    11  	"io/ioutil"
    12  	"math/big"
    13  	"net"
    14  	"net/url"
    15  	"os"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/stretchr/testify/require"
    20  	"go.dedis.ch/onet/v3/log"
    21  	"go.dedis.ch/onet/v3/network"
    22  	"go.dedis.ch/protobuf"
    23  	"golang.org/x/xerrors"
    24  )
    25  
    26  func init() {
    27  	RegisterNewService(serviceWebSocket, newServiceWebSocket)
    28  }
    29  
    30  // Adapted from 'https://golang.org/src/crypto/tls/generate_cert.go'
    31  func generateSelfSignedCert() (string, string, error) {
    32  	// Hostname or IP to generate a certificate for
    33  	hosts := []string{
    34  		"127.0.0.1",
    35  		"::",
    36  	}
    37  	// Creation date formatted as Jan 2 15:04:05 2006
    38  	validFrom := time.Now().UTC().Format("Jan 2 15:04:05 2006")
    39  	// Duration that certificate is valid for
    40  	validFor := 365 * 24 * time.Hour
    41  
    42  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    43  	if err != nil {
    44  		return "", "", err
    45  	}
    46  
    47  	notBefore, err := time.Parse("Jan 2 15:04:05 2006", validFrom)
    48  	if err != nil {
    49  		return "", "", err
    50  	}
    51  
    52  	notAfter := notBefore.Add(validFor)
    53  
    54  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
    55  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
    56  	if err != nil {
    57  		return "", "", err
    58  	}
    59  
    60  	template := x509.Certificate{
    61  		SerialNumber: serialNumber,
    62  		Subject: pkix.Name{
    63  			Organization: []string{"DEDIS EPFL"},
    64  		},
    65  		NotBefore: notBefore,
    66  		NotAfter:  notAfter,
    67  
    68  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
    69  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
    70  		BasicConstraintsValid: true,
    71  	}
    72  
    73  	for _, host := range hosts {
    74  		if ip := net.ParseIP(host); ip != nil {
    75  			template.IPAddresses = append(template.IPAddresses, ip)
    76  		} else {
    77  			template.DNSNames = append(template.DNSNames, host)
    78  		}
    79  	}
    80  
    81  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
    82  	if err != nil {
    83  		return "", "", err
    84  	}
    85  
    86  	certOut, err := ioutil.TempFile("", "cert.pem")
    87  	if err != nil {
    88  		return "", "", err
    89  	}
    90  	pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
    91  	certOut.Close()
    92  
    93  	keyOut, err := ioutil.TempFile("", "key.pem")
    94  	if err != nil {
    95  		return "", "", err
    96  	}
    97  
    98  	b, err := x509.MarshalECPrivateKey(priv)
    99  	if err != nil {
   100  		return "", "", err
   101  	}
   102  	pemBlockForKey := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
   103  
   104  	pem.Encode(keyOut, pemBlockForKey)
   105  	keyOut.Close()
   106  
   107  	return certOut.Name(), keyOut.Name(), nil
   108  }
   109  
   110  func getSelfSignedCertificateAndKey() ([]byte, []byte, error) {
   111  	certFilePath, keyFilePath, err := generateSelfSignedCert()
   112  	if err != nil {
   113  		return nil, nil, err
   114  	}
   115  
   116  	cert, err := ioutil.ReadFile(certFilePath)
   117  	if err != nil {
   118  		return nil, nil, err
   119  	}
   120  	err = os.Remove(certFilePath)
   121  	if err != nil {
   122  		return nil, nil, err
   123  	}
   124  
   125  	key, err := ioutil.ReadFile(keyFilePath)
   126  	if err != nil {
   127  		return nil, nil, err
   128  	}
   129  	err = os.Remove(keyFilePath)
   130  	if err != nil {
   131  		return nil, nil, err
   132  	}
   133  
   134  	return cert, key, nil
   135  }
   136  
   137  func TestNewWebSocket(t *testing.T) {
   138  	l := NewLocalTest(tSuite)
   139  	defer l.CloseAll()
   140  
   141  	c := l.NewServer(tSuite, 2050)
   142  
   143  	defer c.Close()
   144  	require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services))
   145  	require.NotEmpty(t, c.WebSocket.services[serviceWebSocket])
   146  	cl := NewClientKeep(tSuite, "WebSocket")
   147  	req := &SimpleResponse{}
   148  	msgTypeID := network.MessageType(req)
   149  	log.Lvlf1("Sending message Request: %x", msgTypeID[:])
   150  	buf, err := protobuf.Encode(req)
   151  	require.Nil(t, err)
   152  	rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   153  	require.Nil(t, err)
   154  
   155  	log.Lvlf1("Received reply: %x", rcv)
   156  	rcvMsg := &SimpleResponse{}
   157  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   158  	require.Equal(t, int64(1), rcvMsg.Val)
   159  }
   160  
   161  func TestNewWebSocketTLS(t *testing.T) {
   162  	cert, key, err := getSelfSignedCertificateAndKey()
   163  	require.Nil(t, err)
   164  	CAPool := x509.NewCertPool()
   165  	CAPool.AppendCertsFromPEM(cert)
   166  
   167  	l := NewTCPTest(tSuite)
   168  	l.webSocketTLSCertificate = cert
   169  	l.webSocketTLSCertificateKey = key
   170  	defer l.CloseAll()
   171  
   172  	c := l.NewServer(tSuite, 2050)
   173  	require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services))
   174  	require.NotEmpty(t, c.WebSocket.services[serviceWebSocket])
   175  
   176  	// Test the traditional host:port+1 way of specifying the websocket server.
   177  	cl := NewClientKeep(tSuite, "WebSocket")
   178  	defer cl.Close()
   179  	cl.TLSClientConfig = &tls.Config{RootCAs: CAPool}
   180  	req := &SimpleResponse{}
   181  	msgTypeID := network.MessageType(req)
   182  	log.Lvlf1("Sending message Request: %x", msgTypeID[:])
   183  	buf, err := protobuf.Encode(req)
   184  	require.Nil(t, err)
   185  	rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   186  	log.Lvlf1("Received reply: %x", rcv)
   187  	rcvMsg := &SimpleResponse{}
   188  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   189  	require.Equal(t, int64(1), rcvMsg.Val)
   190  
   191  	// Set c.ServerIdentity.URL, in order to test the other way of triggering wss:// connection.
   192  	hp, err := getWSHostPort(c.ServerIdentity, false)
   193  	require.NoError(t, err)
   194  	u := &url.URL{Scheme: "https", Host: hp}
   195  	c.ServerIdentity.URL = u.String()
   196  
   197  	log.Lvlf1("Sending message Request: %x", msgTypeID[:])
   198  	rcv, err = cl.Send(c.ServerIdentity, "SimpleResponse", buf)
   199  	log.Lvlf1("Received reply: %x", rcv)
   200  	require.Nil(t, protobuf.Decode(rcv, rcvMsg))
   201  	require.Equal(t, int64(1), rcvMsg.Val)
   202  }
   203  
   204  // Test the certificate reloader for websocket over TLS.
   205  func TestCertificateReloader(t *testing.T) {
   206  	certPath, keyPath, err := generateSelfSignedCert()
   207  	require.NoError(t, err)
   208  	defer func() {
   209  		os.Remove(certPath)
   210  		os.Remove(keyPath)
   211  	}()
   212  
   213  	reloader, err := NewCertificateReloader(certPath, keyPath)
   214  	require.NoError(t, err)
   215  
   216  	cert, err := reloader.GetCertificateFunc()(nil)
   217  	require.NoError(t, err)
   218  	require.NotNil(t, cert)
   219  
   220  	reloader.certPath = ""
   221  	reloader.keyPath = ""
   222  
   223  	// It should work as the certificate is cached.
   224  	cert, err = reloader.GetCertificateFunc()(nil)
   225  	require.NoError(t, err)
   226  	require.NotNil(t, cert)
   227  
   228  	// Try with an expired certificate
   229  	// thus expecting an error.
   230  	cert.Leaf.NotAfter = time.Now().Add(30 * time.Minute)
   231  	_, err = reloader.GetCertificateFunc()(nil)
   232  	require.Error(t, err)
   233  	require.Contains(t, err.Error(), "no such file or directory")
   234  
   235  	// And finally it should reload the new cert
   236  	reloader.certPath = certPath
   237  	reloader.keyPath = keyPath
   238  	cert, err = reloader.GetCertificateFunc()(nil)
   239  	require.NoError(t, err)
   240  	require.NotNil(t, cert)
   241  }
   242  
   243  func TestGetWebHost(t *testing.T) {
   244  	url, err := getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, true)
   245  	require.Error(t, err)
   246  
   247  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, false)
   248  	require.Error(t, err)
   249  
   250  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, true)
   251  	require.NoError(t, err)
   252  	require.Equal(t, "0.0.0.0:7771", url)
   253  
   254  	url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, false)
   255  	require.NoError(t, err)
   256  	require.Equal(t, "8.8.8.8:7771", url)
   257  
   258  	url, err = getWSHostPort(&network.ServerIdentity{
   259  		Address: "tcp://irrelevant:7770",
   260  		URL:     "wrong url",
   261  	}, false)
   262  	require.Error(t, err)
   263  
   264  	url, err = getWSHostPort(&network.ServerIdentity{
   265  		Address: "tcp://irrelevant:7770",
   266  		URL:     "http://8.8.8.8:8888",
   267  	}, false)
   268  	require.NoError(t, err)
   269  	require.Equal(t, "8.8.8.8:8888", url)
   270  
   271  	url, err = getWSHostPort(&network.ServerIdentity{
   272  		Address: "tcp://irrelevant:7770",
   273  		URL:     "http://8.8.8.8",
   274  	}, false)
   275  	require.NoError(t, err)
   276  	require.Equal(t, "8.8.8.8:80", url)
   277  
   278  	url, err = getWSHostPort(&network.ServerIdentity{
   279  		Address: "tcp://irrelevant:7770",
   280  		URL:     "https://8.8.8.8",
   281  	}, false)
   282  	require.NoError(t, err)
   283  	require.Equal(t, "8.8.8.8:443", url)
   284  
   285  	url, err = getWSHostPort(&network.ServerIdentity{
   286  		Address: "tcp://irrelevant:7770",
   287  		URL:     "invalid://8.8.8.8:8888",
   288  	}, false)
   289  	require.Error(t, err)
   290  }
   291  
   292  const serviceWebSocket = "WebSocket"
   293  
   294  type ServiceWebSocket struct {
   295  	*ServiceProcessor
   296  	Errors int
   297  }
   298  
   299  func (i *ServiceWebSocket) SimpleResponse(msg *SimpleResponse) (network.Message, error) {
   300  	return &SimpleResponse{msg.Val + 1}, nil
   301  }
   302  
   303  type ErrorRequest struct {
   304  	Roster Roster
   305  	Flags  int
   306  }
   307  
   308  func (i *ServiceWebSocket) ErrorRequest(msg *ErrorRequest) (network.Message, error) {
   309  	i.Errors = 1
   310  	index, _ := msg.Roster.Search(i.ServerIdentity().ID)
   311  	if index < 0 {
   312  		return nil, xerrors.New("not in roster")
   313  	}
   314  	if msg.Flags&(1<<uint(index)) > 0 {
   315  		return nil, xerrors.New("found in flags: " + i.ServerIdentity().String())
   316  	}
   317  	i.Errors = 0
   318  	return &SimpleResponse{}, nil
   319  }
   320  
   321  func newServiceWebSocket(c *Context) (Service, error) {
   322  	s := &ServiceWebSocket{
   323  		ServiceProcessor: NewServiceProcessor(c),
   324  	}
   325  	log.ErrFatal(s.RegisterHandlers(s.SimpleResponse, s.ErrorRequest))
   326  	return s, nil
   327  }