github.com/pion/dtls/v2@v2.2.12/conn_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto"
    10  	"crypto/ecdsa"
    11  	cryptoElliptic "crypto/elliptic"
    12  	"crypto/rand"
    13  	"crypto/rsa"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"encoding/hex"
    17  	"errors"
    18  	"fmt"
    19  	"io"
    20  	"net"
    21  	"strings"
    22  	"sync"
    23  	"sync/atomic"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/pion/dtls/v2/internal/ciphersuite"
    28  	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
    29  	"github.com/pion/dtls/v2/pkg/crypto/hash"
    30  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    31  	"github.com/pion/dtls/v2/pkg/crypto/signature"
    32  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    33  	"github.com/pion/dtls/v2/pkg/protocol"
    34  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    35  	"github.com/pion/dtls/v2/pkg/protocol/extension"
    36  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    37  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    38  	"github.com/pion/logging"
    39  	"github.com/pion/transport/v2/dpipe"
    40  	"github.com/pion/transport/v2/test"
    41  )
    42  
    43  var (
    44  	errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity")
    45  	errPSKRejected            = errors.New("PSK Rejected")
    46  	errNotExpectedChain       = errors.New("not expected chain")
    47  	errExpecedChain           = errors.New("expected chain")
    48  	errWrongCert              = errors.New("wrong cert")
    49  )
    50  
    51  func TestStressDuplex(t *testing.T) {
    52  	// Limit runtime in case of deadlocks
    53  	lim := test.TimeOut(time.Second * 20)
    54  	defer lim.Stop()
    55  
    56  	// Check for leaking routines
    57  	report := test.CheckRoutines(t)
    58  	defer report()
    59  
    60  	// Run the test
    61  	stressDuplex(t)
    62  }
    63  
    64  func stressDuplex(t *testing.T) {
    65  	ca, cb, err := pipeMemory()
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  
    70  	defer func() {
    71  		err = ca.Close()
    72  		if err != nil {
    73  			t.Fatal(err)
    74  		}
    75  		err = cb.Close()
    76  		if err != nil {
    77  			t.Fatal(err)
    78  		}
    79  	}()
    80  
    81  	opt := test.Options{
    82  		MsgSize:  2048,
    83  		MsgCount: 100,
    84  	}
    85  
    86  	err = test.StressDuplex(ca, cb, opt)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  }
    91  
    92  func TestRoutineLeakOnClose(t *testing.T) {
    93  	// Limit runtime in case of deadlocks
    94  	lim := test.TimeOut(5 * time.Second)
    95  	defer lim.Stop()
    96  
    97  	// Check for leaking routines
    98  	report := test.CheckRoutines(t)
    99  	defer report()
   100  
   101  	ca, cb, err := pipeMemory()
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  
   106  	if _, err := ca.Write(make([]byte, 100)); err != nil {
   107  		t.Fatal(err)
   108  	}
   109  	if err := cb.Close(); err != nil {
   110  		t.Fatal(err)
   111  	}
   112  	if err := ca.Close(); err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	// Packet is sent, but not read.
   116  	// inboundLoop routine should not be leaked.
   117  }
   118  
   119  func TestReadWriteDeadline(t *testing.T) {
   120  	// Limit runtime in case of deadlocks
   121  	lim := test.TimeOut(5 * time.Second)
   122  	defer lim.Stop()
   123  
   124  	// Check for leaking routines
   125  	report := test.CheckRoutines(t)
   126  	defer report()
   127  
   128  	var e net.Error
   129  
   130  	ca, cb, err := pipeMemory()
   131  	if err != nil {
   132  		t.Fatal(err)
   133  	}
   134  
   135  	if err := ca.SetDeadline(time.Unix(0, 1)); err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	_, werr := ca.Write(make([]byte, 100))
   139  	if errors.As(werr, &e) {
   140  		if !e.Timeout() {
   141  			t.Error("Deadline exceeded Write must return Timeout error")
   142  		}
   143  		if !e.Temporary() { //nolint:staticcheck
   144  			t.Error("Deadline exceeded Write must return Temporary error")
   145  		}
   146  	} else {
   147  		t.Error("Write must return net.Error error")
   148  	}
   149  	_, rerr := ca.Read(make([]byte, 100))
   150  	if errors.As(rerr, &e) {
   151  		if !e.Timeout() {
   152  			t.Error("Deadline exceeded Read must return Timeout error")
   153  		}
   154  		if !e.Temporary() { //nolint:staticcheck
   155  			t.Error("Deadline exceeded Read must return Temporary error")
   156  		}
   157  	} else {
   158  		t.Error("Read must return net.Error error")
   159  	}
   160  	if err := ca.SetDeadline(time.Time{}); err != nil {
   161  		t.Error(err)
   162  	}
   163  
   164  	if err := ca.Close(); err != nil {
   165  		t.Error(err)
   166  	}
   167  	if err := cb.Close(); err != nil {
   168  		t.Error(err)
   169  	}
   170  
   171  	if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) {
   172  		t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err)
   173  	}
   174  	if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) {
   175  		t.Errorf("Read must return %v after close, got %v", io.EOF, err)
   176  	}
   177  }
   178  
   179  func TestSequenceNumberOverflow(t *testing.T) {
   180  	// Limit runtime in case of deadlocks
   181  	lim := test.TimeOut(5 * time.Second)
   182  	defer lim.Stop()
   183  
   184  	// Check for leaking routines
   185  	report := test.CheckRoutines(t)
   186  	defer report()
   187  
   188  	t.Run("ApplicationData", func(t *testing.T) {
   189  		ca, cb, err := pipeMemory()
   190  		if err != nil {
   191  			t.Fatal(err)
   192  		}
   193  
   194  		atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber)
   195  		if _, werr := ca.Write(make([]byte, 100)); werr != nil {
   196  			t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr)
   197  		}
   198  		if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) {
   199  			t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr)
   200  		}
   201  
   202  		if err := ca.Close(); err != nil {
   203  			t.Error(err)
   204  		}
   205  		if err := cb.Close(); err != nil {
   206  			t.Error(err)
   207  		}
   208  	})
   209  	t.Run("Handshake", func(t *testing.T) {
   210  		ca, cb, err := pipeMemory()
   211  		if err != nil {
   212  			t.Fatal(err)
   213  		}
   214  
   215  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   216  		defer cancel()
   217  
   218  		atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1)
   219  
   220  		// Try to send handshake packet.
   221  		if werr := ca.writePackets(ctx, []*packet{
   222  			{
   223  				record: &recordlayer.RecordLayer{
   224  					Header: recordlayer.Header{
   225  						Version: protocol.Version1_2,
   226  					},
   227  					Content: &handshake.Handshake{
   228  						Message: &handshake.MessageClientHello{
   229  							Version:            protocol.Version1_2,
   230  							Cookie:             make([]byte, 64),
   231  							CipherSuiteIDs:     cipherSuiteIDs(defaultCipherSuites()),
   232  							CompressionMethods: defaultCompressionMethods(),
   233  						},
   234  					},
   235  				},
   236  			},
   237  		}); !errors.Is(werr, errSequenceNumberOverflow) {
   238  			t.Errorf("Connection must fail on handshake packet reaches maximum sequence number")
   239  		}
   240  
   241  		if err := ca.Close(); err != nil {
   242  			t.Error(err)
   243  		}
   244  		if err := cb.Close(); err != nil {
   245  			t.Error(err)
   246  		}
   247  	})
   248  }
   249  
   250  func pipeMemory() (*Conn, *Conn, error) {
   251  	// In memory pipe
   252  	ca, cb := dpipe.Pipe()
   253  	return pipeConn(ca, cb)
   254  }
   255  
   256  func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) {
   257  	type result struct {
   258  		c   *Conn
   259  		err error
   260  	}
   261  
   262  	c := make(chan result)
   263  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   264  	defer cancel()
   265  
   266  	// Setup client
   267  	go func() {
   268  		client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
   269  		c <- result{client, err}
   270  	}()
   271  
   272  	// Setup server
   273  	server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true)
   274  	if err != nil {
   275  		return nil, nil, err
   276  	}
   277  
   278  	// Receive client
   279  	res := <-c
   280  	if res.err != nil {
   281  		_ = server.Close()
   282  		return nil, nil, res.err
   283  	}
   284  
   285  	return res.c, server, nil
   286  }
   287  
   288  func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
   289  	if generateCertificate {
   290  		clientCert, err := selfsign.GenerateSelfSigned()
   291  		if err != nil {
   292  			return nil, err
   293  		}
   294  		cfg.Certificates = []tls.Certificate{clientCert}
   295  	}
   296  	cfg.InsecureSkipVerify = true
   297  	return ClientWithContext(ctx, c, cfg)
   298  }
   299  
   300  func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) {
   301  	if generateCertificate {
   302  		serverCert, err := selfsign.GenerateSelfSigned()
   303  		if err != nil {
   304  			return nil, err
   305  		}
   306  		cfg.Certificates = []tls.Certificate{serverCert}
   307  	}
   308  	return ServerWithContext(ctx, c, cfg)
   309  }
   310  
   311  func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error {
   312  	packet, err := (&recordlayer.RecordLayer{
   313  		Header: recordlayer.Header{
   314  			Version:        protocol.Version1_2,
   315  			SequenceNumber: sequenceNumber,
   316  		},
   317  		Content: &handshake.Handshake{
   318  			Header: handshake.Header{
   319  				MessageSequence: uint16(sequenceNumber),
   320  			},
   321  			Message: &handshake.MessageClientHello{
   322  				Version:            protocol.Version1_2,
   323  				Cookie:             cookie,
   324  				CipherSuiteIDs:     cipherSuiteIDs(defaultCipherSuites()),
   325  				CompressionMethods: defaultCompressionMethods(),
   326  				Extensions:         extensions,
   327  			},
   328  		},
   329  	}).Marshal()
   330  	if err != nil {
   331  		return err
   332  	}
   333  
   334  	if _, err = ca.Write(packet); err != nil {
   335  		return err
   336  	}
   337  	return nil
   338  }
   339  
   340  func TestHandshakeWithAlert(t *testing.T) {
   341  	// Limit runtime in case of deadlocks
   342  	lim := test.TimeOut(time.Second * 20)
   343  	defer lim.Stop()
   344  
   345  	// Check for leaking routines
   346  	report := test.CheckRoutines(t)
   347  	defer report()
   348  
   349  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   350  	defer cancel()
   351  
   352  	cases := map[string]struct {
   353  		configServer, configClient *Config
   354  		errServer, errClient       error
   355  	}{
   356  		"CipherSuiteNoIntersection": {
   357  			configServer: &Config{
   358  				CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   359  			},
   360  			configClient: &Config{
   361  				CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   362  			},
   363  			errServer: errCipherSuiteNoIntersection,
   364  			errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
   365  		},
   366  		"SignatureSchemesNoIntersection": {
   367  			configServer: &Config{
   368  				CipherSuites:     []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   369  				SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256},
   370  			},
   371  			configClient: &Config{
   372  				CipherSuites:     []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   373  				SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512},
   374  			},
   375  			errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
   376  			errClient: errNoAvailableSignatureSchemes,
   377  		},
   378  	}
   379  
   380  	for name, testCase := range cases {
   381  		testCase := testCase
   382  		t.Run(name, func(t *testing.T) {
   383  			clientErr := make(chan error, 1)
   384  
   385  			ca, cb := dpipe.Pipe()
   386  			go func() {
   387  				_, err := testClient(ctx, ca, testCase.configClient, true)
   388  				clientErr <- err
   389  			}()
   390  
   391  			_, errServer := testServer(ctx, cb, testCase.configServer, true)
   392  			if !errors.Is(errServer, testCase.errServer) {
   393  				t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer)
   394  			}
   395  
   396  			errClient := <-clientErr
   397  			if !errors.Is(errClient, testCase.errClient) {
   398  				t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient)
   399  			}
   400  		})
   401  	}
   402  }
   403  
   404  func TestHandshakeWithInvalidRecord(t *testing.T) {
   405  	// Limit runtime in case of deadlocks
   406  	lim := test.TimeOut(time.Second * 20)
   407  	defer lim.Stop()
   408  
   409  	// Check for leaking routines
   410  	report := test.CheckRoutines(t)
   411  	defer report()
   412  
   413  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   414  	defer cancel()
   415  
   416  	type result struct {
   417  		c   *Conn
   418  		err error
   419  	}
   420  	clientErr := make(chan result, 1)
   421  	ca, cb := dpipe.Pipe()
   422  	caWithInvalidRecord := &connWithCallback{Conn: ca}
   423  
   424  	var msgSeq atomic.Int32
   425  	// Send invalid record after first message
   426  	caWithInvalidRecord.onWrite = func(b []byte) {
   427  		if msgSeq.Add(1) == 2 {
   428  			if _, err := ca.Write([]byte{0x01, 0x02}); err != nil {
   429  				t.Fatal(err)
   430  			}
   431  		}
   432  	}
   433  	go func() {
   434  		client, err := testClient(ctx, caWithInvalidRecord, &Config{
   435  			CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   436  		}, true)
   437  		clientErr <- result{client, err}
   438  	}()
   439  
   440  	server, errServer := testServer(ctx, cb, &Config{
   441  		CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
   442  	}, true)
   443  
   444  	errClient := <-clientErr
   445  
   446  	defer func() {
   447  		if server != nil {
   448  			if err := server.Close(); err != nil {
   449  				t.Fatal(err)
   450  			}
   451  		}
   452  
   453  		if errClient.c != nil {
   454  			if err := errClient.c.Close(); err != nil {
   455  				t.Fatal(err)
   456  			}
   457  		}
   458  	}()
   459  
   460  	if errServer != nil {
   461  		t.Fatalf("Server failed(%v)", errServer)
   462  	}
   463  
   464  	if errClient.err != nil {
   465  		t.Fatalf("Client failed(%v)", errClient.err)
   466  	}
   467  }
   468  
   469  func TestExportKeyingMaterial(t *testing.T) {
   470  	// Check for leaking routines
   471  	report := test.CheckRoutines(t)
   472  	defer report()
   473  
   474  	var rand [28]byte
   475  	exportLabel := "EXTRACTOR-dtls_srtp"
   476  
   477  	expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b}
   478  	expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77}
   479  
   480  	c := &Conn{
   481  		state: State{
   482  			localRandom:         handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand},
   483  			remoteRandom:        handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand},
   484  			localSequenceNumber: []uint64{0, 0},
   485  			cipherSuite:         &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
   486  		},
   487  	}
   488  	c.setLocalEpoch(0)
   489  	c.setRemoteEpoch(0)
   490  
   491  	state := c.ConnectionState()
   492  	_, err := state.ExportKeyingMaterial(exportLabel, nil, 0)
   493  	if !errors.Is(err, errHandshakeInProgress) {
   494  		t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
   495  	}
   496  
   497  	c.setLocalEpoch(1)
   498  	state = c.ConnectionState()
   499  	_, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
   500  	if !errors.Is(err, errContextUnsupported) {
   501  		t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
   502  	}
   503  
   504  	for k := range invalidKeyingLabels() {
   505  		state = c.ConnectionState()
   506  		_, err = state.ExportKeyingMaterial(k, nil, 0)
   507  		if !errors.Is(err, errReservedExportKeyingMaterial) {
   508  			t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
   509  		}
   510  	}
   511  
   512  	state = c.ConnectionState()
   513  	keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10)
   514  	if err != nil {
   515  		t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
   516  	} else if !bytes.Equal(keyingMaterial, expectedServerKey) {
   517  		t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial)
   518  	}
   519  
   520  	c.state.isClient = true
   521  	state = c.ConnectionState()
   522  	keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10)
   523  	if err != nil {
   524  		t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
   525  	} else if !bytes.Equal(keyingMaterial, expectedClientKey) {
   526  		t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial)
   527  	}
   528  }
   529  
   530  func TestPSK(t *testing.T) {
   531  	// Limit runtime in case of deadlocks
   532  	lim := test.TimeOut(time.Second * 20)
   533  	defer lim.Stop()
   534  
   535  	// Check for leaking routines
   536  	report := test.CheckRoutines(t)
   537  	defer report()
   538  
   539  	for _, test := range []struct {
   540  		Name                   string
   541  		ServerIdentity         []byte
   542  		CipherSuites           []CipherSuiteID
   543  		ClientVerifyConnection func(*State) error
   544  		ServerVerifyConnection func(*State) error
   545  		WantFail               bool
   546  		ExpectedServerErr      string
   547  		ExpectedClientErr      string
   548  	}{
   549  		{
   550  			Name:           "Server identity specified",
   551  			ServerIdentity: []byte("Test Identity"),
   552  			CipherSuites:   []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   553  		},
   554  		{
   555  			Name:           "Server identity specified - Server verify connection fails",
   556  			ServerIdentity: []byte("Test Identity"),
   557  			CipherSuites:   []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   558  			ServerVerifyConnection: func(s *State) error {
   559  				return errExample
   560  			},
   561  			WantFail:          true,
   562  			ExpectedServerErr: errExample.Error(),
   563  			ExpectedClientErr: alert.BadCertificate.String(),
   564  		},
   565  		{
   566  			Name:           "Server identity specified - Client verify connection fails",
   567  			ServerIdentity: []byte("Test Identity"),
   568  			CipherSuites:   []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   569  			ClientVerifyConnection: func(s *State) error {
   570  				return errExample
   571  			},
   572  			WantFail:          true,
   573  			ExpectedServerErr: alert.BadCertificate.String(),
   574  			ExpectedClientErr: errExample.Error(),
   575  		},
   576  		{
   577  			Name:           "Server identity nil",
   578  			ServerIdentity: nil,
   579  			CipherSuites:   []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   580  		},
   581  		{
   582  			Name:           "TLS_PSK_WITH_AES_128_CBC_SHA256",
   583  			ServerIdentity: nil,
   584  			CipherSuites:   []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256},
   585  		},
   586  		{
   587  			Name:           "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256",
   588  			ServerIdentity: nil,
   589  			CipherSuites:   []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256},
   590  		},
   591  	} {
   592  		test := test
   593  		t.Run(test.Name, func(t *testing.T) {
   594  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   595  			defer cancel()
   596  
   597  			clientIdentity := []byte("Client Identity")
   598  			type result struct {
   599  				c   *Conn
   600  				err error
   601  			}
   602  			clientRes := make(chan result, 1)
   603  
   604  			ca, cb := dpipe.Pipe()
   605  			go func() {
   606  				conf := &Config{
   607  					PSK: func(hint []byte) ([]byte, error) {
   608  						if !bytes.Equal(test.ServerIdentity, hint) {
   609  							return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113
   610  						}
   611  
   612  						return []byte{0xAB, 0xC1, 0x23}, nil
   613  					},
   614  					PSKIdentityHint:  clientIdentity,
   615  					CipherSuites:     test.CipherSuites,
   616  					VerifyConnection: test.ClientVerifyConnection,
   617  				}
   618  
   619  				c, err := testClient(ctx, ca, conf, false)
   620  				clientRes <- result{c, err}
   621  			}()
   622  
   623  			config := &Config{
   624  				PSK: func(hint []byte) ([]byte, error) {
   625  					if !bytes.Equal(clientIdentity, hint) {
   626  						return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint)
   627  					}
   628  					return []byte{0xAB, 0xC1, 0x23}, nil
   629  				},
   630  				PSKIdentityHint:  test.ServerIdentity,
   631  				CipherSuites:     test.CipherSuites,
   632  				VerifyConnection: test.ServerVerifyConnection,
   633  			}
   634  
   635  			server, err := testServer(ctx, cb, config, false)
   636  			if test.WantFail {
   637  				res := <-clientRes
   638  				if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) {
   639  					t.Fatalf("TestPSK: Server expected(%v) actual(%v)", test.ExpectedServerErr, err)
   640  				}
   641  				if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) {
   642  					t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err)
   643  				}
   644  				return
   645  			}
   646  			if err != nil {
   647  				t.Fatalf("TestPSK: Server failed(%v)", err)
   648  			}
   649  
   650  			actualPSKIdentityHint := server.ConnectionState().IdentityHint
   651  			if !bytes.Equal(actualPSKIdentityHint, clientIdentity) {
   652  				t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint)
   653  			}
   654  
   655  			defer func() {
   656  				_ = server.Close()
   657  			}()
   658  
   659  			res := <-clientRes
   660  			if res.err != nil {
   661  				t.Fatal(res.err)
   662  			}
   663  			_ = res.c.Close()
   664  		})
   665  	}
   666  }
   667  
   668  func TestPSKHintFail(t *testing.T) {
   669  	// Check for leaking routines
   670  	report := test.CheckRoutines(t)
   671  	defer report()
   672  
   673  	serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}}
   674  	pskRejected := errPSKRejected
   675  
   676  	// Limit runtime in case of deadlocks
   677  	lim := test.TimeOut(time.Second * 20)
   678  	defer lim.Stop()
   679  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   680  	defer cancel()
   681  
   682  	clientErr := make(chan error, 1)
   683  
   684  	ca, cb := dpipe.Pipe()
   685  	go func() {
   686  		conf := &Config{
   687  			PSK: func(hint []byte) ([]byte, error) {
   688  				return nil, pskRejected
   689  			},
   690  			PSKIdentityHint: []byte{},
   691  			CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   692  		}
   693  
   694  		_, err := testClient(ctx, ca, conf, false)
   695  		clientErr <- err
   696  	}()
   697  
   698  	config := &Config{
   699  		PSK: func(hint []byte) ([]byte, error) {
   700  			return nil, pskRejected
   701  		},
   702  		PSKIdentityHint: []byte{},
   703  		CipherSuites:    []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
   704  	}
   705  
   706  	if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) {
   707  		t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err)
   708  	}
   709  
   710  	if err := <-clientErr; !errors.Is(err, pskRejected) {
   711  		t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err)
   712  	}
   713  }
   714  
   715  func TestClientTimeout(t *testing.T) {
   716  	// Limit runtime in case of deadlocks
   717  	lim := test.TimeOut(time.Second * 20)
   718  	defer lim.Stop()
   719  
   720  	// Check for leaking routines
   721  	report := test.CheckRoutines(t)
   722  	defer report()
   723  
   724  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   725  	defer cancel()
   726  
   727  	clientErr := make(chan error, 1)
   728  
   729  	ca, _ := dpipe.Pipe()
   730  	go func() {
   731  		conf := &Config{}
   732  
   733  		c, err := testClient(ctx, ca, conf, true)
   734  		if err == nil {
   735  			_ = c.Close() //nolint:contextcheck
   736  		}
   737  		clientErr <- err
   738  	}()
   739  
   740  	// no server!
   741  	err := <-clientErr
   742  	var netErr net.Error
   743  	if !errors.As(err, &netErr) || !netErr.Timeout() {
   744  		t.Fatalf("Client error exp(Temporary network error) failed(%v)", err)
   745  	}
   746  }
   747  
   748  func TestSRTPConfiguration(t *testing.T) {
   749  	// Check for leaking routines
   750  	report := test.CheckRoutines(t)
   751  	defer report()
   752  
   753  	for _, test := range []struct {
   754  		Name            string
   755  		ClientSRTP      []SRTPProtectionProfile
   756  		ServerSRTP      []SRTPProtectionProfile
   757  		ExpectedProfile SRTPProtectionProfile
   758  		WantClientError error
   759  		WantServerError error
   760  	}{
   761  		{
   762  			Name:            "No SRTP in use",
   763  			ClientSRTP:      nil,
   764  			ServerSRTP:      nil,
   765  			ExpectedProfile: 0,
   766  			WantClientError: nil,
   767  			WantServerError: nil,
   768  		},
   769  		{
   770  			Name:            "SRTP both ends",
   771  			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
   772  			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
   773  			ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
   774  			WantClientError: nil,
   775  			WantServerError: nil,
   776  		},
   777  		{
   778  			Name:            "SRTP client only",
   779  			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
   780  			ServerSRTP:      nil,
   781  			ExpectedProfile: 0,
   782  			WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
   783  			WantServerError: errServerNoMatchingSRTPProfile,
   784  		},
   785  		{
   786  			Name:            "SRTP server only",
   787  			ClientSRTP:      nil,
   788  			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
   789  			ExpectedProfile: 0,
   790  			WantClientError: nil,
   791  			WantServerError: nil,
   792  		},
   793  		{
   794  			Name:            "Multiple Suites",
   795  			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
   796  			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
   797  			ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
   798  			WantClientError: nil,
   799  			WantServerError: nil,
   800  		},
   801  		{
   802  			Name:            "Multiple Suites, Client Chooses",
   803  			ClientSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32},
   804  			ServerSRTP:      []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80},
   805  			ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
   806  			WantClientError: nil,
   807  			WantServerError: nil,
   808  		},
   809  	} {
   810  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   811  		defer cancel()
   812  
   813  		ca, cb := dpipe.Pipe()
   814  		type result struct {
   815  			c   *Conn
   816  			err error
   817  		}
   818  		c := make(chan result)
   819  
   820  		go func() {
   821  			client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true)
   822  			c <- result{client, err}
   823  		}()
   824  
   825  		server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true)
   826  		if !errors.Is(err, test.WantServerError) {
   827  			t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
   828  		}
   829  		if err == nil {
   830  			defer func() {
   831  				_ = server.Close()
   832  			}()
   833  		}
   834  
   835  		res := <-c
   836  		if res.err == nil {
   837  			defer func() {
   838  				_ = res.c.Close()
   839  			}()
   840  		}
   841  		if !errors.Is(res.err, test.WantClientError) {
   842  			t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
   843  		}
   844  		if res.c == nil {
   845  			return
   846  		}
   847  
   848  		actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile()
   849  		if actualClientSRTP != test.ExpectedProfile {
   850  			t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP)
   851  		}
   852  
   853  		actualServerSRTP, _ := server.SelectedSRTPProtectionProfile()
   854  		if actualServerSRTP != test.ExpectedProfile {
   855  			t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
   856  		}
   857  	}
   858  }
   859  
   860  func TestClientCertificate(t *testing.T) {
   861  	// Check for leaking routines
   862  	report := test.CheckRoutines(t)
   863  	defer report()
   864  
   865  	srvCert, err := selfsign.GenerateSelfSigned()
   866  	if err != nil {
   867  		t.Fatal(err)
   868  	}
   869  	srvCAPool := x509.NewCertPool()
   870  	srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0])
   871  	if err != nil {
   872  		t.Fatal(err)
   873  	}
   874  	srvCAPool.AddCert(srvCertificate)
   875  
   876  	cert, err := selfsign.GenerateSelfSigned()
   877  	if err != nil {
   878  		t.Fatal(err)
   879  	}
   880  	certificate, err := x509.ParseCertificate(cert.Certificate[0])
   881  	if err != nil {
   882  		t.Fatal(err)
   883  	}
   884  	caPool := x509.NewCertPool()
   885  	caPool.AddCert(certificate)
   886  
   887  	t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
   888  		tests := map[string]struct {
   889  			clientCfg *Config
   890  			serverCfg *Config
   891  			wantErr   bool
   892  		}{
   893  			"NoClientCert": {
   894  				clientCfg: &Config{RootCAs: srvCAPool},
   895  				serverCfg: &Config{
   896  					Certificates: []tls.Certificate{srvCert},
   897  					ClientAuth:   NoClientCert,
   898  					ClientCAs:    caPool,
   899  				},
   900  			},
   901  			"NoClientCert_ServerVerifyConnectionFails": {
   902  				clientCfg: &Config{RootCAs: srvCAPool},
   903  				serverCfg: &Config{
   904  					Certificates: []tls.Certificate{srvCert},
   905  					ClientAuth:   NoClientCert,
   906  					ClientCAs:    caPool,
   907  					VerifyConnection: func(s *State) error {
   908  						return errExample
   909  					},
   910  				},
   911  				wantErr: true,
   912  			},
   913  			"NoClientCert_ClientVerifyConnectionFails": {
   914  				clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(s *State) error {
   915  					return errExample
   916  				}},
   917  				serverCfg: &Config{
   918  					Certificates: []tls.Certificate{srvCert},
   919  					ClientAuth:   NoClientCert,
   920  					ClientCAs:    caPool,
   921  				},
   922  				wantErr: true,
   923  			},
   924  			"NoClientCert_cert": {
   925  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
   926  				serverCfg: &Config{
   927  					Certificates: []tls.Certificate{srvCert},
   928  					ClientAuth:   RequireAnyClientCert,
   929  				},
   930  			},
   931  			"RequestClientCert_cert": {
   932  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
   933  				serverCfg: &Config{
   934  					Certificates: []tls.Certificate{srvCert},
   935  					ClientAuth:   RequestClientCert,
   936  				},
   937  			},
   938  			"RequestClientCert_no_cert": {
   939  				clientCfg: &Config{RootCAs: srvCAPool},
   940  				serverCfg: &Config{
   941  					Certificates: []tls.Certificate{srvCert},
   942  					ClientAuth:   RequestClientCert,
   943  					ClientCAs:    caPool,
   944  				},
   945  			},
   946  			"RequireAnyClientCert": {
   947  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
   948  				serverCfg: &Config{
   949  					Certificates: []tls.Certificate{srvCert},
   950  					ClientAuth:   RequireAnyClientCert,
   951  				},
   952  			},
   953  			"RequireAnyClientCert_error": {
   954  				clientCfg: &Config{RootCAs: srvCAPool},
   955  				serverCfg: &Config{
   956  					Certificates: []tls.Certificate{srvCert},
   957  					ClientAuth:   RequireAnyClientCert,
   958  				},
   959  				wantErr: true,
   960  			},
   961  			"VerifyClientCertIfGiven_no_cert": {
   962  				clientCfg: &Config{RootCAs: srvCAPool},
   963  				serverCfg: &Config{
   964  					Certificates: []tls.Certificate{srvCert},
   965  					ClientAuth:   VerifyClientCertIfGiven,
   966  					ClientCAs:    caPool,
   967  				},
   968  			},
   969  			"VerifyClientCertIfGiven_cert": {
   970  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
   971  				serverCfg: &Config{
   972  					Certificates: []tls.Certificate{srvCert},
   973  					ClientAuth:   VerifyClientCertIfGiven,
   974  					ClientCAs:    caPool,
   975  				},
   976  			},
   977  			"VerifyClientCertIfGiven_error": {
   978  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}},
   979  				serverCfg: &Config{
   980  					Certificates: []tls.Certificate{srvCert},
   981  					ClientAuth:   VerifyClientCertIfGiven,
   982  				},
   983  				wantErr: true,
   984  			},
   985  			"RequireAndVerifyClientCert": {
   986  				clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error {
   987  					if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok {
   988  						return errExample
   989  					}
   990  					return nil
   991  				}},
   992  				serverCfg: &Config{
   993  					Certificates: []tls.Certificate{srvCert},
   994  					ClientAuth:   RequireAndVerifyClientCert,
   995  					ClientCAs:    caPool,
   996  					VerifyConnection: func(s *State) error {
   997  						if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok {
   998  							return errExample
   999  						}
  1000  						return nil
  1001  					},
  1002  				},
  1003  			},
  1004  			"RequireAndVerifyClientCert_callbacks": {
  1005  				clientCfg: &Config{
  1006  					RootCAs: srvCAPool,
  1007  					// Certificates:   []tls.Certificate{cert},
  1008  					GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil },
  1009  				},
  1010  				serverCfg: &Config{
  1011  					GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil },
  1012  					// Certificates:   []tls.Certificate{srvCert},
  1013  					ClientAuth: RequireAndVerifyClientCert,
  1014  					ClientCAs:  caPool,
  1015  				},
  1016  			},
  1017  		}
  1018  		for name, tt := range tests {
  1019  			tt := tt
  1020  			t.Run(name, func(t *testing.T) {
  1021  				ca, cb := dpipe.Pipe()
  1022  				type result struct {
  1023  					c   *Conn
  1024  					err error
  1025  				}
  1026  				c := make(chan result)
  1027  
  1028  				go func() {
  1029  					client, err := Client(ca, tt.clientCfg)
  1030  					c <- result{client, err}
  1031  				}()
  1032  
  1033  				server, err := Server(cb, tt.serverCfg)
  1034  				res := <-c
  1035  				defer func() {
  1036  					if err == nil {
  1037  						_ = server.Close()
  1038  					}
  1039  					if res.err == nil {
  1040  						_ = res.c.Close()
  1041  					}
  1042  				}()
  1043  
  1044  				if tt.wantErr {
  1045  					if err != nil {
  1046  						// Error expected, test succeeded
  1047  						return
  1048  					}
  1049  					t.Error("Error expected")
  1050  				}
  1051  				if err != nil {
  1052  					t.Errorf("Server failed(%v)", err)
  1053  				}
  1054  
  1055  				if res.err != nil {
  1056  					t.Errorf("Client failed(%v)", res.err)
  1057  				}
  1058  
  1059  				actualClientCert := server.ConnectionState().PeerCertificates
  1060  				if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
  1061  					if actualClientCert == nil {
  1062  						t.Errorf("Client did not provide a certificate")
  1063  					}
  1064  
  1065  					var cfgCert [][]byte
  1066  					if len(tt.clientCfg.Certificates) > 0 {
  1067  						cfgCert = tt.clientCfg.Certificates[0].Certificate
  1068  					}
  1069  					if tt.clientCfg.GetClientCertificate != nil {
  1070  						crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{})
  1071  						if err != nil {
  1072  							t.Errorf("Server configuration did not provide a certificate")
  1073  						}
  1074  						cfgCert = crt.Certificate
  1075  					}
  1076  					if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualClientCert[0]) {
  1077  						t.Errorf("Client certificate was not communicated correctly")
  1078  					}
  1079  				}
  1080  				if tt.serverCfg.ClientAuth == NoClientCert {
  1081  					if actualClientCert != nil {
  1082  						t.Errorf("Client certificate wasn't expected")
  1083  					}
  1084  				}
  1085  
  1086  				actualServerCert := res.c.ConnectionState().PeerCertificates
  1087  				if actualServerCert == nil {
  1088  					t.Errorf("Server did not provide a certificate")
  1089  				}
  1090  				var cfgCert [][]byte
  1091  				if len(tt.serverCfg.Certificates) > 0 {
  1092  					cfgCert = tt.serverCfg.Certificates[0].Certificate
  1093  				}
  1094  				if tt.serverCfg.GetCertificate != nil {
  1095  					crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{})
  1096  					if err != nil {
  1097  						t.Errorf("Server configuration did not provide a certificate")
  1098  					}
  1099  					cfgCert = crt.Certificate
  1100  				}
  1101  				if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualServerCert[0]) {
  1102  					t.Errorf("Server certificate was not communicated correctly")
  1103  				}
  1104  			})
  1105  		}
  1106  	})
  1107  }
  1108  
  1109  func TestExtendedMasterSecret(t *testing.T) {
  1110  	// Check for leaking routines
  1111  	report := test.CheckRoutines(t)
  1112  	defer report()
  1113  
  1114  	tests := map[string]struct {
  1115  		clientCfg         *Config
  1116  		serverCfg         *Config
  1117  		expectedClientErr error
  1118  		expectedServerErr error
  1119  	}{
  1120  		"Request_Request_ExtendedMasterSecret": {
  1121  			clientCfg: &Config{
  1122  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1123  			},
  1124  			serverCfg: &Config{
  1125  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1126  			},
  1127  			expectedClientErr: nil,
  1128  			expectedServerErr: nil,
  1129  		},
  1130  		"Request_Require_ExtendedMasterSecret": {
  1131  			clientCfg: &Config{
  1132  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1133  			},
  1134  			serverCfg: &Config{
  1135  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1136  			},
  1137  			expectedClientErr: nil,
  1138  			expectedServerErr: nil,
  1139  		},
  1140  		"Request_Disable_ExtendedMasterSecret": {
  1141  			clientCfg: &Config{
  1142  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1143  			},
  1144  			serverCfg: &Config{
  1145  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1146  			},
  1147  			expectedClientErr: nil,
  1148  			expectedServerErr: nil,
  1149  		},
  1150  		"Require_Request_ExtendedMasterSecret": {
  1151  			clientCfg: &Config{
  1152  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1153  			},
  1154  			serverCfg: &Config{
  1155  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1156  			},
  1157  			expectedClientErr: nil,
  1158  			expectedServerErr: nil,
  1159  		},
  1160  		"Require_Require_ExtendedMasterSecret": {
  1161  			clientCfg: &Config{
  1162  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1163  			},
  1164  			serverCfg: &Config{
  1165  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1166  			},
  1167  			expectedClientErr: nil,
  1168  			expectedServerErr: nil,
  1169  		},
  1170  		"Require_Disable_ExtendedMasterSecret": {
  1171  			clientCfg: &Config{
  1172  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1173  			},
  1174  			serverCfg: &Config{
  1175  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1176  			},
  1177  			expectedClientErr: errClientRequiredButNoServerEMS,
  1178  			expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
  1179  		},
  1180  		"Disable_Request_ExtendedMasterSecret": {
  1181  			clientCfg: &Config{
  1182  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1183  			},
  1184  			serverCfg: &Config{
  1185  				ExtendedMasterSecret: RequestExtendedMasterSecret,
  1186  			},
  1187  			expectedClientErr: nil,
  1188  			expectedServerErr: nil,
  1189  		},
  1190  		"Disable_Require_ExtendedMasterSecret": {
  1191  			clientCfg: &Config{
  1192  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1193  			},
  1194  			serverCfg: &Config{
  1195  				ExtendedMasterSecret: RequireExtendedMasterSecret,
  1196  			},
  1197  			expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
  1198  			expectedServerErr: errServerRequiredButNoClientEMS,
  1199  		},
  1200  		"Disable_Disable_ExtendedMasterSecret": {
  1201  			clientCfg: &Config{
  1202  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1203  			},
  1204  			serverCfg: &Config{
  1205  				ExtendedMasterSecret: DisableExtendedMasterSecret,
  1206  			},
  1207  			expectedClientErr: nil,
  1208  			expectedServerErr: nil,
  1209  		},
  1210  	}
  1211  	for name, tt := range tests {
  1212  		tt := tt
  1213  		t.Run(name, func(t *testing.T) {
  1214  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1215  			defer cancel()
  1216  
  1217  			ca, cb := dpipe.Pipe()
  1218  			type result struct {
  1219  				c   *Conn
  1220  				err error
  1221  			}
  1222  			c := make(chan result)
  1223  
  1224  			go func() {
  1225  				client, err := testClient(ctx, ca, tt.clientCfg, true)
  1226  				c <- result{client, err}
  1227  			}()
  1228  
  1229  			server, err := testServer(ctx, cb, tt.serverCfg, true)
  1230  			res := <-c
  1231  			defer func() {
  1232  				if err == nil {
  1233  					_ = server.Close()
  1234  				}
  1235  				if res.err == nil {
  1236  					_ = res.c.Close()
  1237  				}
  1238  			}()
  1239  
  1240  			if !errors.Is(res.err, tt.expectedClientErr) {
  1241  				t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err)
  1242  			}
  1243  
  1244  			if !errors.Is(err, tt.expectedServerErr) {
  1245  				t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err)
  1246  			}
  1247  		})
  1248  	}
  1249  }
  1250  
  1251  func TestServerCertificate(t *testing.T) {
  1252  	// Check for leaking routines
  1253  	report := test.CheckRoutines(t)
  1254  	defer report()
  1255  
  1256  	cert, err := selfsign.GenerateSelfSigned()
  1257  	if err != nil {
  1258  		t.Fatal(err)
  1259  	}
  1260  	certificate, err := x509.ParseCertificate(cert.Certificate[0])
  1261  	if err != nil {
  1262  		t.Fatal(err)
  1263  	}
  1264  	caPool := x509.NewCertPool()
  1265  	caPool.AddCert(certificate)
  1266  
  1267  	t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak
  1268  		tests := map[string]struct {
  1269  			clientCfg *Config
  1270  			serverCfg *Config
  1271  			wantErr   bool
  1272  		}{
  1273  			"no_ca": {
  1274  				clientCfg: &Config{},
  1275  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1276  				wantErr:   true,
  1277  			},
  1278  			"good_ca": {
  1279  				clientCfg: &Config{RootCAs: caPool},
  1280  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1281  			},
  1282  			"no_ca_skip_verify": {
  1283  				clientCfg: &Config{InsecureSkipVerify: true},
  1284  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1285  			},
  1286  			"good_ca_skip_verify_custom_verify_peer": {
  1287  				clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
  1288  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
  1289  					if len(chain) != 0 {
  1290  						return errNotExpectedChain
  1291  					}
  1292  					return nil
  1293  				}},
  1294  			},
  1295  			"good_ca_verify_custom_verify_peer": {
  1296  				clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}},
  1297  				serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error {
  1298  					if len(chain) == 0 {
  1299  						return errExpecedChain
  1300  					}
  1301  					return nil
  1302  				}},
  1303  			},
  1304  			"good_ca_custom_verify_peer": {
  1305  				clientCfg: &Config{
  1306  					RootCAs: caPool,
  1307  					VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
  1308  						return errWrongCert
  1309  					},
  1310  				},
  1311  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1312  				wantErr:   true,
  1313  			},
  1314  			"server_name": {
  1315  				clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName},
  1316  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1317  			},
  1318  			"server_name_error": {
  1319  				clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"},
  1320  				serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert},
  1321  				wantErr:   true,
  1322  			},
  1323  		}
  1324  		for name, tt := range tests {
  1325  			tt := tt
  1326  			t.Run(name, func(t *testing.T) {
  1327  				ca, cb := dpipe.Pipe()
  1328  
  1329  				type result struct {
  1330  					c   *Conn
  1331  					err error
  1332  				}
  1333  				srvCh := make(chan result)
  1334  				go func() {
  1335  					s, err := Server(cb, tt.serverCfg)
  1336  					srvCh <- result{s, err}
  1337  				}()
  1338  
  1339  				cli, err := Client(ca, tt.clientCfg)
  1340  				if err == nil {
  1341  					_ = cli.Close()
  1342  				}
  1343  				if !tt.wantErr && err != nil {
  1344  					t.Errorf("Client failed(%v)", err)
  1345  				}
  1346  				if tt.wantErr && err == nil {
  1347  					t.Fatal("Error expected")
  1348  				}
  1349  
  1350  				srv := <-srvCh
  1351  				if srv.err == nil {
  1352  					_ = srv.c.Close()
  1353  				}
  1354  			})
  1355  		}
  1356  	})
  1357  }
  1358  
  1359  func TestCipherSuiteConfiguration(t *testing.T) {
  1360  	// Check for leaking routines
  1361  	report := test.CheckRoutines(t)
  1362  	defer report()
  1363  
  1364  	for _, test := range []struct {
  1365  		Name                    string
  1366  		ClientCipherSuites      []CipherSuiteID
  1367  		ServerCipherSuites      []CipherSuiteID
  1368  		WantClientError         error
  1369  		WantServerError         error
  1370  		WantSelectedCipherSuite CipherSuiteID
  1371  	}{
  1372  		{
  1373  			Name:               "No CipherSuites specified",
  1374  			ClientCipherSuites: nil,
  1375  			ServerCipherSuites: nil,
  1376  			WantClientError:    nil,
  1377  			WantServerError:    nil,
  1378  		},
  1379  		{
  1380  			Name:               "Invalid CipherSuite",
  1381  			ClientCipherSuites: []CipherSuiteID{0x00},
  1382  			ServerCipherSuites: []CipherSuiteID{0x00},
  1383  			WantClientError:    &invalidCipherSuiteError{0x00},
  1384  			WantServerError:    &invalidCipherSuiteError{0x00},
  1385  		},
  1386  		{
  1387  			Name:                    "Valid CipherSuites specified",
  1388  			ClientCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  1389  			ServerCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  1390  			WantClientError:         nil,
  1391  			WantServerError:         nil,
  1392  			WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  1393  		},
  1394  		{
  1395  			Name:               "CipherSuites mismatch",
  1396  			ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  1397  			ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
  1398  			WantClientError:    &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}},
  1399  			WantServerError:    errCipherSuiteNoIntersection,
  1400  		},
  1401  		{
  1402  			Name:                    "Valid CipherSuites CCM specified",
  1403  			ClientCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
  1404  			ServerCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM},
  1405  			WantClientError:         nil,
  1406  			WantServerError:         nil,
  1407  			WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
  1408  		},
  1409  		{
  1410  			Name:                    "Valid CipherSuites CCM-8 specified",
  1411  			ClientCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
  1412  			ServerCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8},
  1413  			WantClientError:         nil,
  1414  			WantServerError:         nil,
  1415  			WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
  1416  		},
  1417  		{
  1418  			Name:                    "Server supports subset of client suites",
  1419  			ClientCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
  1420  			ServerCipherSuites:      []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
  1421  			WantClientError:         nil,
  1422  			WantServerError:         nil,
  1423  			WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
  1424  		},
  1425  	} {
  1426  		test := test
  1427  		t.Run(test.Name, func(t *testing.T) {
  1428  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1429  			defer cancel()
  1430  
  1431  			ca, cb := dpipe.Pipe()
  1432  			type result struct {
  1433  				c   *Conn
  1434  				err error
  1435  			}
  1436  			c := make(chan result)
  1437  
  1438  			go func() {
  1439  				client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true)
  1440  				c <- result{client, err}
  1441  			}()
  1442  
  1443  			server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true)
  1444  			if err == nil {
  1445  				defer func() {
  1446  					_ = server.Close()
  1447  				}()
  1448  			}
  1449  			if !errors.Is(err, test.WantServerError) {
  1450  				t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
  1451  			}
  1452  
  1453  			res := <-c
  1454  			if res.err == nil {
  1455  				_ = server.Close()
  1456  				_ = res.c.Close()
  1457  			}
  1458  			if !errors.Is(res.err, test.WantClientError) {
  1459  				t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
  1460  			}
  1461  			if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite {
  1462  				t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID())
  1463  			}
  1464  		})
  1465  	}
  1466  }
  1467  
  1468  func TestCertificateAndPSKServer(t *testing.T) {
  1469  	// Check for leaking routines
  1470  	report := test.CheckRoutines(t)
  1471  	defer report()
  1472  
  1473  	for _, test := range []struct {
  1474  		Name      string
  1475  		ClientPSK bool
  1476  	}{
  1477  		{
  1478  			Name:      "Client uses PKI",
  1479  			ClientPSK: false,
  1480  		},
  1481  		{
  1482  			Name:      "Client uses PSK",
  1483  			ClientPSK: true,
  1484  		},
  1485  	} {
  1486  		test := test
  1487  		t.Run(test.Name, func(t *testing.T) {
  1488  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1489  			defer cancel()
  1490  
  1491  			ca, cb := dpipe.Pipe()
  1492  			type result struct {
  1493  				c   *Conn
  1494  				err error
  1495  			}
  1496  			c := make(chan result)
  1497  
  1498  			go func() {
  1499  				config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}
  1500  				if test.ClientPSK {
  1501  					config.PSK = func([]byte) ([]byte, error) {
  1502  						return []byte{0x00, 0x01, 0x02}, nil
  1503  					}
  1504  					config.PSKIdentityHint = []byte{0x00}
  1505  					config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256}
  1506  				}
  1507  
  1508  				client, err := testClient(ctx, ca, config, false)
  1509  				c <- result{client, err}
  1510  			}()
  1511  
  1512  			config := &Config{
  1513  				CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256},
  1514  				PSK: func([]byte) ([]byte, error) {
  1515  					return []byte{0x00, 0x01, 0x02}, nil
  1516  				},
  1517  			}
  1518  
  1519  			server, err := testServer(ctx, cb, config, true)
  1520  			if err == nil {
  1521  				defer func() {
  1522  					_ = server.Close()
  1523  				}()
  1524  			} else {
  1525  				t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err)
  1526  			}
  1527  
  1528  			res := <-c
  1529  			if res.err == nil {
  1530  				_ = server.Close()
  1531  				_ = res.c.Close()
  1532  			} else {
  1533  				t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err)
  1534  			}
  1535  		})
  1536  	}
  1537  }
  1538  
  1539  func TestPSKConfiguration(t *testing.T) {
  1540  	// Check for leaking routines
  1541  	report := test.CheckRoutines(t)
  1542  	defer report()
  1543  
  1544  	for _, test := range []struct {
  1545  		Name                 string
  1546  		ClientHasCertificate bool
  1547  		ServerHasCertificate bool
  1548  		ClientPSK            PSKCallback
  1549  		ServerPSK            PSKCallback
  1550  		ClientPSKIdentity    []byte
  1551  		ServerPSKIdentity    []byte
  1552  		WantClientError      error
  1553  		WantServerError      error
  1554  	}{
  1555  		{
  1556  			Name:                 "PSK and no certificate specified",
  1557  			ClientHasCertificate: false,
  1558  			ServerHasCertificate: false,
  1559  			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1560  			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1561  			ClientPSKIdentity:    []byte{0x00},
  1562  			ServerPSKIdentity:    []byte{0x00},
  1563  			WantClientError:      errNoAvailablePSKCipherSuite,
  1564  			WantServerError:      errNoAvailablePSKCipherSuite,
  1565  		},
  1566  		{
  1567  			Name:                 "PSK and certificate specified",
  1568  			ClientHasCertificate: true,
  1569  			ServerHasCertificate: true,
  1570  			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1571  			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1572  			ClientPSKIdentity:    []byte{0x00},
  1573  			ServerPSKIdentity:    []byte{0x00},
  1574  			WantClientError:      errNoAvailablePSKCipherSuite,
  1575  			WantServerError:      errNoAvailablePSKCipherSuite,
  1576  		},
  1577  		{
  1578  			Name:                 "PSK and no identity specified",
  1579  			ClientHasCertificate: false,
  1580  			ServerHasCertificate: false,
  1581  			ClientPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1582  			ServerPSK:            func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
  1583  			ClientPSKIdentity:    nil,
  1584  			ServerPSKIdentity:    nil,
  1585  			WantClientError:      errPSKAndIdentityMustBeSetForClient,
  1586  			WantServerError:      errNoAvailablePSKCipherSuite,
  1587  		},
  1588  		{
  1589  			Name:                 "No PSK and identity specified",
  1590  			ClientHasCertificate: false,
  1591  			ServerHasCertificate: false,
  1592  			ClientPSK:            nil,
  1593  			ServerPSK:            nil,
  1594  			ClientPSKIdentity:    []byte{0x00},
  1595  			ServerPSKIdentity:    []byte{0x00},
  1596  			WantClientError:      errIdentityNoPSK,
  1597  			WantServerError:      errIdentityNoPSK,
  1598  		},
  1599  	} {
  1600  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1601  		defer cancel()
  1602  
  1603  		ca, cb := dpipe.Pipe()
  1604  		type result struct {
  1605  			c   *Conn
  1606  			err error
  1607  		}
  1608  		c := make(chan result)
  1609  
  1610  		go func() {
  1611  			client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate)
  1612  			c <- result{client, err}
  1613  		}()
  1614  
  1615  		_, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate)
  1616  		if err != nil || test.WantServerError != nil {
  1617  			if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
  1618  				t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
  1619  			}
  1620  		}
  1621  
  1622  		res := <-c
  1623  		if res.err != nil || test.WantClientError != nil {
  1624  			if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) {
  1625  				t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err)
  1626  			}
  1627  		}
  1628  	}
  1629  }
  1630  
  1631  func TestServerTimeout(t *testing.T) {
  1632  	// Limit runtime in case of deadlocks
  1633  	lim := test.TimeOut(time.Second * 20)
  1634  	defer lim.Stop()
  1635  
  1636  	// Check for leaking routines
  1637  	report := test.CheckRoutines(t)
  1638  	defer report()
  1639  
  1640  	cookie := make([]byte, 20)
  1641  	_, err := rand.Read(cookie)
  1642  	if err != nil {
  1643  		t.Fatal(err)
  1644  	}
  1645  
  1646  	var rand [28]byte
  1647  	random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
  1648  
  1649  	cipherSuites := []CipherSuite{
  1650  		&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{},
  1651  		&ciphersuite.TLSEcdheRsaWithAes128GcmSha256{},
  1652  	}
  1653  
  1654  	extensions := []extension.Extension{
  1655  		&extension.SupportedSignatureAlgorithms{
  1656  			SignatureHashAlgorithms: []signaturehash.Algorithm{
  1657  				{Hash: hash.SHA256, Signature: signature.ECDSA},
  1658  				{Hash: hash.SHA384, Signature: signature.ECDSA},
  1659  				{Hash: hash.SHA512, Signature: signature.ECDSA},
  1660  				{Hash: hash.SHA256, Signature: signature.RSA},
  1661  				{Hash: hash.SHA384, Signature: signature.RSA},
  1662  				{Hash: hash.SHA512, Signature: signature.RSA},
  1663  			},
  1664  		},
  1665  		&extension.SupportedEllipticCurves{
  1666  			EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
  1667  		},
  1668  		&extension.SupportedPointFormats{
  1669  			PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  1670  		},
  1671  	}
  1672  
  1673  	record := &recordlayer.RecordLayer{
  1674  		Header: recordlayer.Header{
  1675  			SequenceNumber: 0,
  1676  			Version:        protocol.Version1_2,
  1677  		},
  1678  		Content: &handshake.Handshake{
  1679  			// sequenceNumber and messageSequence line up, may need to be re-evaluated
  1680  			Header: handshake.Header{
  1681  				MessageSequence: 0,
  1682  			},
  1683  			Message: &handshake.MessageClientHello{
  1684  				Version:            protocol.Version1_2,
  1685  				Cookie:             cookie,
  1686  				Random:             random,
  1687  				CipherSuiteIDs:     cipherSuiteIDs(cipherSuites),
  1688  				CompressionMethods: defaultCompressionMethods(),
  1689  				Extensions:         extensions,
  1690  			},
  1691  		},
  1692  	}
  1693  
  1694  	packet, err := record.Marshal()
  1695  	if err != nil {
  1696  		t.Fatal(err)
  1697  	}
  1698  
  1699  	ca, cb := dpipe.Pipe()
  1700  	defer func() {
  1701  		err := ca.Close()
  1702  		if err != nil {
  1703  			t.Fatal(err)
  1704  		}
  1705  	}()
  1706  
  1707  	// Client reader
  1708  	caReadChan := make(chan []byte, 1000)
  1709  	go func() {
  1710  		for {
  1711  			data := make([]byte, 8192)
  1712  			n, err := ca.Read(data)
  1713  			if err != nil {
  1714  				return
  1715  			}
  1716  
  1717  			caReadChan <- data[:n]
  1718  		}
  1719  	}()
  1720  
  1721  	// Start sending ClientHello packets until server responds with first packet
  1722  	go func() {
  1723  		for {
  1724  			select {
  1725  			case <-time.After(10 * time.Millisecond):
  1726  				_, err := ca.Write(packet)
  1727  				if err != nil {
  1728  					return
  1729  				}
  1730  			case <-caReadChan:
  1731  				// Once we receive the first reply from the server, stop
  1732  				return
  1733  			}
  1734  		}
  1735  	}()
  1736  
  1737  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
  1738  	defer cancel()
  1739  
  1740  	config := &Config{
  1741  		CipherSuites:   []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  1742  		FlightInterval: 100 * time.Millisecond,
  1743  	}
  1744  
  1745  	_, serverErr := testServer(ctx, cb, config, true)
  1746  	var netErr net.Error
  1747  	if !errors.As(serverErr, &netErr) || !netErr.Timeout() {
  1748  		t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr)
  1749  	}
  1750  
  1751  	// Wait a little longer to ensure no additional messages have been sent by the server
  1752  	time.Sleep(300 * time.Millisecond)
  1753  	select {
  1754  	case msg := <-caReadChan:
  1755  		t.Fatalf("Expected no additional messages from server, got: %+v", msg)
  1756  	default:
  1757  	}
  1758  }
  1759  
  1760  func TestProtocolVersionValidation(t *testing.T) {
  1761  	// Limit runtime in case of deadlocks
  1762  	lim := test.TimeOut(time.Second * 20)
  1763  	defer lim.Stop()
  1764  
  1765  	// Check for leaking routines
  1766  	report := test.CheckRoutines(t)
  1767  	defer report()
  1768  
  1769  	cookie := make([]byte, 20)
  1770  	if _, err := rand.Read(cookie); err != nil {
  1771  		t.Fatal(err)
  1772  	}
  1773  
  1774  	var rand [28]byte
  1775  	random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}
  1776  
  1777  	config := &Config{
  1778  		CipherSuites:   []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  1779  		FlightInterval: 100 * time.Millisecond,
  1780  	}
  1781  
  1782  	t.Run("Server", func(t *testing.T) {
  1783  		serverCases := map[string]struct {
  1784  			records []*recordlayer.RecordLayer
  1785  		}{
  1786  			"ClientHelloVersion": {
  1787  				records: []*recordlayer.RecordLayer{
  1788  					{
  1789  						Header: recordlayer.Header{
  1790  							Version: protocol.Version1_2,
  1791  						},
  1792  						Content: &handshake.Handshake{
  1793  							Message: &handshake.MessageClientHello{
  1794  								Version:            protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
  1795  								Cookie:             cookie,
  1796  								Random:             random,
  1797  								CipherSuiteIDs:     []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
  1798  								CompressionMethods: defaultCompressionMethods(),
  1799  							},
  1800  						},
  1801  					},
  1802  				},
  1803  			},
  1804  			"SecondsClientHelloVersion": {
  1805  				records: []*recordlayer.RecordLayer{
  1806  					{
  1807  						Header: recordlayer.Header{
  1808  							Version: protocol.Version1_2,
  1809  						},
  1810  						Content: &handshake.Handshake{
  1811  							Message: &handshake.MessageClientHello{
  1812  								Version:            protocol.Version1_2,
  1813  								Cookie:             cookie,
  1814  								Random:             random,
  1815  								CipherSuiteIDs:     []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
  1816  								CompressionMethods: defaultCompressionMethods(),
  1817  							},
  1818  						},
  1819  					},
  1820  					{
  1821  						Header: recordlayer.Header{
  1822  							Version:        protocol.Version1_2,
  1823  							SequenceNumber: 1,
  1824  						},
  1825  						Content: &handshake.Handshake{
  1826  							Header: handshake.Header{
  1827  								MessageSequence: 1,
  1828  							},
  1829  							Message: &handshake.MessageClientHello{
  1830  								Version:            protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
  1831  								Cookie:             cookie,
  1832  								Random:             random,
  1833  								CipherSuiteIDs:     []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())},
  1834  								CompressionMethods: defaultCompressionMethods(),
  1835  							},
  1836  						},
  1837  					},
  1838  				},
  1839  			},
  1840  		}
  1841  		for name, c := range serverCases {
  1842  			c := c
  1843  			t.Run(name, func(t *testing.T) {
  1844  				ca, cb := dpipe.Pipe()
  1845  				defer func() {
  1846  					err := ca.Close()
  1847  					if err != nil {
  1848  						t.Error(err)
  1849  					}
  1850  				}()
  1851  
  1852  				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1853  				defer cancel()
  1854  
  1855  				var wg sync.WaitGroup
  1856  				wg.Add(1)
  1857  				defer wg.Wait()
  1858  				go func() {
  1859  					defer wg.Done()
  1860  					if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
  1861  						t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
  1862  					}
  1863  				}()
  1864  
  1865  				time.Sleep(50 * time.Millisecond)
  1866  
  1867  				resp := make([]byte, 1024)
  1868  				for _, record := range c.records {
  1869  					packet, err := record.Marshal()
  1870  					if err != nil {
  1871  						t.Fatal(err)
  1872  					}
  1873  					if _, werr := ca.Write(packet); werr != nil {
  1874  						t.Fatal(werr)
  1875  					}
  1876  					n, rerr := ca.Read(resp[:cap(resp)])
  1877  					if rerr != nil {
  1878  						t.Fatal(rerr)
  1879  					}
  1880  					resp = resp[:n]
  1881  				}
  1882  
  1883  				h := &recordlayer.Header{}
  1884  				if err := h.Unmarshal(resp); err != nil {
  1885  					t.Fatal("Failed to unmarshal response")
  1886  				}
  1887  				if h.ContentType != protocol.ContentTypeAlert {
  1888  					t.Errorf("Peer must return alert to unsupported protocol version")
  1889  				}
  1890  			})
  1891  		}
  1892  	})
  1893  
  1894  	t.Run("Client", func(t *testing.T) {
  1895  		clientCases := map[string]struct {
  1896  			records []*recordlayer.RecordLayer
  1897  		}{
  1898  			"ServerHelloVersion": {
  1899  				records: []*recordlayer.RecordLayer{
  1900  					{
  1901  						Header: recordlayer.Header{
  1902  							Version: protocol.Version1_2,
  1903  						},
  1904  						Content: &handshake.Handshake{
  1905  							Message: &handshake.MessageHelloVerifyRequest{
  1906  								Version: protocol.Version1_2,
  1907  								Cookie:  cookie,
  1908  							},
  1909  						},
  1910  					},
  1911  					{
  1912  						Header: recordlayer.Header{
  1913  							Version:        protocol.Version1_2,
  1914  							SequenceNumber: 1,
  1915  						},
  1916  						Content: &handshake.Handshake{
  1917  							Header: handshake.Header{
  1918  								MessageSequence: 1,
  1919  							},
  1920  							Message: &handshake.MessageServerHello{
  1921  								Version:           protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade
  1922  								Random:            random,
  1923  								CipherSuiteID:     func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(),
  1924  								CompressionMethod: defaultCompressionMethods()[0],
  1925  							},
  1926  						},
  1927  					},
  1928  				},
  1929  			},
  1930  		}
  1931  		for name, c := range clientCases {
  1932  			c := c
  1933  			t.Run(name, func(t *testing.T) {
  1934  				ca, cb := dpipe.Pipe()
  1935  				defer func() {
  1936  					err := ca.Close()
  1937  					if err != nil {
  1938  						t.Error(err)
  1939  					}
  1940  				}()
  1941  
  1942  				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  1943  				defer cancel()
  1944  
  1945  				var wg sync.WaitGroup
  1946  				wg.Add(1)
  1947  				defer wg.Wait()
  1948  				go func() {
  1949  					defer wg.Done()
  1950  					if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) {
  1951  						t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err)
  1952  					}
  1953  				}()
  1954  
  1955  				time.Sleep(50 * time.Millisecond)
  1956  
  1957  				for _, record := range c.records {
  1958  					if _, err := ca.Read(make([]byte, 1024)); err != nil {
  1959  						t.Fatal(err)
  1960  					}
  1961  
  1962  					packet, err := record.Marshal()
  1963  					if err != nil {
  1964  						t.Fatal(err)
  1965  					}
  1966  					if _, err := ca.Write(packet); err != nil {
  1967  						t.Fatal(err)
  1968  					}
  1969  				}
  1970  				resp := make([]byte, 1024)
  1971  				n, err := ca.Read(resp)
  1972  				if err != nil {
  1973  					t.Fatal(err)
  1974  				}
  1975  				resp = resp[:n]
  1976  
  1977  				h := &recordlayer.Header{}
  1978  				if err := h.Unmarshal(resp); err != nil {
  1979  					t.Fatal("Failed to unmarshal response")
  1980  				}
  1981  				if h.ContentType != protocol.ContentTypeAlert {
  1982  					t.Errorf("Peer must return alert to unsupported protocol version")
  1983  				}
  1984  			})
  1985  		}
  1986  	})
  1987  }
  1988  
  1989  func TestMultipleHelloVerifyRequest(t *testing.T) {
  1990  	// Limit runtime in case of deadlocks
  1991  	lim := test.TimeOut(time.Second * 20)
  1992  	defer lim.Stop()
  1993  
  1994  	// Check for leaking routines
  1995  	report := test.CheckRoutines(t)
  1996  	defer report()
  1997  
  1998  	cookies := [][]byte{
  1999  		// first clientHello contains an empty cookie
  2000  		{},
  2001  	}
  2002  	var packets [][]byte
  2003  	for i := 0; i < 2; i++ {
  2004  		cookie := make([]byte, 20)
  2005  		if _, err := rand.Read(cookie); err != nil {
  2006  			t.Fatal(err)
  2007  		}
  2008  		cookies = append(cookies, cookie)
  2009  
  2010  		record := &recordlayer.RecordLayer{
  2011  			Header: recordlayer.Header{
  2012  				SequenceNumber: uint64(i),
  2013  				Version:        protocol.Version1_2,
  2014  			},
  2015  			Content: &handshake.Handshake{
  2016  				Header: handshake.Header{
  2017  					MessageSequence: uint16(i),
  2018  				},
  2019  				Message: &handshake.MessageHelloVerifyRequest{
  2020  					Version: protocol.Version1_2,
  2021  					Cookie:  cookie,
  2022  				},
  2023  			},
  2024  		}
  2025  		packet, err := record.Marshal()
  2026  		if err != nil {
  2027  			t.Fatal(err)
  2028  		}
  2029  		packets = append(packets, packet)
  2030  	}
  2031  
  2032  	ca, cb := dpipe.Pipe()
  2033  	defer func() {
  2034  		err := ca.Close()
  2035  		if err != nil {
  2036  			t.Error(err)
  2037  		}
  2038  	}()
  2039  
  2040  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
  2041  	defer cancel()
  2042  
  2043  	var wg sync.WaitGroup
  2044  	wg.Add(1)
  2045  	defer wg.Wait()
  2046  	go func() {
  2047  		defer wg.Done()
  2048  		_, _ = testClient(ctx, ca, &Config{}, false)
  2049  	}()
  2050  
  2051  	for i, cookie := range cookies {
  2052  		// read client hello
  2053  		resp := make([]byte, 1024)
  2054  		n, err := cb.Read(resp)
  2055  		if err != nil {
  2056  			t.Fatal(err)
  2057  		}
  2058  		record := &recordlayer.RecordLayer{}
  2059  		if err := record.Unmarshal(resp[:n]); err != nil {
  2060  			t.Fatal(err)
  2061  		}
  2062  		clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
  2063  		if !ok {
  2064  			t.Fatal("Failed to cast MessageClientHello")
  2065  		}
  2066  
  2067  		if !bytes.Equal(clientHello.Cookie, cookie) {
  2068  			t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie)
  2069  		}
  2070  		if len(packets) <= i {
  2071  			break
  2072  		}
  2073  		// write hello verify request
  2074  		if _, err := cb.Write(packets[i]); err != nil {
  2075  			t.Fatal(err)
  2076  		}
  2077  	}
  2078  	cancel()
  2079  }
  2080  
  2081  // Assert that a DTLS Server always responds with RenegotiationInfo if
  2082  // a ClientHello contained that extension or not
  2083  func TestRenegotationInfo(t *testing.T) {
  2084  	// Limit runtime in case of deadlocks
  2085  	lim := test.TimeOut(10 * time.Second)
  2086  	defer lim.Stop()
  2087  
  2088  	// Check for leaking routines
  2089  	report := test.CheckRoutines(t)
  2090  	defer report()
  2091  
  2092  	resp := make([]byte, 1024)
  2093  
  2094  	for _, testCase := range []struct {
  2095  		Name                  string
  2096  		SendRenegotiationInfo bool
  2097  	}{
  2098  		{
  2099  			"Include RenegotiationInfo",
  2100  			true,
  2101  		},
  2102  		{
  2103  			"No RenegotiationInfo",
  2104  			false,
  2105  		},
  2106  	} {
  2107  		test := testCase
  2108  		t.Run(test.Name, func(t *testing.T) {
  2109  			ca, cb := dpipe.Pipe()
  2110  			defer func() {
  2111  				if err := ca.Close(); err != nil {
  2112  					t.Error(err)
  2113  				}
  2114  			}()
  2115  
  2116  			ctx, cancel := context.WithCancel(context.Background())
  2117  			defer cancel()
  2118  
  2119  			go func() {
  2120  				if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) {
  2121  					t.Error(err)
  2122  				}
  2123  			}()
  2124  
  2125  			time.Sleep(50 * time.Millisecond)
  2126  
  2127  			extensions := []extension.Extension{}
  2128  			if test.SendRenegotiationInfo {
  2129  				extensions = append(extensions, &extension.RenegotiationInfo{
  2130  					RenegotiatedConnection: 0,
  2131  				})
  2132  			}
  2133  			err := sendClientHello([]byte{}, ca, 0, extensions)
  2134  			if err != nil {
  2135  				t.Fatal(err)
  2136  			}
  2137  			n, err := ca.Read(resp)
  2138  			if err != nil {
  2139  				t.Fatal(err)
  2140  			}
  2141  			r := &recordlayer.RecordLayer{}
  2142  			if err = r.Unmarshal(resp[:n]); err != nil {
  2143  				t.Fatal(err)
  2144  			}
  2145  
  2146  			helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
  2147  			if !ok {
  2148  				t.Fatal("Failed to cast MessageHelloVerifyRequest")
  2149  			}
  2150  
  2151  			err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
  2152  			if err != nil {
  2153  				t.Fatal(err)
  2154  			}
  2155  			if n, err = ca.Read(resp); err != nil {
  2156  				t.Fatal(err)
  2157  			}
  2158  
  2159  			messages, err := recordlayer.UnpackDatagram(resp[:n])
  2160  			if err != nil {
  2161  				t.Fatal(err)
  2162  			}
  2163  
  2164  			if err := r.Unmarshal(messages[0]); err != nil {
  2165  				t.Fatal(err)
  2166  			}
  2167  
  2168  			serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
  2169  			if !ok {
  2170  				t.Fatal("Failed to cast MessageServerHello")
  2171  			}
  2172  
  2173  			gotNegotationInfo := false
  2174  			for _, v := range serverHello.Extensions {
  2175  				if _, ok := v.(*extension.RenegotiationInfo); ok {
  2176  					gotNegotationInfo = true
  2177  				}
  2178  			}
  2179  
  2180  			if !gotNegotationInfo {
  2181  				t.Fatalf("Received ServerHello without RenegotiationInfo")
  2182  			}
  2183  		})
  2184  	}
  2185  }
  2186  
  2187  func TestServerNameIndicationExtension(t *testing.T) {
  2188  	// Limit runtime in case of deadlocks
  2189  	lim := test.TimeOut(time.Second * 20)
  2190  	defer lim.Stop()
  2191  
  2192  	// Check for leaking routines
  2193  	report := test.CheckRoutines(t)
  2194  	defer report()
  2195  
  2196  	for _, test := range []struct {
  2197  		Name       string
  2198  		ServerName string
  2199  		Expected   []byte
  2200  		IncludeSNI bool
  2201  	}{
  2202  		{
  2203  			Name:       "Server name is a valid hostname",
  2204  			ServerName: "example.com",
  2205  			Expected:   []byte("example.com"),
  2206  			IncludeSNI: true,
  2207  		},
  2208  		{
  2209  			Name:       "Server name is an IP literal",
  2210  			ServerName: "1.2.3.4",
  2211  			Expected:   []byte(""),
  2212  			IncludeSNI: false,
  2213  		},
  2214  		{
  2215  			Name:       "Server name is empty",
  2216  			ServerName: "",
  2217  			Expected:   []byte(""),
  2218  			IncludeSNI: false,
  2219  		},
  2220  	} {
  2221  		test := test
  2222  		t.Run(test.Name, func(t *testing.T) {
  2223  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2224  			defer cancel()
  2225  
  2226  			ca, cb := dpipe.Pipe()
  2227  			go func() {
  2228  				conf := &Config{
  2229  					ServerName: test.ServerName,
  2230  				}
  2231  
  2232  				_, _ = testClient(ctx, ca, conf, false)
  2233  			}()
  2234  
  2235  			// Receive ClientHello
  2236  			resp := make([]byte, 1024)
  2237  			n, err := cb.Read(resp)
  2238  			if err != nil {
  2239  				t.Fatal(err)
  2240  			}
  2241  			r := &recordlayer.RecordLayer{}
  2242  			if err = r.Unmarshal(resp[:n]); err != nil {
  2243  				t.Fatal(err)
  2244  			}
  2245  
  2246  			clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello)
  2247  			if !ok {
  2248  				t.Fatal("Failed to cast MessageClientHello")
  2249  			}
  2250  
  2251  			gotSNI := false
  2252  			var actualServerName string
  2253  			for _, v := range clientHello.Extensions {
  2254  				if _, ok := v.(*extension.ServerName); ok {
  2255  					gotSNI = true
  2256  					extensionServerName, ok := v.(*extension.ServerName)
  2257  					if !ok {
  2258  						t.Fatal("Failed to cast extension.ServerName")
  2259  					}
  2260  
  2261  					actualServerName = extensionServerName.ServerName
  2262  				}
  2263  			}
  2264  
  2265  			if gotSNI != test.IncludeSNI {
  2266  				t.Errorf("TestSNI: unexpected SNI inclusion '%s': expected(%v) actual(%v)", test.Name, test.IncludeSNI, gotSNI)
  2267  			}
  2268  
  2269  			if !bytes.Equal([]byte(actualServerName), test.Expected) {
  2270  				t.Errorf("TestSNI: server name mismatch '%s': expected(%v) actual(%v)", test.Name, test.Expected, actualServerName)
  2271  			}
  2272  		})
  2273  	}
  2274  }
  2275  
  2276  func TestALPNExtension(t *testing.T) {
  2277  	// Limit runtime in case of deadlocks
  2278  	lim := test.TimeOut(time.Second * 20)
  2279  	defer lim.Stop()
  2280  
  2281  	// Check for leaking routines
  2282  	report := test.CheckRoutines(t)
  2283  	defer report()
  2284  
  2285  	for _, test := range []struct {
  2286  		Name                   string
  2287  		ClientProtocolNameList []string
  2288  		ServerProtocolNameList []string
  2289  		ExpectedProtocol       string
  2290  		ExpectAlertFromClient  bool
  2291  		ExpectAlertFromServer  bool
  2292  		Alert                  alert.Description
  2293  	}{
  2294  		{
  2295  			Name:                   "Negotiate a protocol",
  2296  			ClientProtocolNameList: []string{"http/1.1", "spd/1"},
  2297  			ServerProtocolNameList: []string{"spd/1"},
  2298  			ExpectedProtocol:       "spd/1",
  2299  			ExpectAlertFromClient:  false,
  2300  			ExpectAlertFromServer:  false,
  2301  			Alert:                  0,
  2302  		},
  2303  		{
  2304  			Name:                   "Server doesn't support any",
  2305  			ClientProtocolNameList: []string{"http/1.1", "spd/1"},
  2306  			ServerProtocolNameList: []string{},
  2307  			ExpectedProtocol:       "",
  2308  			ExpectAlertFromClient:  false,
  2309  			ExpectAlertFromServer:  false,
  2310  			Alert:                  0,
  2311  		},
  2312  		{
  2313  			Name:                   "Negotiate with higher server precedence",
  2314  			ClientProtocolNameList: []string{"http/1.1", "spd/1", "http/3"},
  2315  			ServerProtocolNameList: []string{"ssh/2", "http/3", "spd/1"},
  2316  			ExpectedProtocol:       "http/3",
  2317  			ExpectAlertFromClient:  false,
  2318  			ExpectAlertFromServer:  false,
  2319  			Alert:                  0,
  2320  		},
  2321  		{
  2322  			Name:                   "Empty intersection",
  2323  			ClientProtocolNameList: []string{"http/1.1", "http/3"},
  2324  			ServerProtocolNameList: []string{"ssh/2", "spd/1"},
  2325  			ExpectedProtocol:       "",
  2326  			ExpectAlertFromClient:  false,
  2327  			ExpectAlertFromServer:  true,
  2328  			Alert:                  alert.NoApplicationProtocol,
  2329  		},
  2330  		{
  2331  			Name:                   "Multiple protocols in ServerHello",
  2332  			ClientProtocolNameList: []string{"http/1.1"},
  2333  			ServerProtocolNameList: []string{"http/1.1"},
  2334  			ExpectedProtocol:       "http/1.1",
  2335  			ExpectAlertFromClient:  true,
  2336  			ExpectAlertFromServer:  false,
  2337  			Alert:                  alert.InternalError,
  2338  		},
  2339  	} {
  2340  		test := test
  2341  		t.Run(test.Name, func(t *testing.T) {
  2342  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2343  			defer cancel()
  2344  
  2345  			ca, cb := dpipe.Pipe()
  2346  			go func() {
  2347  				conf := &Config{
  2348  					SupportedProtocols: test.ClientProtocolNameList,
  2349  				}
  2350  				_, _ = testClient(ctx, ca, conf, false)
  2351  			}()
  2352  
  2353  			// Receive ClientHello
  2354  			resp := make([]byte, 1024)
  2355  			n, err := cb.Read(resp)
  2356  			if err != nil {
  2357  				t.Fatal(err)
  2358  			}
  2359  
  2360  			ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second)
  2361  			defer cancel2()
  2362  
  2363  			ca2, cb2 := dpipe.Pipe()
  2364  			go func() {
  2365  				conf := &Config{
  2366  					SupportedProtocols: test.ServerProtocolNameList,
  2367  				}
  2368  				if _, err2 := testServer(ctx2, cb2, conf, true); !errors.Is(err2, context.Canceled) {
  2369  					if test.ExpectAlertFromServer { //nolint
  2370  						// Assert the error type?
  2371  					} else {
  2372  						t.Error(err2)
  2373  					}
  2374  				}
  2375  			}()
  2376  
  2377  			time.Sleep(50 * time.Millisecond)
  2378  
  2379  			// Forward ClientHello
  2380  			if _, err = ca2.Write(resp[:n]); err != nil {
  2381  				t.Fatal(err)
  2382  			}
  2383  
  2384  			// Receive HelloVerify
  2385  			resp2 := make([]byte, 1024)
  2386  			n, err = ca2.Read(resp2)
  2387  			if err != nil {
  2388  				t.Fatal(err)
  2389  			}
  2390  
  2391  			// Forward HelloVerify
  2392  			if _, err = cb.Write(resp2[:n]); err != nil {
  2393  				t.Fatal(err)
  2394  			}
  2395  
  2396  			// Receive ClientHello
  2397  			resp3 := make([]byte, 1024)
  2398  			n, err = cb.Read(resp3)
  2399  			if err != nil {
  2400  				t.Fatal(err)
  2401  			}
  2402  
  2403  			// Forward ClientHello
  2404  			if _, err = ca2.Write(resp3[:n]); err != nil {
  2405  				t.Fatal(err)
  2406  			}
  2407  
  2408  			// Receive ServerHello
  2409  			resp4 := make([]byte, 1024)
  2410  			n, err = ca2.Read(resp4)
  2411  			if err != nil {
  2412  				t.Fatal(err)
  2413  			}
  2414  
  2415  			messages, err := recordlayer.UnpackDatagram(resp4[:n])
  2416  			if err != nil {
  2417  				t.Fatal(err)
  2418  			}
  2419  
  2420  			r := &recordlayer.RecordLayer{}
  2421  			if err := r.Unmarshal(messages[0]); err != nil {
  2422  				t.Fatal(err)
  2423  			}
  2424  
  2425  			if test.ExpectAlertFromServer {
  2426  				a, ok := r.Content.(*alert.Alert)
  2427  				if !ok {
  2428  					t.Fatal("Failed to cast alert.Alert")
  2429  				}
  2430  
  2431  				if a.Description != test.Alert {
  2432  					t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
  2433  				}
  2434  			} else {
  2435  				serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
  2436  				if !ok {
  2437  					t.Fatal("Failed to cast handshake.MessageServerHello")
  2438  				}
  2439  
  2440  				var negotiatedProtocol string
  2441  				for _, v := range serverHello.Extensions {
  2442  					if _, ok := v.(*extension.ALPN); ok {
  2443  						e, ok := v.(*extension.ALPN)
  2444  						if !ok {
  2445  							t.Fatal("Failed to cast extension.ALPN")
  2446  						}
  2447  
  2448  						negotiatedProtocol = e.ProtocolNameList[0]
  2449  
  2450  						// Manipulate ServerHello
  2451  						if test.ExpectAlertFromClient {
  2452  							e.ProtocolNameList = append(e.ProtocolNameList, "oops")
  2453  						}
  2454  					}
  2455  				}
  2456  
  2457  				if negotiatedProtocol != test.ExpectedProtocol {
  2458  					t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol)
  2459  				}
  2460  
  2461  				s, err := r.Marshal()
  2462  				if err != nil {
  2463  					t.Fatal(err)
  2464  				}
  2465  
  2466  				// Forward ServerHello
  2467  				if _, err = cb.Write(s); err != nil {
  2468  					t.Fatal(err)
  2469  				}
  2470  
  2471  				if test.ExpectAlertFromClient {
  2472  					resp5 := make([]byte, 1024)
  2473  					n, err = cb.Read(resp5)
  2474  					if err != nil {
  2475  						t.Fatal(err)
  2476  					}
  2477  
  2478  					r2 := &recordlayer.RecordLayer{}
  2479  					if err := r2.Unmarshal(resp5[:n]); err != nil {
  2480  						t.Fatal(err)
  2481  					}
  2482  
  2483  					a, ok := r2.Content.(*alert.Alert)
  2484  					if !ok {
  2485  						t.Fatal("Failed to cast alert.Alert")
  2486  					}
  2487  
  2488  					if a.Description != test.Alert {
  2489  						t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description)
  2490  					}
  2491  				}
  2492  			}
  2493  
  2494  			time.Sleep(50 * time.Millisecond) // Give some time for returned errors
  2495  		})
  2496  	}
  2497  }
  2498  
  2499  // Make sure the supported_groups extension is not included in the ServerHello
  2500  func TestSupportedGroupsExtension(t *testing.T) {
  2501  	// Limit runtime in case of deadlocks
  2502  	lim := test.TimeOut(time.Second * 20)
  2503  	defer lim.Stop()
  2504  
  2505  	// Check for leaking routines
  2506  	report := test.CheckRoutines(t)
  2507  	defer report()
  2508  
  2509  	t.Run("ServerHello Supported Groups", func(t *testing.T) {
  2510  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2511  		defer cancel()
  2512  
  2513  		ca, cb := dpipe.Pipe()
  2514  		go func() {
  2515  			if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) {
  2516  				t.Error(err)
  2517  			}
  2518  		}()
  2519  		extensions := []extension.Extension{
  2520  			&extension.SupportedEllipticCurves{
  2521  				EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
  2522  			},
  2523  			&extension.SupportedPointFormats{
  2524  				PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  2525  			},
  2526  		}
  2527  
  2528  		time.Sleep(50 * time.Millisecond)
  2529  
  2530  		resp := make([]byte, 1024)
  2531  		err := sendClientHello([]byte{}, ca, 0, extensions)
  2532  		if err != nil {
  2533  			t.Fatal(err)
  2534  		}
  2535  
  2536  		// Receive ServerHello
  2537  		n, err := ca.Read(resp)
  2538  		if err != nil {
  2539  			t.Fatal(err)
  2540  		}
  2541  		r := &recordlayer.RecordLayer{}
  2542  		if err = r.Unmarshal(resp[:n]); err != nil {
  2543  			t.Fatal(err)
  2544  		}
  2545  
  2546  		helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest)
  2547  		if !ok {
  2548  			t.Fatal("Failed to cast MessageHelloVerifyRequest")
  2549  		}
  2550  
  2551  		err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions)
  2552  		if err != nil {
  2553  			t.Fatal(err)
  2554  		}
  2555  		if n, err = ca.Read(resp); err != nil {
  2556  			t.Fatal(err)
  2557  		}
  2558  
  2559  		messages, err := recordlayer.UnpackDatagram(resp[:n])
  2560  		if err != nil {
  2561  			t.Fatal(err)
  2562  		}
  2563  
  2564  		if err := r.Unmarshal(messages[0]); err != nil {
  2565  			t.Fatal(err)
  2566  		}
  2567  
  2568  		serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello)
  2569  		if !ok {
  2570  			t.Fatal("Failed to cast MessageServerHello")
  2571  		}
  2572  
  2573  		gotGroups := false
  2574  		for _, v := range serverHello.Extensions {
  2575  			if _, ok := v.(*extension.SupportedEllipticCurves); ok {
  2576  				gotGroups = true
  2577  			}
  2578  		}
  2579  
  2580  		if gotGroups {
  2581  			t.Errorf("TestSupportedGroups: supported_groups extension was sent in ServerHello")
  2582  		}
  2583  	})
  2584  }
  2585  
  2586  func TestSessionResume(t *testing.T) {
  2587  	// Limit runtime in case of deadlocks
  2588  	lim := test.TimeOut(time.Second * 20)
  2589  	defer lim.Stop()
  2590  
  2591  	// Check for leaking routines
  2592  	report := test.CheckRoutines(t)
  2593  	defer report()
  2594  
  2595  	t.Run("resumed", func(t *testing.T) {
  2596  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2597  		defer cancel()
  2598  
  2599  		type result struct {
  2600  			c   *Conn
  2601  			err error
  2602  		}
  2603  		clientRes := make(chan result, 1)
  2604  
  2605  		ss := &memSessStore{}
  2606  
  2607  		id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306")
  2608  		secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7")
  2609  
  2610  		s := Session{ID: id, Secret: secret}
  2611  
  2612  		ca, cb := dpipe.Pipe()
  2613  
  2614  		_ = ss.Set(id, s)
  2615  		_ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s)
  2616  
  2617  		go func() {
  2618  			config := &Config{
  2619  				CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  2620  				ServerName:   "example.com",
  2621  				SessionStore: ss,
  2622  				MTU:          100,
  2623  			}
  2624  			c, err := testClient(ctx, ca, config, false)
  2625  			clientRes <- result{c, err}
  2626  		}()
  2627  
  2628  		config := &Config{
  2629  			CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  2630  			ServerName:   "example.com",
  2631  			SessionStore: ss,
  2632  			MTU:          100,
  2633  		}
  2634  		server, err := testServer(ctx, cb, config, true)
  2635  		if err != nil {
  2636  			t.Fatalf("TestSessionResume: Server failed(%v)", err)
  2637  		}
  2638  
  2639  		actualSessionID := server.ConnectionState().SessionID
  2640  		actualMasterSecret := server.ConnectionState().masterSecret
  2641  		if !bytes.Equal(actualSessionID, id) {
  2642  			t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID)
  2643  		}
  2644  		if !bytes.Equal(actualMasterSecret, secret) {
  2645  			t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret)
  2646  		}
  2647  
  2648  		defer func() {
  2649  			_ = server.Close()
  2650  		}()
  2651  
  2652  		res := <-clientRes
  2653  		if res.err != nil {
  2654  			t.Fatal(res.err)
  2655  		}
  2656  		_ = res.c.Close()
  2657  	})
  2658  
  2659  	t.Run("new session", func(t *testing.T) {
  2660  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2661  		defer cancel()
  2662  
  2663  		type result struct {
  2664  			c   *Conn
  2665  			err error
  2666  		}
  2667  		clientRes := make(chan result, 1)
  2668  
  2669  		s1 := &memSessStore{}
  2670  		s2 := &memSessStore{}
  2671  
  2672  		ca, cb := dpipe.Pipe()
  2673  		go func() {
  2674  			config := &Config{
  2675  				ServerName:   "example.com",
  2676  				SessionStore: s1,
  2677  			}
  2678  			c, err := testClient(ctx, ca, config, false)
  2679  			clientRes <- result{c, err}
  2680  		}()
  2681  
  2682  		config := &Config{
  2683  			SessionStore: s2,
  2684  		}
  2685  		server, err := testServer(ctx, cb, config, true)
  2686  		if err != nil {
  2687  			t.Fatalf("TestSessionResumetion: Server failed(%v)", err)
  2688  		}
  2689  
  2690  		actualSessionID := server.ConnectionState().SessionID
  2691  		actualMasterSecret := server.ConnectionState().masterSecret
  2692  		ss, _ := s2.Get(actualSessionID)
  2693  		if !bytes.Equal(actualMasterSecret, ss.Secret) {
  2694  			t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
  2695  		}
  2696  
  2697  		defer func() {
  2698  			_ = server.Close()
  2699  		}()
  2700  
  2701  		res := <-clientRes
  2702  		if res.err != nil {
  2703  			t.Fatal(res.err)
  2704  		}
  2705  		cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com"))
  2706  		if !bytes.Equal(actualMasterSecret, cs.Secret) {
  2707  			t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
  2708  		}
  2709  		_ = res.c.Close()
  2710  	})
  2711  }
  2712  
  2713  type memSessStore struct {
  2714  	sync.Map
  2715  }
  2716  
  2717  func (ms *memSessStore) Set(key []byte, s Session) error {
  2718  	k := hex.EncodeToString(key)
  2719  	ms.Store(k, s)
  2720  
  2721  	return nil
  2722  }
  2723  
  2724  func (ms *memSessStore) Get(key []byte) (Session, error) {
  2725  	k := hex.EncodeToString(key)
  2726  
  2727  	v, ok := ms.Load(k)
  2728  	if !ok {
  2729  		return Session{}, nil
  2730  	}
  2731  
  2732  	s, ok := v.(Session)
  2733  	if !ok {
  2734  		return Session{}, nil
  2735  	}
  2736  
  2737  	return s, nil
  2738  }
  2739  
  2740  func (ms *memSessStore) Del(key []byte) error {
  2741  	k := hex.EncodeToString(key)
  2742  	ms.Delete(k)
  2743  
  2744  	return nil
  2745  }
  2746  
  2747  // Assert that the server only uses CipherSuites with a hash+signature that matches
  2748  // the certificate. As specified in rfc5246#section-7.4.3
  2749  func TestCipherSuiteMatchesCertificateType(t *testing.T) {
  2750  	// Limit runtime in case of deadlocks
  2751  	lim := test.TimeOut(time.Second * 20)
  2752  	defer lim.Stop()
  2753  
  2754  	// Check for leaking routines
  2755  	report := test.CheckRoutines(t)
  2756  	defer report()
  2757  
  2758  	for _, test := range []struct {
  2759  		Name           string
  2760  		cipherList     []CipherSuiteID
  2761  		expectedCipher CipherSuiteID
  2762  		generateRSA    bool
  2763  	}{
  2764  		{
  2765  			Name:           "ECDSA Certificate with RSA CipherSuite first",
  2766  			cipherList:     []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  2767  			expectedCipher: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  2768  		},
  2769  		{
  2770  			Name:           "RSA Certificate with ECDSA CipherSuite first",
  2771  			cipherList:     []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  2772  			expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  2773  			generateRSA:    true,
  2774  		},
  2775  	} {
  2776  		test := test
  2777  		t.Run(test.Name, func(t *testing.T) {
  2778  			clientErr := make(chan error, 1)
  2779  			client := make(chan *Conn, 1)
  2780  
  2781  			ca, cb := dpipe.Pipe()
  2782  			go func() {
  2783  				c, err := testClient(context.TODO(), ca, &Config{CipherSuites: test.cipherList}, false)
  2784  				clientErr <- err
  2785  				client <- c
  2786  			}()
  2787  
  2788  			var (
  2789  				priv crypto.PrivateKey
  2790  				err  error
  2791  			)
  2792  
  2793  			if test.generateRSA {
  2794  				if priv, err = rsa.GenerateKey(rand.Reader, 2048); err != nil {
  2795  					t.Fatal(err)
  2796  				}
  2797  			} else {
  2798  				if priv, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil {
  2799  					t.Fatal(err)
  2800  				}
  2801  			}
  2802  
  2803  			serverCert, err := selfsign.SelfSign(priv)
  2804  			if err != nil {
  2805  				t.Fatal(err)
  2806  			}
  2807  
  2808  			if s, err := testServer(context.TODO(), cb, &Config{
  2809  				CipherSuites: test.cipherList,
  2810  				Certificates: []tls.Certificate{serverCert},
  2811  			}, false); err != nil {
  2812  				t.Fatal(err)
  2813  			} else if err = s.Close(); err != nil {
  2814  				t.Fatal(err)
  2815  			}
  2816  
  2817  			if c, err := <-client, <-clientErr; err != nil {
  2818  				t.Fatal(err)
  2819  			} else if err := c.Close(); err != nil {
  2820  				t.Fatal(err)
  2821  			} else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher {
  2822  				t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID())
  2823  			}
  2824  		})
  2825  	}
  2826  }
  2827  
  2828  // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server
  2829  func TestMultipleServerCertificates(t *testing.T) {
  2830  	fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo")
  2831  	if err != nil {
  2832  		t.Fatal(err)
  2833  	}
  2834  
  2835  	barCert, err := selfsign.GenerateSelfSignedWithDNS("bar")
  2836  	if err != nil {
  2837  		t.Fatal(err)
  2838  	}
  2839  
  2840  	caPool := x509.NewCertPool()
  2841  	for _, cert := range []tls.Certificate{fooCert, barCert} {
  2842  		certificate, err := x509.ParseCertificate(cert.Certificate[0])
  2843  		if err != nil {
  2844  			t.Fatal(err)
  2845  		}
  2846  		caPool.AddCert(certificate)
  2847  	}
  2848  
  2849  	for _, test := range []struct {
  2850  		RequestServerName string
  2851  		ExpectedDNSName   string
  2852  	}{
  2853  		{
  2854  			"foo",
  2855  			"foo",
  2856  		},
  2857  		{
  2858  			"bar",
  2859  			"bar",
  2860  		},
  2861  		{
  2862  			"invalid",
  2863  			"foo",
  2864  		},
  2865  	} {
  2866  		test := test
  2867  		t.Run(test.RequestServerName, func(t *testing.T) {
  2868  			clientErr := make(chan error, 2)
  2869  			client := make(chan *Conn, 1)
  2870  
  2871  			ca, cb := dpipe.Pipe()
  2872  			go func() {
  2873  				c, err := testClient(context.TODO(), ca, &Config{
  2874  					RootCAs:    caPool,
  2875  					ServerName: test.RequestServerName,
  2876  					VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  2877  						certificate, err := x509.ParseCertificate(rawCerts[0])
  2878  						if err != nil {
  2879  							return err
  2880  						}
  2881  
  2882  						if certificate.DNSNames[0] != test.ExpectedDNSName {
  2883  							return errWrongCert
  2884  						}
  2885  
  2886  						return nil
  2887  					},
  2888  				}, false)
  2889  				clientErr <- err
  2890  				client <- c
  2891  			}()
  2892  
  2893  			if s, err := testServer(context.TODO(), cb, &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil {
  2894  				t.Fatal(err)
  2895  			} else if err = s.Close(); err != nil {
  2896  				t.Fatal(err)
  2897  			}
  2898  
  2899  			if c, err := <-client, <-clientErr; err != nil {
  2900  				t.Fatal(err)
  2901  			} else if err := c.Close(); err != nil {
  2902  				t.Fatal(err)
  2903  			}
  2904  		})
  2905  	}
  2906  }
  2907  
  2908  func TestEllipticCurveConfiguration(t *testing.T) {
  2909  	// Check for leaking routines
  2910  	report := test.CheckRoutines(t)
  2911  	defer report()
  2912  
  2913  	for _, test := range []struct {
  2914  		Name            string
  2915  		ConfigCurves    []elliptic.Curve
  2916  		HadnshakeCurves []elliptic.Curve
  2917  	}{
  2918  		{
  2919  			Name:            "Curve defaulting",
  2920  			ConfigCurves:    nil,
  2921  			HadnshakeCurves: defaultCurves,
  2922  		},
  2923  		{
  2924  			Name:            "Single curve",
  2925  			ConfigCurves:    []elliptic.Curve{elliptic.X25519},
  2926  			HadnshakeCurves: []elliptic.Curve{elliptic.X25519},
  2927  		},
  2928  		{
  2929  			Name:            "Multiple curves",
  2930  			ConfigCurves:    []elliptic.Curve{elliptic.P384, elliptic.X25519},
  2931  			HadnshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519},
  2932  		},
  2933  	} {
  2934  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2935  		defer cancel()
  2936  
  2937  		ca, cb := dpipe.Pipe()
  2938  		type result struct {
  2939  			c   *Conn
  2940  			err error
  2941  		}
  2942  		c := make(chan result)
  2943  
  2944  		go func() {
  2945  			client, err := testClient(ctx, ca, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true)
  2946  			c <- result{client, err}
  2947  		}()
  2948  
  2949  		server, err := testServer(ctx, cb, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true)
  2950  		if err != nil {
  2951  			t.Fatalf("Server error: %v", err)
  2952  		}
  2953  
  2954  		if len(test.ConfigCurves) == 0 && len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) {
  2955  			t.Fatalf("Failed to default Elliptic curves, expected %d, got: %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves))
  2956  		}
  2957  
  2958  		if len(test.ConfigCurves) != 0 {
  2959  			if len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) {
  2960  				t.Fatalf("Failed to configure Elliptic curves, expect %d, got %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves))
  2961  			}
  2962  			for i, c := range test.ConfigCurves {
  2963  				if c != server.fsm.cfg.ellipticCurves[i] {
  2964  					t.Fatalf("Failed to maintain Elliptic curve order, expected %s, got %s", c, server.fsm.cfg.ellipticCurves[i])
  2965  				}
  2966  			}
  2967  		}
  2968  
  2969  		res := <-c
  2970  		if res.err != nil {
  2971  			t.Fatalf("Client error; %v", err)
  2972  		}
  2973  
  2974  		defer func() {
  2975  			err = server.Close()
  2976  			if err != nil {
  2977  				t.Fatal(err)
  2978  			}
  2979  			err = res.c.Close()
  2980  			if err != nil {
  2981  				t.Fatal(err)
  2982  			}
  2983  		}()
  2984  	}
  2985  }
  2986  
  2987  func TestSkipHelloVerify(t *testing.T) {
  2988  	report := test.CheckRoutines(t)
  2989  	defer report()
  2990  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  2991  	defer cancel()
  2992  
  2993  	ca, cb := dpipe.Pipe()
  2994  	certificate, err := selfsign.GenerateSelfSigned()
  2995  	if err != nil {
  2996  		t.Fatal(err)
  2997  	}
  2998  	gotHello := make(chan struct{})
  2999  
  3000  	go func() {
  3001  		server, sErr := testServer(ctx, cb, &Config{
  3002  			Certificates:            []tls.Certificate{certificate},
  3003  			LoggerFactory:           logging.NewDefaultLoggerFactory(),
  3004  			InsecureSkipVerifyHello: true,
  3005  		}, false)
  3006  		if sErr != nil {
  3007  			t.Error(sErr)
  3008  			return
  3009  		}
  3010  		buf := make([]byte, 1024)
  3011  		if _, sErr = server.Read(buf); sErr != nil {
  3012  			t.Error(sErr)
  3013  		}
  3014  		gotHello <- struct{}{}
  3015  		if sErr = server.Close(); sErr != nil { //nolint:contextcheck
  3016  			t.Error(sErr)
  3017  		}
  3018  	}()
  3019  
  3020  	client, err := testClient(ctx, ca, &Config{
  3021  		LoggerFactory:      logging.NewDefaultLoggerFactory(),
  3022  		InsecureSkipVerify: true,
  3023  	}, false)
  3024  	if err != nil {
  3025  		t.Fatal(err)
  3026  	}
  3027  	if _, err = client.Write([]byte("hello")); err != nil {
  3028  		t.Error(err)
  3029  	}
  3030  	select {
  3031  	case <-gotHello:
  3032  		// OK
  3033  	case <-time.After(time.Second * 5):
  3034  		t.Error("timeout")
  3035  	}
  3036  
  3037  	if err = client.Close(); err != nil {
  3038  		t.Error(err)
  3039  	}
  3040  }
  3041  
  3042  type connWithCallback struct {
  3043  	net.Conn
  3044  	onWrite func([]byte)
  3045  }
  3046  
  3047  func (c *connWithCallback) Write(b []byte) (int, error) {
  3048  	if c.onWrite != nil {
  3049  		c.onWrite(b)
  3050  	}
  3051  	return c.Conn.Write(b)
  3052  }
  3053  
  3054  func TestApplicationDataQueueLimited(t *testing.T) {
  3055  	// Limit runtime in case of deadlocks
  3056  	lim := test.TimeOut(time.Second * 20)
  3057  	defer lim.Stop()
  3058  
  3059  	// Check for leaking routines
  3060  	report := test.CheckRoutines(t)
  3061  	defer report()
  3062  
  3063  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  3064  	defer cancel()
  3065  
  3066  	ca, cb := dpipe.Pipe()
  3067  	defer ca.Close()
  3068  	defer cb.Close()
  3069  
  3070  	done := make(chan struct{})
  3071  	go func() {
  3072  		serverCert, err := selfsign.GenerateSelfSigned()
  3073  		if err != nil {
  3074  			t.Error(err)
  3075  			return
  3076  		}
  3077  		cfg := &Config{}
  3078  		cfg.Certificates = []tls.Certificate{serverCert}
  3079  
  3080  		dconn, err := createConn(cb, cfg, false)
  3081  		if err != nil {
  3082  			t.Error(err)
  3083  			return
  3084  		}
  3085  		go func() {
  3086  			for i := 0; i < 5; i++ {
  3087  				dconn.lock.RLock()
  3088  				qlen := len(dconn.encryptedPackets)
  3089  				dconn.lock.RUnlock()
  3090  				if qlen > maxAppDataPacketQueueSize {
  3091  					t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
  3092  				}
  3093  				t.Log(qlen)
  3094  				time.Sleep(1 * time.Second)
  3095  			}
  3096  
  3097  		}()
  3098  		if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
  3099  			t.Error("expected handshake to fail")
  3100  		}
  3101  		close(done)
  3102  	}()
  3103  	extensions := []extension.Extension{}
  3104  
  3105  	time.Sleep(50 * time.Millisecond)
  3106  
  3107  	err := sendClientHello([]byte{}, ca, 0, extensions)
  3108  	if err != nil {
  3109  		t.Fatal(err)
  3110  	}
  3111  
  3112  	time.Sleep(50 * time.Millisecond)
  3113  
  3114  	for i := 0; i < 1000; i++ {
  3115  		// Send an application data packet
  3116  		packet, err := (&recordlayer.RecordLayer{
  3117  			Header: recordlayer.Header{
  3118  				Version:        protocol.Version1_2,
  3119  				SequenceNumber: uint64(3),
  3120  				Epoch:          1, // use an epoch greater than 0
  3121  			},
  3122  			Content: &protocol.ApplicationData{
  3123  				Data: []byte{1, 2, 3, 4},
  3124  			},
  3125  		}).Marshal()
  3126  		if err != nil {
  3127  			t.Fatal(err)
  3128  		}
  3129  		ca.Write(packet)
  3130  		if i%100 == 0 {
  3131  			time.Sleep(10 * time.Millisecond)
  3132  		}
  3133  	}
  3134  	time.Sleep(1 * time.Second)
  3135  	ca.Close()
  3136  	<-done
  3137  }
  3138  
  3139  func TestApplicationDataWithClientHelloRejected(t *testing.T) {
  3140  	// Limit runtime in case of deadlocks
  3141  	lim := test.TimeOut(time.Second * 20)
  3142  	defer lim.Stop()
  3143  
  3144  	// Check for leaking routines
  3145  	report := test.CheckRoutines(t)
  3146  	defer report()
  3147  
  3148  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  3149  	defer cancel()
  3150  
  3151  	ca, cb := dpipe.Pipe()
  3152  	defer ca.Close()
  3153  	defer cb.Close()
  3154  
  3155  	done := make(chan struct{})
  3156  	go func() {
  3157  		if _, err := testServer(ctx, cb, &Config{}, true); err == nil {
  3158  			t.Error("expected handshake to fail")
  3159  		}
  3160  		close(done)
  3161  	}()
  3162  	extensions := []extension.Extension{}
  3163  
  3164  	time.Sleep(50 * time.Millisecond)
  3165  
  3166  	err := sendClientHello([]byte{}, ca, 0, extensions)
  3167  	if err != nil {
  3168  		t.Fatal(err)
  3169  	}
  3170  
  3171  	// Send an application data packet
  3172  	packet, err := (&recordlayer.RecordLayer{
  3173  		Header: recordlayer.Header{
  3174  			Version:        protocol.Version1_2,
  3175  			SequenceNumber: uint64(3),
  3176  			Epoch:          0,
  3177  		},
  3178  		Content: &protocol.ApplicationData{
  3179  			Data: []byte{1, 2, 3, 4},
  3180  		},
  3181  	}).Marshal()
  3182  	if err != nil {
  3183  		t.Fatal(err)
  3184  	}
  3185  	ca.Write(packet)
  3186  	<-done
  3187  }