github.com/pion/dtls/v2@v2.2.12/e2e/e2e_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package e2e 8 9 import ( 10 "context" 11 "crypto/ed25519" 12 "crypto/rand" 13 "crypto/rsa" 14 "crypto/tls" 15 "crypto/x509" 16 "errors" 17 "fmt" 18 "io" 19 "net" 20 "sync" 21 "sync/atomic" 22 "testing" 23 "time" 24 25 "github.com/pion/dtls/v2" 26 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 27 "github.com/pion/transport/v2/test" 28 ) 29 30 const ( 31 testMessage = "Hello World" 32 testTimeLimit = 5 * time.Second 33 messageRetry = 200 * time.Millisecond 34 ) 35 36 var errServerTimeout = errors.New("waiting on serverReady err: timeout") 37 38 func randomPort(t testing.TB) int { 39 t.Helper() 40 conn, err := net.ListenPacket("udp4", "127.0.0.1:0") 41 if err != nil { 42 t.Fatalf("failed to pickPort: %v", err) 43 } 44 defer func() { 45 _ = conn.Close() 46 }() 47 switch addr := conn.LocalAddr().(type) { 48 case *net.UDPAddr: 49 return addr.Port 50 default: 51 t.Fatalf("unknown addr type %T", addr) 52 return 0 53 } 54 } 55 56 func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) { 57 go func() { 58 buffer := make([]byte, 8192) 59 n, err := conn.Read(buffer) 60 if err != nil { 61 errChan <- err 62 return 63 } 64 65 outChan <- string(buffer[:n]) 66 atomic.AddUint64(messageRecvCount, 1) 67 }() 68 69 for { 70 if atomic.LoadUint64(messageRecvCount) == 2 { 71 break 72 } else if _, err := conn.Write([]byte(testMessage)); err != nil { 73 errChan <- err 74 break 75 } 76 77 time.Sleep(messageRetry) 78 } 79 } 80 81 type comm struct { 82 ctx context.Context 83 clientConfig, serverConfig *dtls.Config 84 serverPort int 85 messageRecvCount *uint64 // Counter to make sure both sides got a message 86 clientMutex *sync.Mutex 87 clientConn net.Conn 88 serverMutex *sync.Mutex 89 serverConn net.Conn 90 serverListener net.Listener 91 serverReady chan struct{} 92 errChan chan error 93 clientChan chan string 94 serverChan chan string 95 client func(*comm) 96 server func(*comm) 97 } 98 99 func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm { 100 messageRecvCount := uint64(0) 101 c := &comm{ 102 ctx: ctx, 103 clientConfig: clientConfig, 104 serverConfig: serverConfig, 105 serverPort: serverPort, 106 messageRecvCount: &messageRecvCount, 107 clientMutex: &sync.Mutex{}, 108 serverMutex: &sync.Mutex{}, 109 serverReady: make(chan struct{}), 110 errChan: make(chan error), 111 clientChan: make(chan string), 112 serverChan: make(chan string), 113 server: server, 114 client: client, 115 } 116 return c 117 } 118 119 func (c *comm) assert(t *testing.T) { 120 // DTLS Client 121 go c.client(c) 122 123 // DTLS Server 124 go c.server(c) 125 126 defer func() { 127 if c.clientConn != nil { 128 if err := c.clientConn.Close(); err != nil { 129 t.Fatal(err) 130 } 131 } 132 if c.serverConn != nil { 133 if err := c.serverConn.Close(); err != nil { 134 t.Fatal(err) 135 } 136 } 137 if c.serverListener != nil { 138 if err := c.serverListener.Close(); err != nil { 139 t.Fatal(err) 140 } 141 } 142 }() 143 144 func() { 145 seenClient, seenServer := false, false 146 for { 147 select { 148 case err := <-c.errChan: 149 t.Fatal(err) 150 case <-time.After(testTimeLimit): 151 t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer) 152 case clientMsg := <-c.clientChan: 153 if clientMsg != testMessage { 154 t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage) 155 } 156 157 seenClient = true 158 if seenClient && seenServer { 159 return 160 } 161 case serverMsg := <-c.serverChan: 162 if serverMsg != testMessage { 163 t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage) 164 } 165 166 seenServer = true 167 if seenClient && seenServer { 168 return 169 } 170 } 171 } 172 }() 173 } 174 175 func clientPion(c *comm) { 176 select { 177 case <-c.serverReady: 178 // OK 179 case <-time.After(time.Second): 180 c.errChan <- errServerTimeout 181 } 182 183 c.clientMutex.Lock() 184 defer c.clientMutex.Unlock() 185 186 var err error 187 c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", 188 &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, 189 c.clientConfig, 190 ) 191 if err != nil { 192 c.errChan <- err 193 return 194 } 195 196 simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) 197 } 198 199 func serverPion(c *comm) { 200 c.serverMutex.Lock() 201 defer c.serverMutex.Unlock() 202 203 var err error 204 c.serverListener, err = dtls.Listen("udp", 205 &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, 206 c.serverConfig, 207 ) 208 if err != nil { 209 c.errChan <- err 210 return 211 } 212 c.serverReady <- struct{}{} 213 c.serverConn, err = c.serverListener.Accept() 214 if err != nil { 215 c.errChan <- err 216 return 217 } 218 219 simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) 220 } 221 222 /* 223 Simple DTLS Client/Server can communicate 224 - Assert that you can send messages both ways 225 - Assert that Close() on both ends work 226 - Assert that no Goroutines are leaked 227 */ 228 func testPionE2ESimple(t *testing.T, server, client func(*comm)) { 229 lim := test.TimeOut(time.Second * 30) 230 defer lim.Stop() 231 232 report := test.CheckRoutines(t) 233 defer report() 234 235 for _, cipherSuite := range []dtls.CipherSuiteID{ 236 dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 237 dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 238 dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 239 } { 240 cipherSuite := cipherSuite 241 t.Run(cipherSuite.String(), func(t *testing.T) { 242 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 243 defer cancel() 244 245 cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") 246 if err != nil { 247 t.Fatal(err) 248 } 249 250 cfg := &dtls.Config{ 251 Certificates: []tls.Certificate{cert}, 252 CipherSuites: []dtls.CipherSuiteID{cipherSuite}, 253 InsecureSkipVerify: true, 254 } 255 serverPort := randomPort(t) 256 comm := newComm(ctx, cfg, cfg, serverPort, server, client) 257 comm.assert(t) 258 }) 259 } 260 } 261 262 func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) { 263 lim := test.TimeOut(time.Second * 30) 264 defer lim.Stop() 265 266 report := test.CheckRoutines(t) 267 defer report() 268 269 for _, cipherSuite := range []dtls.CipherSuiteID{ 270 dtls.TLS_PSK_WITH_AES_128_CCM, 271 dtls.TLS_PSK_WITH_AES_128_CCM_8, 272 dtls.TLS_PSK_WITH_AES_256_CCM_8, 273 dtls.TLS_PSK_WITH_AES_128_GCM_SHA256, 274 dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, 275 } { 276 cipherSuite := cipherSuite 277 t.Run(cipherSuite.String(), func(t *testing.T) { 278 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 279 defer cancel() 280 281 cfg := &dtls.Config{ 282 PSK: func(hint []byte) ([]byte, error) { 283 return []byte{0xAB, 0xC1, 0x23}, nil 284 }, 285 PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, 286 CipherSuites: []dtls.CipherSuiteID{cipherSuite}, 287 } 288 serverPort := randomPort(t) 289 comm := newComm(ctx, cfg, cfg, serverPort, server, client) 290 comm.assert(t) 291 }) 292 } 293 } 294 295 func testPionE2EMTUs(t *testing.T, server, client func(*comm)) { 296 lim := test.TimeOut(time.Second * 30) 297 defer lim.Stop() 298 299 report := test.CheckRoutines(t) 300 defer report() 301 302 for _, mtu := range []int{ 303 10000, 304 1000, 305 100, 306 } { 307 mtu := mtu 308 t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) { 309 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 310 defer cancel() 311 312 cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") 313 if err != nil { 314 t.Fatal(err) 315 } 316 317 cfg := &dtls.Config{ 318 Certificates: []tls.Certificate{cert}, 319 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 320 InsecureSkipVerify: true, 321 MTU: mtu, 322 } 323 serverPort := randomPort(t) 324 comm := newComm(ctx, cfg, cfg, serverPort, server, client) 325 comm.assert(t) 326 }) 327 } 328 } 329 330 func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) { 331 lim := test.TimeOut(time.Second * 30) 332 defer lim.Stop() 333 334 report := test.CheckRoutines(t) 335 defer report() 336 337 for _, cipherSuite := range []dtls.CipherSuiteID{ 338 dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, 339 dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, 340 dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 341 dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 342 dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 343 } { 344 cipherSuite := cipherSuite 345 t.Run(cipherSuite.String(), func(t *testing.T) { 346 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 347 defer cancel() 348 349 _, key, err := ed25519.GenerateKey(rand.Reader) 350 if err != nil { 351 t.Fatal(err) 352 } 353 cert, err := selfsign.SelfSign(key) 354 if err != nil { 355 t.Fatal(err) 356 } 357 358 cfg := &dtls.Config{ 359 Certificates: []tls.Certificate{cert}, 360 CipherSuites: []dtls.CipherSuiteID{cipherSuite}, 361 InsecureSkipVerify: true, 362 } 363 serverPort := randomPort(t) 364 comm := newComm(ctx, cfg, cfg, serverPort, server, client) 365 comm.assert(t) 366 }) 367 } 368 } 369 370 func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) { 371 lim := test.TimeOut(time.Second * 30) 372 defer lim.Stop() 373 374 report := test.CheckRoutines(t) 375 defer report() 376 377 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 378 defer cancel() 379 380 _, skey, err := ed25519.GenerateKey(rand.Reader) 381 if err != nil { 382 t.Fatal(err) 383 } 384 scert, err := selfsign.SelfSign(skey) 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 _, ckey, err := ed25519.GenerateKey(rand.Reader) 390 if err != nil { 391 t.Fatal(err) 392 } 393 ccert, err := selfsign.SelfSign(ckey) 394 if err != nil { 395 t.Fatal(err) 396 } 397 398 scfg := &dtls.Config{ 399 Certificates: []tls.Certificate{scert}, 400 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 401 ClientAuth: dtls.RequireAnyClientCert, 402 } 403 ccfg := &dtls.Config{ 404 Certificates: []tls.Certificate{ccert}, 405 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 406 InsecureSkipVerify: true, 407 } 408 serverPort := randomPort(t) 409 comm := newComm(ctx, ccfg, scfg, serverPort, server, client) 410 comm.assert(t) 411 } 412 413 func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) { 414 lim := test.TimeOut(time.Second * 30) 415 defer lim.Stop() 416 417 report := test.CheckRoutines(t) 418 defer report() 419 420 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 421 defer cancel() 422 423 scert, err := selfsign.GenerateSelfSigned() 424 if err != nil { 425 t.Fatal(err) 426 } 427 428 ccert, err := selfsign.GenerateSelfSigned() 429 if err != nil { 430 t.Fatal(err) 431 } 432 433 clientCAs := x509.NewCertPool() 434 caCert, err := x509.ParseCertificate(ccert.Certificate[0]) 435 if err != nil { 436 t.Fatal(err) 437 } 438 clientCAs.AddCert(caCert) 439 440 scfg := &dtls.Config{ 441 ClientCAs: clientCAs, 442 Certificates: []tls.Certificate{scert}, 443 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 444 ClientAuth: dtls.RequireAnyClientCert, 445 } 446 ccfg := &dtls.Config{ 447 Certificates: []tls.Certificate{ccert}, 448 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, 449 InsecureSkipVerify: true, 450 } 451 serverPort := randomPort(t) 452 comm := newComm(ctx, ccfg, scfg, serverPort, server, client) 453 comm.assert(t) 454 } 455 456 func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) { 457 lim := test.TimeOut(time.Second * 30) 458 defer lim.Stop() 459 460 report := test.CheckRoutines(t) 461 defer report() 462 463 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 464 defer cancel() 465 466 spriv, err := rsa.GenerateKey(rand.Reader, 2048) 467 if err != nil { 468 t.Fatal(err) 469 } 470 scert, err := selfsign.SelfSign(spriv) 471 if err != nil { 472 t.Fatal(err) 473 } 474 475 cpriv, err := rsa.GenerateKey(rand.Reader, 2048) 476 if err != nil { 477 t.Fatal(err) 478 } 479 ccert, err := selfsign.SelfSign(cpriv) 480 if err != nil { 481 t.Fatal(err) 482 } 483 484 scfg := &dtls.Config{ 485 Certificates: []tls.Certificate{scert}, 486 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, 487 ClientAuth: dtls.RequireAnyClientCert, 488 } 489 ccfg := &dtls.Config{ 490 Certificates: []tls.Certificate{ccert}, 491 CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, 492 InsecureSkipVerify: true, 493 } 494 serverPort := randomPort(t) 495 comm := newComm(ctx, ccfg, scfg, serverPort, server, client) 496 comm.assert(t) 497 } 498 499 func TestPionE2ESimple(t *testing.T) { 500 testPionE2ESimple(t, serverPion, clientPion) 501 } 502 503 func TestPionE2ESimplePSK(t *testing.T) { 504 testPionE2ESimplePSK(t, serverPion, clientPion) 505 } 506 507 func TestPionE2EMTUs(t *testing.T) { 508 testPionE2EMTUs(t, serverPion, clientPion) 509 } 510 511 func TestPionE2ESimpleED25519(t *testing.T) { 512 testPionE2ESimpleED25519(t, serverPion, clientPion) 513 } 514 515 func TestPionE2ESimpleED25519ClientCert(t *testing.T) { 516 testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion) 517 } 518 519 func TestPionE2ESimpleECDSAClientCert(t *testing.T) { 520 testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion) 521 } 522 523 func TestPionE2ESimpleRSAClientCert(t *testing.T) { 524 testPionE2ESimpleRSAClientCert(t, serverPion, clientPion) 525 }