github.com/pion/dtls/v2@v2.2.12/e2e/e2e_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package e2e
     8  
     9  import (
    10  	"context"
    11  	"crypto/ed25519"
    12  	"crypto/rand"
    13  	"crypto/rsa"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"errors"
    17  	"fmt"
    18  	"io"
    19  	"net"
    20  	"sync"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/pion/dtls/v2"
    26  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    27  	"github.com/pion/transport/v2/test"
    28  )
    29  
    30  const (
    31  	testMessage   = "Hello World"
    32  	testTimeLimit = 5 * time.Second
    33  	messageRetry  = 200 * time.Millisecond
    34  )
    35  
    36  var errServerTimeout = errors.New("waiting on serverReady err: timeout")
    37  
    38  func randomPort(t testing.TB) int {
    39  	t.Helper()
    40  	conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
    41  	if err != nil {
    42  		t.Fatalf("failed to pickPort: %v", err)
    43  	}
    44  	defer func() {
    45  		_ = conn.Close()
    46  	}()
    47  	switch addr := conn.LocalAddr().(type) {
    48  	case *net.UDPAddr:
    49  		return addr.Port
    50  	default:
    51  		t.Fatalf("unknown addr type %T", addr)
    52  		return 0
    53  	}
    54  }
    55  
    56  func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
    57  	go func() {
    58  		buffer := make([]byte, 8192)
    59  		n, err := conn.Read(buffer)
    60  		if err != nil {
    61  			errChan <- err
    62  			return
    63  		}
    64  
    65  		outChan <- string(buffer[:n])
    66  		atomic.AddUint64(messageRecvCount, 1)
    67  	}()
    68  
    69  	for {
    70  		if atomic.LoadUint64(messageRecvCount) == 2 {
    71  			break
    72  		} else if _, err := conn.Write([]byte(testMessage)); err != nil {
    73  			errChan <- err
    74  			break
    75  		}
    76  
    77  		time.Sleep(messageRetry)
    78  	}
    79  }
    80  
    81  type comm struct {
    82  	ctx                        context.Context
    83  	clientConfig, serverConfig *dtls.Config
    84  	serverPort                 int
    85  	messageRecvCount           *uint64 // Counter to make sure both sides got a message
    86  	clientMutex                *sync.Mutex
    87  	clientConn                 net.Conn
    88  	serverMutex                *sync.Mutex
    89  	serverConn                 net.Conn
    90  	serverListener             net.Listener
    91  	serverReady                chan struct{}
    92  	errChan                    chan error
    93  	clientChan                 chan string
    94  	serverChan                 chan string
    95  	client                     func(*comm)
    96  	server                     func(*comm)
    97  }
    98  
    99  func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
   100  	messageRecvCount := uint64(0)
   101  	c := &comm{
   102  		ctx:              ctx,
   103  		clientConfig:     clientConfig,
   104  		serverConfig:     serverConfig,
   105  		serverPort:       serverPort,
   106  		messageRecvCount: &messageRecvCount,
   107  		clientMutex:      &sync.Mutex{},
   108  		serverMutex:      &sync.Mutex{},
   109  		serverReady:      make(chan struct{}),
   110  		errChan:          make(chan error),
   111  		clientChan:       make(chan string),
   112  		serverChan:       make(chan string),
   113  		server:           server,
   114  		client:           client,
   115  	}
   116  	return c
   117  }
   118  
   119  func (c *comm) assert(t *testing.T) {
   120  	// DTLS Client
   121  	go c.client(c)
   122  
   123  	// DTLS Server
   124  	go c.server(c)
   125  
   126  	defer func() {
   127  		if c.clientConn != nil {
   128  			if err := c.clientConn.Close(); err != nil {
   129  				t.Fatal(err)
   130  			}
   131  		}
   132  		if c.serverConn != nil {
   133  			if err := c.serverConn.Close(); err != nil {
   134  				t.Fatal(err)
   135  			}
   136  		}
   137  		if c.serverListener != nil {
   138  			if err := c.serverListener.Close(); err != nil {
   139  				t.Fatal(err)
   140  			}
   141  		}
   142  	}()
   143  
   144  	func() {
   145  		seenClient, seenServer := false, false
   146  		for {
   147  			select {
   148  			case err := <-c.errChan:
   149  				t.Fatal(err)
   150  			case <-time.After(testTimeLimit):
   151  				t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
   152  			case clientMsg := <-c.clientChan:
   153  				if clientMsg != testMessage {
   154  					t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
   155  				}
   156  
   157  				seenClient = true
   158  				if seenClient && seenServer {
   159  					return
   160  				}
   161  			case serverMsg := <-c.serverChan:
   162  				if serverMsg != testMessage {
   163  					t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
   164  				}
   165  
   166  				seenServer = true
   167  				if seenClient && seenServer {
   168  					return
   169  				}
   170  			}
   171  		}
   172  	}()
   173  }
   174  
   175  func clientPion(c *comm) {
   176  	select {
   177  	case <-c.serverReady:
   178  		// OK
   179  	case <-time.After(time.Second):
   180  		c.errChan <- errServerTimeout
   181  	}
   182  
   183  	c.clientMutex.Lock()
   184  	defer c.clientMutex.Unlock()
   185  
   186  	var err error
   187  	c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
   188  		&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
   189  		c.clientConfig,
   190  	)
   191  	if err != nil {
   192  		c.errChan <- err
   193  		return
   194  	}
   195  
   196  	simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
   197  }
   198  
   199  func serverPion(c *comm) {
   200  	c.serverMutex.Lock()
   201  	defer c.serverMutex.Unlock()
   202  
   203  	var err error
   204  	c.serverListener, err = dtls.Listen("udp",
   205  		&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
   206  		c.serverConfig,
   207  	)
   208  	if err != nil {
   209  		c.errChan <- err
   210  		return
   211  	}
   212  	c.serverReady <- struct{}{}
   213  	c.serverConn, err = c.serverListener.Accept()
   214  	if err != nil {
   215  		c.errChan <- err
   216  		return
   217  	}
   218  
   219  	simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
   220  }
   221  
   222  /*
   223  	  Simple DTLS Client/Server can communicate
   224  	    - Assert that you can send messages both ways
   225  		- Assert that Close() on both ends work
   226  		- Assert that no Goroutines are leaked
   227  */
   228  func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
   229  	lim := test.TimeOut(time.Second * 30)
   230  	defer lim.Stop()
   231  
   232  	report := test.CheckRoutines(t)
   233  	defer report()
   234  
   235  	for _, cipherSuite := range []dtls.CipherSuiteID{
   236  		dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   237  		dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   238  		dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
   239  	} {
   240  		cipherSuite := cipherSuite
   241  		t.Run(cipherSuite.String(), func(t *testing.T) {
   242  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   243  			defer cancel()
   244  
   245  			cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
   246  			if err != nil {
   247  				t.Fatal(err)
   248  			}
   249  
   250  			cfg := &dtls.Config{
   251  				Certificates:       []tls.Certificate{cert},
   252  				CipherSuites:       []dtls.CipherSuiteID{cipherSuite},
   253  				InsecureSkipVerify: true,
   254  			}
   255  			serverPort := randomPort(t)
   256  			comm := newComm(ctx, cfg, cfg, serverPort, server, client)
   257  			comm.assert(t)
   258  		})
   259  	}
   260  }
   261  
   262  func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
   263  	lim := test.TimeOut(time.Second * 30)
   264  	defer lim.Stop()
   265  
   266  	report := test.CheckRoutines(t)
   267  	defer report()
   268  
   269  	for _, cipherSuite := range []dtls.CipherSuiteID{
   270  		dtls.TLS_PSK_WITH_AES_128_CCM,
   271  		dtls.TLS_PSK_WITH_AES_128_CCM_8,
   272  		dtls.TLS_PSK_WITH_AES_256_CCM_8,
   273  		dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
   274  		dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
   275  	} {
   276  		cipherSuite := cipherSuite
   277  		t.Run(cipherSuite.String(), func(t *testing.T) {
   278  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   279  			defer cancel()
   280  
   281  			cfg := &dtls.Config{
   282  				PSK: func(hint []byte) ([]byte, error) {
   283  					return []byte{0xAB, 0xC1, 0x23}, nil
   284  				},
   285  				PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
   286  				CipherSuites:    []dtls.CipherSuiteID{cipherSuite},
   287  			}
   288  			serverPort := randomPort(t)
   289  			comm := newComm(ctx, cfg, cfg, serverPort, server, client)
   290  			comm.assert(t)
   291  		})
   292  	}
   293  }
   294  
   295  func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
   296  	lim := test.TimeOut(time.Second * 30)
   297  	defer lim.Stop()
   298  
   299  	report := test.CheckRoutines(t)
   300  	defer report()
   301  
   302  	for _, mtu := range []int{
   303  		10000,
   304  		1000,
   305  		100,
   306  	} {
   307  		mtu := mtu
   308  		t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
   309  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   310  			defer cancel()
   311  
   312  			cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
   313  			if err != nil {
   314  				t.Fatal(err)
   315  			}
   316  
   317  			cfg := &dtls.Config{
   318  				Certificates:       []tls.Certificate{cert},
   319  				CipherSuites:       []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   320  				InsecureSkipVerify: true,
   321  				MTU:                mtu,
   322  			}
   323  			serverPort := randomPort(t)
   324  			comm := newComm(ctx, cfg, cfg, serverPort, server, client)
   325  			comm.assert(t)
   326  		})
   327  	}
   328  }
   329  
   330  func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) {
   331  	lim := test.TimeOut(time.Second * 30)
   332  	defer lim.Stop()
   333  
   334  	report := test.CheckRoutines(t)
   335  	defer report()
   336  
   337  	for _, cipherSuite := range []dtls.CipherSuiteID{
   338  		dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
   339  		dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
   340  		dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   341  		dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
   342  		dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
   343  	} {
   344  		cipherSuite := cipherSuite
   345  		t.Run(cipherSuite.String(), func(t *testing.T) {
   346  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   347  			defer cancel()
   348  
   349  			_, key, err := ed25519.GenerateKey(rand.Reader)
   350  			if err != nil {
   351  				t.Fatal(err)
   352  			}
   353  			cert, err := selfsign.SelfSign(key)
   354  			if err != nil {
   355  				t.Fatal(err)
   356  			}
   357  
   358  			cfg := &dtls.Config{
   359  				Certificates:       []tls.Certificate{cert},
   360  				CipherSuites:       []dtls.CipherSuiteID{cipherSuite},
   361  				InsecureSkipVerify: true,
   362  			}
   363  			serverPort := randomPort(t)
   364  			comm := newComm(ctx, cfg, cfg, serverPort, server, client)
   365  			comm.assert(t)
   366  		})
   367  	}
   368  }
   369  
   370  func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) {
   371  	lim := test.TimeOut(time.Second * 30)
   372  	defer lim.Stop()
   373  
   374  	report := test.CheckRoutines(t)
   375  	defer report()
   376  
   377  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   378  	defer cancel()
   379  
   380  	_, skey, err := ed25519.GenerateKey(rand.Reader)
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  	scert, err := selfsign.SelfSign(skey)
   385  	if err != nil {
   386  		t.Fatal(err)
   387  	}
   388  
   389  	_, ckey, err := ed25519.GenerateKey(rand.Reader)
   390  	if err != nil {
   391  		t.Fatal(err)
   392  	}
   393  	ccert, err := selfsign.SelfSign(ckey)
   394  	if err != nil {
   395  		t.Fatal(err)
   396  	}
   397  
   398  	scfg := &dtls.Config{
   399  		Certificates: []tls.Certificate{scert},
   400  		CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   401  		ClientAuth:   dtls.RequireAnyClientCert,
   402  	}
   403  	ccfg := &dtls.Config{
   404  		Certificates:       []tls.Certificate{ccert},
   405  		CipherSuites:       []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   406  		InsecureSkipVerify: true,
   407  	}
   408  	serverPort := randomPort(t)
   409  	comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
   410  	comm.assert(t)
   411  }
   412  
   413  func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) {
   414  	lim := test.TimeOut(time.Second * 30)
   415  	defer lim.Stop()
   416  
   417  	report := test.CheckRoutines(t)
   418  	defer report()
   419  
   420  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   421  	defer cancel()
   422  
   423  	scert, err := selfsign.GenerateSelfSigned()
   424  	if err != nil {
   425  		t.Fatal(err)
   426  	}
   427  
   428  	ccert, err := selfsign.GenerateSelfSigned()
   429  	if err != nil {
   430  		t.Fatal(err)
   431  	}
   432  
   433  	clientCAs := x509.NewCertPool()
   434  	caCert, err := x509.ParseCertificate(ccert.Certificate[0])
   435  	if err != nil {
   436  		t.Fatal(err)
   437  	}
   438  	clientCAs.AddCert(caCert)
   439  
   440  	scfg := &dtls.Config{
   441  		ClientCAs:    clientCAs,
   442  		Certificates: []tls.Certificate{scert},
   443  		CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   444  		ClientAuth:   dtls.RequireAnyClientCert,
   445  	}
   446  	ccfg := &dtls.Config{
   447  		Certificates:       []tls.Certificate{ccert},
   448  		CipherSuites:       []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   449  		InsecureSkipVerify: true,
   450  	}
   451  	serverPort := randomPort(t)
   452  	comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
   453  	comm.assert(t)
   454  }
   455  
   456  func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) {
   457  	lim := test.TimeOut(time.Second * 30)
   458  	defer lim.Stop()
   459  
   460  	report := test.CheckRoutines(t)
   461  	defer report()
   462  
   463  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   464  	defer cancel()
   465  
   466  	spriv, err := rsa.GenerateKey(rand.Reader, 2048)
   467  	if err != nil {
   468  		t.Fatal(err)
   469  	}
   470  	scert, err := selfsign.SelfSign(spriv)
   471  	if err != nil {
   472  		t.Fatal(err)
   473  	}
   474  
   475  	cpriv, err := rsa.GenerateKey(rand.Reader, 2048)
   476  	if err != nil {
   477  		t.Fatal(err)
   478  	}
   479  	ccert, err := selfsign.SelfSign(cpriv)
   480  	if err != nil {
   481  		t.Fatal(err)
   482  	}
   483  
   484  	scfg := &dtls.Config{
   485  		Certificates: []tls.Certificate{scert},
   486  		CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   487  		ClientAuth:   dtls.RequireAnyClientCert,
   488  	}
   489  	ccfg := &dtls.Config{
   490  		Certificates:       []tls.Certificate{ccert},
   491  		CipherSuites:       []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   492  		InsecureSkipVerify: true,
   493  	}
   494  	serverPort := randomPort(t)
   495  	comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
   496  	comm.assert(t)
   497  }
   498  
   499  func TestPionE2ESimple(t *testing.T) {
   500  	testPionE2ESimple(t, serverPion, clientPion)
   501  }
   502  
   503  func TestPionE2ESimplePSK(t *testing.T) {
   504  	testPionE2ESimplePSK(t, serverPion, clientPion)
   505  }
   506  
   507  func TestPionE2EMTUs(t *testing.T) {
   508  	testPionE2EMTUs(t, serverPion, clientPion)
   509  }
   510  
   511  func TestPionE2ESimpleED25519(t *testing.T) {
   512  	testPionE2ESimpleED25519(t, serverPion, clientPion)
   513  }
   514  
   515  func TestPionE2ESimpleED25519ClientCert(t *testing.T) {
   516  	testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion)
   517  }
   518  
   519  func TestPionE2ESimpleECDSAClientCert(t *testing.T) {
   520  	testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion)
   521  }
   522  
   523  func TestPionE2ESimpleRSAClientCert(t *testing.T) {
   524  	testPionE2ESimpleRSAClientCert(t, serverPion, clientPion)
   525  }