go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket_test.go (about) 1 package onet 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/elliptic" 6 "crypto/rand" 7 "crypto/tls" 8 "crypto/x509" 9 "crypto/x509/pkix" 10 "encoding/pem" 11 "io/ioutil" 12 "math/big" 13 "net" 14 "net/url" 15 "os" 16 "testing" 17 "time" 18 19 "github.com/stretchr/testify/require" 20 "go.dedis.ch/onet/v3/log" 21 "go.dedis.ch/onet/v3/network" 22 "go.dedis.ch/protobuf" 23 "golang.org/x/xerrors" 24 ) 25 26 func init() { 27 RegisterNewService(serviceWebSocket, newServiceWebSocket) 28 } 29 30 // Adapted from 'https://golang.org/src/crypto/tls/generate_cert.go' 31 func generateSelfSignedCert() (string, string, error) { 32 // Hostname or IP to generate a certificate for 33 hosts := []string{ 34 "127.0.0.1", 35 "::", 36 } 37 // Creation date formatted as Jan 2 15:04:05 2006 38 validFrom := time.Now().UTC().Format("Jan 2 15:04:05 2006") 39 // Duration that certificate is valid for 40 validFor := 365 * 24 * time.Hour 41 42 priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 43 if err != nil { 44 return "", "", err 45 } 46 47 notBefore, err := time.Parse("Jan 2 15:04:05 2006", validFrom) 48 if err != nil { 49 return "", "", err 50 } 51 52 notAfter := notBefore.Add(validFor) 53 54 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 55 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 56 if err != nil { 57 return "", "", err 58 } 59 60 template := x509.Certificate{ 61 SerialNumber: serialNumber, 62 Subject: pkix.Name{ 63 Organization: []string{"DEDIS EPFL"}, 64 }, 65 NotBefore: notBefore, 66 NotAfter: notAfter, 67 68 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 69 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 70 BasicConstraintsValid: true, 71 } 72 73 for _, host := range hosts { 74 if ip := net.ParseIP(host); ip != nil { 75 template.IPAddresses = append(template.IPAddresses, ip) 76 } else { 77 template.DNSNames = append(template.DNSNames, host) 78 } 79 } 80 81 derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) 82 if err != nil { 83 return "", "", err 84 } 85 86 certOut, err := ioutil.TempFile("", "cert.pem") 87 if err != nil { 88 return "", "", err 89 } 90 pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) 91 certOut.Close() 92 93 keyOut, err := ioutil.TempFile("", "key.pem") 94 if err != nil { 95 return "", "", err 96 } 97 98 b, err := x509.MarshalECPrivateKey(priv) 99 if err != nil { 100 return "", "", err 101 } 102 pemBlockForKey := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} 103 104 pem.Encode(keyOut, pemBlockForKey) 105 keyOut.Close() 106 107 return certOut.Name(), keyOut.Name(), nil 108 } 109 110 func getSelfSignedCertificateAndKey() ([]byte, []byte, error) { 111 certFilePath, keyFilePath, err := generateSelfSignedCert() 112 if err != nil { 113 return nil, nil, err 114 } 115 116 cert, err := ioutil.ReadFile(certFilePath) 117 if err != nil { 118 return nil, nil, err 119 } 120 err = os.Remove(certFilePath) 121 if err != nil { 122 return nil, nil, err 123 } 124 125 key, err := ioutil.ReadFile(keyFilePath) 126 if err != nil { 127 return nil, nil, err 128 } 129 err = os.Remove(keyFilePath) 130 if err != nil { 131 return nil, nil, err 132 } 133 134 return cert, key, nil 135 } 136 137 func TestNewWebSocket(t *testing.T) { 138 l := NewLocalTest(tSuite) 139 defer l.CloseAll() 140 141 c := l.NewServer(tSuite, 2050) 142 143 defer c.Close() 144 require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services)) 145 require.NotEmpty(t, c.WebSocket.services[serviceWebSocket]) 146 cl := NewClientKeep(tSuite, "WebSocket") 147 req := &SimpleResponse{} 148 msgTypeID := network.MessageType(req) 149 log.Lvlf1("Sending message Request: %x", msgTypeID[:]) 150 buf, err := protobuf.Encode(req) 151 require.Nil(t, err) 152 rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf) 153 require.Nil(t, err) 154 155 log.Lvlf1("Received reply: %x", rcv) 156 rcvMsg := &SimpleResponse{} 157 require.Nil(t, protobuf.Decode(rcv, rcvMsg)) 158 require.Equal(t, int64(1), rcvMsg.Val) 159 } 160 161 func TestNewWebSocketTLS(t *testing.T) { 162 cert, key, err := getSelfSignedCertificateAndKey() 163 require.Nil(t, err) 164 CAPool := x509.NewCertPool() 165 CAPool.AppendCertsFromPEM(cert) 166 167 l := NewTCPTest(tSuite) 168 l.webSocketTLSCertificate = cert 169 l.webSocketTLSCertificateKey = key 170 defer l.CloseAll() 171 172 c := l.NewServer(tSuite, 2050) 173 require.Equal(t, len(c.serviceManager.services), len(c.WebSocket.services)) 174 require.NotEmpty(t, c.WebSocket.services[serviceWebSocket]) 175 176 // Test the traditional host:port+1 way of specifying the websocket server. 177 cl := NewClientKeep(tSuite, "WebSocket") 178 defer cl.Close() 179 cl.TLSClientConfig = &tls.Config{RootCAs: CAPool} 180 req := &SimpleResponse{} 181 msgTypeID := network.MessageType(req) 182 log.Lvlf1("Sending message Request: %x", msgTypeID[:]) 183 buf, err := protobuf.Encode(req) 184 require.Nil(t, err) 185 rcv, err := cl.Send(c.ServerIdentity, "SimpleResponse", buf) 186 log.Lvlf1("Received reply: %x", rcv) 187 rcvMsg := &SimpleResponse{} 188 require.Nil(t, protobuf.Decode(rcv, rcvMsg)) 189 require.Equal(t, int64(1), rcvMsg.Val) 190 191 // Set c.ServerIdentity.URL, in order to test the other way of triggering wss:// connection. 192 hp, err := getWSHostPort(c.ServerIdentity, false) 193 require.NoError(t, err) 194 u := &url.URL{Scheme: "https", Host: hp} 195 c.ServerIdentity.URL = u.String() 196 197 log.Lvlf1("Sending message Request: %x", msgTypeID[:]) 198 rcv, err = cl.Send(c.ServerIdentity, "SimpleResponse", buf) 199 log.Lvlf1("Received reply: %x", rcv) 200 require.Nil(t, protobuf.Decode(rcv, rcvMsg)) 201 require.Equal(t, int64(1), rcvMsg.Val) 202 } 203 204 // Test the certificate reloader for websocket over TLS. 205 func TestCertificateReloader(t *testing.T) { 206 certPath, keyPath, err := generateSelfSignedCert() 207 require.NoError(t, err) 208 defer func() { 209 os.Remove(certPath) 210 os.Remove(keyPath) 211 }() 212 213 reloader, err := NewCertificateReloader(certPath, keyPath) 214 require.NoError(t, err) 215 216 cert, err := reloader.GetCertificateFunc()(nil) 217 require.NoError(t, err) 218 require.NotNil(t, cert) 219 220 reloader.certPath = "" 221 reloader.keyPath = "" 222 223 // It should work as the certificate is cached. 224 cert, err = reloader.GetCertificateFunc()(nil) 225 require.NoError(t, err) 226 require.NotNil(t, cert) 227 228 // Try with an expired certificate 229 // thus expecting an error. 230 cert.Leaf.NotAfter = time.Now().Add(30 * time.Minute) 231 _, err = reloader.GetCertificateFunc()(nil) 232 require.Error(t, err) 233 require.Contains(t, err.Error(), "no such file or directory") 234 235 // And finally it should reload the new cert 236 reloader.certPath = certPath 237 reloader.keyPath = keyPath 238 cert, err = reloader.GetCertificateFunc()(nil) 239 require.NoError(t, err) 240 require.NotNil(t, cert) 241 } 242 243 func TestGetWebHost(t *testing.T) { 244 url, err := getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, true) 245 require.Error(t, err) 246 247 url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8"}, false) 248 require.Error(t, err) 249 250 url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, true) 251 require.NoError(t, err) 252 require.Equal(t, "0.0.0.0:7771", url) 253 254 url, err = getWSHostPort(&network.ServerIdentity{Address: "tcp://8.8.8.8:7770"}, false) 255 require.NoError(t, err) 256 require.Equal(t, "8.8.8.8:7771", url) 257 258 url, err = getWSHostPort(&network.ServerIdentity{ 259 Address: "tcp://irrelevant:7770", 260 URL: "wrong url", 261 }, false) 262 require.Error(t, err) 263 264 url, err = getWSHostPort(&network.ServerIdentity{ 265 Address: "tcp://irrelevant:7770", 266 URL: "http://8.8.8.8:8888", 267 }, false) 268 require.NoError(t, err) 269 require.Equal(t, "8.8.8.8:8888", url) 270 271 url, err = getWSHostPort(&network.ServerIdentity{ 272 Address: "tcp://irrelevant:7770", 273 URL: "http://8.8.8.8", 274 }, false) 275 require.NoError(t, err) 276 require.Equal(t, "8.8.8.8:80", url) 277 278 url, err = getWSHostPort(&network.ServerIdentity{ 279 Address: "tcp://irrelevant:7770", 280 URL: "https://8.8.8.8", 281 }, false) 282 require.NoError(t, err) 283 require.Equal(t, "8.8.8.8:443", url) 284 285 url, err = getWSHostPort(&network.ServerIdentity{ 286 Address: "tcp://irrelevant:7770", 287 URL: "invalid://8.8.8.8:8888", 288 }, false) 289 require.Error(t, err) 290 } 291 292 const serviceWebSocket = "WebSocket" 293 294 type ServiceWebSocket struct { 295 *ServiceProcessor 296 Errors int 297 } 298 299 func (i *ServiceWebSocket) SimpleResponse(msg *SimpleResponse) (network.Message, error) { 300 return &SimpleResponse{msg.Val + 1}, nil 301 } 302 303 type ErrorRequest struct { 304 Roster Roster 305 Flags int 306 } 307 308 func (i *ServiceWebSocket) ErrorRequest(msg *ErrorRequest) (network.Message, error) { 309 i.Errors = 1 310 index, _ := msg.Roster.Search(i.ServerIdentity().ID) 311 if index < 0 { 312 return nil, xerrors.New("not in roster") 313 } 314 if msg.Flags&(1<<uint(index)) > 0 { 315 return nil, xerrors.New("found in flags: " + i.ServerIdentity().String()) 316 } 317 i.Errors = 0 318 return &SimpleResponse{}, nil 319 } 320 321 func newServiceWebSocket(c *Context) (Service, error) { 322 s := &ServiceWebSocket{ 323 ServiceProcessor: NewServiceProcessor(c), 324 } 325 log.ErrFatal(s.RegisterHandlers(s.SimpleResponse, s.ErrorRequest)) 326 return s, nil 327 }