github.com/pion/dtls/v2@v2.2.12/replayprotection_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"context"
     8  	"net"
     9  	"reflect"
    10  	"sync"
    11  	"sync/atomic"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/pion/transport/v2/dpipe"
    16  	"github.com/pion/transport/v2/test"
    17  )
    18  
    19  func TestReplayProtection(t *testing.T) {
    20  	// Limit runtime in case of deadlocks
    21  	lim := test.TimeOut(5 * time.Second)
    22  	defer lim.Stop()
    23  
    24  	// Check for leaking routines
    25  	report := test.CheckRoutines(t)
    26  	defer report()
    27  
    28  	c0, c1 := dpipe.Pipe()
    29  	c2, c3 := dpipe.Pipe()
    30  	conn := []net.Conn{c0, c1, c2, c3}
    31  
    32  	var wgRoutines sync.WaitGroup
    33  	var cntReplays int32 = 1
    34  
    35  	ctxReplayDone, replayDone := context.WithCancel(context.Background())
    36  
    37  	replaySendDone := func() {
    38  		cnt := atomic.AddInt32(&cntReplays, -1)
    39  		if cnt == 0 {
    40  			replayDone()
    41  		}
    42  	}
    43  
    44  	replayer := func(ca, cb net.Conn) {
    45  		defer wgRoutines.Done()
    46  		// Man in the middle
    47  		for {
    48  			b := make([]byte, 2048)
    49  			n, rerr := ca.Read(b)
    50  			if rerr != nil {
    51  				return
    52  			}
    53  			if _, werr := cb.Write(b[:n]); werr != nil {
    54  				t.Error(werr)
    55  				return
    56  			}
    57  
    58  			atomic.AddInt32(&cntReplays, 1)
    59  			go func() {
    60  				defer replaySendDone()
    61  				// Replay bit later
    62  				time.Sleep(time.Millisecond)
    63  				if _, werr := cb.Write(b[:n]); werr != nil {
    64  					t.Error(werr)
    65  				}
    66  			}()
    67  		}
    68  	}
    69  	wgRoutines.Add(2)
    70  	go replayer(conn[1], conn[2])
    71  	go replayer(conn[2], conn[1])
    72  
    73  	ca, cb, err := pipeConn(conn[0], conn[3])
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  
    78  	const numMsgs = 10
    79  
    80  	var received [2][][]byte
    81  	for i, c := range []net.Conn{ca, cb} {
    82  		i := i
    83  		c := c
    84  		wgRoutines.Add(1)
    85  		atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
    86  		var lastMsgDone sync.Once
    87  		go func() {
    88  			defer wgRoutines.Done()
    89  			for {
    90  				b := make([]byte, 2048)
    91  				n, rerr := c.Read(b)
    92  				if rerr != nil {
    93  					return
    94  				}
    95  				received[i] = append(received[i], b[:n])
    96  				if b[0] == numMsgs-1 {
    97  					// Final message received
    98  					lastMsgDone.Do(func() {
    99  						defer replaySendDone()
   100  					})
   101  				}
   102  			}
   103  		}()
   104  	}
   105  
   106  	var sent [][]byte
   107  	for i := 0; i < numMsgs; i++ {
   108  		data := []byte{byte(i)}
   109  		sent = append(sent, data)
   110  		if _, werr := ca.Write(data); werr != nil {
   111  			t.Error(werr)
   112  			return
   113  		}
   114  		if _, werr := cb.Write(data); werr != nil {
   115  			t.Error(werr)
   116  			return
   117  		}
   118  	}
   119  
   120  	replaySendDone()
   121  	<-ctxReplayDone.Done()
   122  	time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent
   123  
   124  	for i := 0; i < 4; i++ {
   125  		if err := conn[i].Close(); err != nil {
   126  			t.Error(err)
   127  		}
   128  	}
   129  	if err := ca.Close(); err != nil {
   130  		t.Error(err)
   131  	}
   132  	if err := cb.Close(); err != nil {
   133  		t.Error(err)
   134  	}
   135  	wgRoutines.Wait()
   136  
   137  	for _, r := range received {
   138  		if !reflect.DeepEqual(sent, r) {
   139  			t.Errorf("Received data differs, expected: %v, got: %v", sent, r)
   140  		}
   141  	}
   142  }