github.com/pion/dtls/v2@v2.2.12/handshaker_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/tls"
    10  	"errors"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    16  	"github.com/pion/dtls/v2/pkg/crypto/signaturehash"
    17  	"github.com/pion/dtls/v2/pkg/protocol/alert"
    18  	"github.com/pion/dtls/v2/pkg/protocol/handshake"
    19  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    20  	"github.com/pion/logging"
    21  	"github.com/pion/transport/v2/test"
    22  )
    23  
    24  const nonZeroRetransmitInterval = 100 * time.Millisecond
    25  
    26  // Test that writes to the key log are in the correct format and only applies
    27  // when a key log writer is given.
    28  func TestWriteKeyLog(t *testing.T) {
    29  	var buf bytes.Buffer
    30  	cfg := handshakeConfig{
    31  		keyLogWriter: &buf,
    32  	}
    33  	cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
    34  
    35  	// Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret>
    36  	// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
    37  	want := "LABEL aabbcc ddeeff\n"
    38  	if buf.String() != want {
    39  		t.Fatalf("Got %s want %s", buf.String(), want)
    40  	}
    41  
    42  	// no key log writer = no writes
    43  	cfg = handshakeConfig{}
    44  	cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
    45  }
    46  
    47  func TestHandshaker(t *testing.T) {
    48  	// Check for leaking routines
    49  	report := test.CheckRoutines(t)
    50  	defer report()
    51  
    52  	loggerFactory := logging.NewDefaultLoggerFactory()
    53  	logger := loggerFactory.NewLogger("dtls")
    54  
    55  	cipherSuites, err := parseCipherSuites(nil, nil, true, false)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  	clientCert, err := selfsign.GenerateSelfSigned()
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  
    64  	genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){
    65  		"PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
    66  			return TestEndpoint{}, TestEndpoint{}, nil
    67  		},
    68  
    69  		"HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
    70  			var (
    71  				cntHelloVerifyRequest  = 0
    72  				cntClientHelloNoCookie = 0
    73  			)
    74  			const helloVerifyDrop = 5
    75  
    76  			clientEndpoint := TestEndpoint{
    77  				Filter: func(p *packet) bool {
    78  					h, ok := p.record.Content.(*handshake.Handshake)
    79  					if !ok {
    80  						return true
    81  					}
    82  					if hmch, ok := h.Message.(*handshake.MessageClientHello); ok {
    83  						if len(hmch.Cookie) == 0 {
    84  							cntClientHelloNoCookie++
    85  						}
    86  					}
    87  					return true
    88  				},
    89  			}
    90  
    91  			serverEndpoint := TestEndpoint{
    92  				Filter: func(p *packet) bool {
    93  					h, ok := p.record.Content.(*handshake.Handshake)
    94  					if !ok {
    95  						return true
    96  					}
    97  					if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok {
    98  						cntHelloVerifyRequest++
    99  						return cntHelloVerifyRequest > helloVerifyDrop
   100  					}
   101  					return true
   102  				},
   103  			}
   104  
   105  			report := func(t *testing.T) {
   106  				if cntHelloVerifyRequest != helloVerifyDrop+1 {
   107  					t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
   108  				}
   109  				if cntClientHelloNoCookie != cntHelloVerifyRequest {
   110  					t.Errorf(
   111  						"HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
   112  						cntHelloVerifyRequest, cntClientHelloNoCookie,
   113  					)
   114  				}
   115  			}
   116  
   117  			return clientEndpoint, serverEndpoint, report
   118  		},
   119  
   120  		"NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
   121  			var (
   122  				cntClientFinished = 0
   123  				cntServerFinished = 0
   124  			)
   125  
   126  			clientEndpoint := TestEndpoint{
   127  				Filter: func(p *packet) bool {
   128  					h, ok := p.record.Content.(*handshake.Handshake)
   129  					if !ok {
   130  						return true
   131  					}
   132  					if _, ok := h.Message.(*handshake.MessageFinished); ok {
   133  						cntClientFinished++
   134  					}
   135  					return true
   136  				},
   137  			}
   138  
   139  			serverEndpoint := TestEndpoint{
   140  				Filter: func(p *packet) bool {
   141  					h, ok := p.record.Content.(*handshake.Handshake)
   142  					if !ok {
   143  						return true
   144  					}
   145  					if _, ok := h.Message.(*handshake.MessageFinished); ok {
   146  						cntServerFinished++
   147  					}
   148  					return true
   149  				},
   150  			}
   151  
   152  			report := func(t *testing.T) {
   153  				if cntClientFinished != 1 {
   154  					t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished)
   155  				}
   156  				if cntServerFinished != 1 {
   157  					t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished)
   158  				}
   159  			}
   160  
   161  			return clientEndpoint, serverEndpoint, report
   162  		},
   163  
   164  		"SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
   165  			var (
   166  				cntClientFinished               = 0
   167  				isClientFinished                = false
   168  				cntClientFinishedLastRetransmit = 0
   169  				cntServerFinished               = 0
   170  				isServerFinished                = false
   171  				cntServerFinishedLastRetransmit = 0
   172  			)
   173  
   174  			clientEndpoint := TestEndpoint{
   175  				Filter: func(p *packet) bool {
   176  					h, ok := p.record.Content.(*handshake.Handshake)
   177  					if !ok {
   178  						return true
   179  					}
   180  					if _, ok := h.Message.(*handshake.MessageFinished); ok {
   181  						if isClientFinished {
   182  							cntClientFinishedLastRetransmit++
   183  						} else {
   184  							cntClientFinished++
   185  						}
   186  					}
   187  					return true
   188  				},
   189  				Delay: 0,
   190  				OnFinished: func() {
   191  					isClientFinished = true
   192  				},
   193  				FinishWait: 2000 * time.Millisecond,
   194  			}
   195  
   196  			serverEndpoint := TestEndpoint{
   197  				Filter: func(p *packet) bool {
   198  					h, ok := p.record.Content.(*handshake.Handshake)
   199  					if !ok {
   200  						return true
   201  					}
   202  					if _, ok := h.Message.(*handshake.MessageFinished); ok {
   203  						if isServerFinished {
   204  							cntServerFinishedLastRetransmit++
   205  						} else {
   206  							cntServerFinished++
   207  						}
   208  					}
   209  					return true
   210  				},
   211  				Delay: 1000 * time.Millisecond,
   212  				OnFinished: func() {
   213  					isServerFinished = true
   214  				},
   215  				FinishWait: 2000 * time.Millisecond,
   216  			}
   217  
   218  			report := func(t *testing.T) {
   219  				// with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client
   220  				// using a range of 9 - 11 for checking
   221  				if cntClientFinished < 8 || cntClientFinished > 11 {
   222  					t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished)
   223  				}
   224  				if !isClientFinished {
   225  					t.Errorf("Client is not finished")
   226  				}
   227  				// there should be no `Finished` last retransmit from client
   228  				if cntClientFinishedLastRetransmit != 0 {
   229  					t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit)
   230  				}
   231  				if cntServerFinished < 1 {
   232  					t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished)
   233  				}
   234  				if !isServerFinished {
   235  					t.Errorf("Server is not finished")
   236  				}
   237  				// there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`.
   238  				if cntServerFinishedLastRetransmit < 1 {
   239  					t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit)
   240  				}
   241  			}
   242  
   243  			return clientEndpoint, serverEndpoint, report
   244  		},
   245  	}
   246  
   247  	for name, filters := range genFilters {
   248  		clientEndpoint, serverEndpoint, report := filters()
   249  		t.Run(name, func(t *testing.T) {
   250  			ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
   251  			defer cancel()
   252  
   253  			if report != nil {
   254  				defer report(t)
   255  			}
   256  
   257  			ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint)
   258  			ca.state.isClient = true
   259  
   260  			var wg sync.WaitGroup
   261  			wg.Add(2)
   262  
   263  			ctxCliFinished, cancelCli := context.WithCancel(ctx)
   264  			ctxSrvFinished, cancelSrv := context.WithCancel(ctx)
   265  			go func() {
   266  				defer wg.Done()
   267  				cfg := &handshakeConfig{
   268  					localCipherSuites:     cipherSuites,
   269  					localCertificates:     []tls.Certificate{clientCert},
   270  					ellipticCurves:        defaultCurves,
   271  					localSignatureSchemes: signaturehash.Algorithms(),
   272  					insecureSkipVerify:    true,
   273  					log:                   logger,
   274  					onFlightState: func(f flightVal, s handshakeState) {
   275  						if s == handshakeFinished {
   276  							if clientEndpoint.OnFinished != nil {
   277  								clientEndpoint.OnFinished()
   278  							}
   279  							time.AfterFunc(clientEndpoint.FinishWait, func() {
   280  								cancelCli()
   281  							})
   282  						}
   283  					},
   284  					retransmitInterval: nonZeroRetransmitInterval,
   285  				}
   286  
   287  				fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
   288  				err := fsm.Run(ctx, ca, handshakePreparing)
   289  				switch {
   290  				case errors.Is(err, context.Canceled):
   291  				case errors.Is(err, context.DeadlineExceeded):
   292  					t.Error("Timeout")
   293  				default:
   294  					t.Error(err)
   295  				}
   296  			}()
   297  
   298  			go func() {
   299  				defer wg.Done()
   300  				cfg := &handshakeConfig{
   301  					localCipherSuites:     cipherSuites,
   302  					localCertificates:     []tls.Certificate{clientCert},
   303  					ellipticCurves:        defaultCurves,
   304  					localSignatureSchemes: signaturehash.Algorithms(),
   305  					insecureSkipVerify:    true,
   306  					log:                   logger,
   307  					onFlightState: func(f flightVal, s handshakeState) {
   308  						if s == handshakeFinished {
   309  							if serverEndpoint.OnFinished != nil {
   310  								serverEndpoint.OnFinished()
   311  							}
   312  							time.AfterFunc(serverEndpoint.FinishWait, func() {
   313  								cancelSrv()
   314  							})
   315  						}
   316  					},
   317  					retransmitInterval: nonZeroRetransmitInterval,
   318  				}
   319  
   320  				fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
   321  				err := fsm.Run(ctx, cb, handshakePreparing)
   322  				switch {
   323  				case errors.Is(err, context.Canceled):
   324  				case errors.Is(err, context.DeadlineExceeded):
   325  					t.Error("Timeout")
   326  				default:
   327  					t.Error(err)
   328  				}
   329  			}()
   330  
   331  			<-ctxCliFinished.Done()
   332  			<-ctxSrvFinished.Done()
   333  
   334  			cancel()
   335  			wg.Wait()
   336  		})
   337  	}
   338  }
   339  
   340  type packetFilter func(p *packet) bool
   341  
   342  type TestEndpoint struct {
   343  	Filter     packetFilter
   344  	Delay      time.Duration
   345  	OnFinished func()
   346  	FinishWait time.Duration
   347  }
   348  
   349  func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) {
   350  	ca := newHandshakeCache()
   351  	cb := newHandshakeCache()
   352  	chA := make(chan chan struct{})
   353  	chB := make(chan chan struct{})
   354  	return &flightTestConn{
   355  			handshakeCache: ca,
   356  			otherEndCache:  cb,
   357  			recv:           chA,
   358  			otherEndRecv:   chB,
   359  			done:           ctx.Done(),
   360  			filter:         clientEndpoint.Filter,
   361  			delay:          clientEndpoint.Delay,
   362  		}, &flightTestConn{
   363  			handshakeCache: cb,
   364  			otherEndCache:  ca,
   365  			recv:           chB,
   366  			otherEndRecv:   chA,
   367  			done:           ctx.Done(),
   368  			filter:         serverEndpoint.Filter,
   369  			delay:          serverEndpoint.Delay,
   370  		}
   371  }
   372  
   373  type flightTestConn struct {
   374  	state          State
   375  	handshakeCache *handshakeCache
   376  	recv           chan chan struct{}
   377  	done           <-chan struct{}
   378  	epoch          uint16
   379  
   380  	filter packetFilter
   381  
   382  	delay time.Duration
   383  
   384  	otherEndCache *handshakeCache
   385  	otherEndRecv  chan chan struct{}
   386  }
   387  
   388  func (c *flightTestConn) recvHandshake() <-chan chan struct{} {
   389  	return c.recv
   390  }
   391  
   392  func (c *flightTestConn) setLocalEpoch(epoch uint16) {
   393  	c.epoch = epoch
   394  }
   395  
   396  func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) error {
   397  	return nil
   398  }
   399  
   400  func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error {
   401  	time.Sleep(c.delay)
   402  	for _, p := range pkts {
   403  		if c.filter != nil && !c.filter(p) {
   404  			continue
   405  		}
   406  		if h, ok := p.record.Content.(*handshake.Handshake); ok {
   407  			handshakeRaw, err := p.record.Marshal()
   408  			if err != nil {
   409  				return err
   410  			}
   411  
   412  			c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
   413  
   414  			content, err := h.Message.Marshal()
   415  			if err != nil {
   416  				return err
   417  			}
   418  			h.Header.Length = uint32(len(content))
   419  			h.Header.FragmentLength = uint32(len(content))
   420  			hdr, err := h.Header.Marshal()
   421  			if err != nil {
   422  				return err
   423  			}
   424  			c.otherEndCache.push(
   425  				append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
   426  		}
   427  	}
   428  	go func() {
   429  		select {
   430  		case c.otherEndRecv <- make(chan struct{}):
   431  		case <-c.done:
   432  		}
   433  	}()
   434  
   435  	// Avoid deadlock on JS/WASM environment due to context switch problem.
   436  	time.Sleep(10 * time.Millisecond)
   437  
   438  	return nil
   439  }
   440  
   441  func (c *flightTestConn) handleQueuedPackets(context.Context) error {
   442  	return nil
   443  }
   444  
   445  func (c *flightTestConn) sessionKey() []byte {
   446  	return nil
   447  }