github.com/pion/dtls/v2@v2.2.12/replayprotection_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "context" 8 "net" 9 "reflect" 10 "sync" 11 "sync/atomic" 12 "testing" 13 "time" 14 15 "github.com/pion/transport/v2/dpipe" 16 "github.com/pion/transport/v2/test" 17 ) 18 19 func TestReplayProtection(t *testing.T) { 20 // Limit runtime in case of deadlocks 21 lim := test.TimeOut(5 * time.Second) 22 defer lim.Stop() 23 24 // Check for leaking routines 25 report := test.CheckRoutines(t) 26 defer report() 27 28 c0, c1 := dpipe.Pipe() 29 c2, c3 := dpipe.Pipe() 30 conn := []net.Conn{c0, c1, c2, c3} 31 32 var wgRoutines sync.WaitGroup 33 var cntReplays int32 = 1 34 35 ctxReplayDone, replayDone := context.WithCancel(context.Background()) 36 37 replaySendDone := func() { 38 cnt := atomic.AddInt32(&cntReplays, -1) 39 if cnt == 0 { 40 replayDone() 41 } 42 } 43 44 replayer := func(ca, cb net.Conn) { 45 defer wgRoutines.Done() 46 // Man in the middle 47 for { 48 b := make([]byte, 2048) 49 n, rerr := ca.Read(b) 50 if rerr != nil { 51 return 52 } 53 if _, werr := cb.Write(b[:n]); werr != nil { 54 t.Error(werr) 55 return 56 } 57 58 atomic.AddInt32(&cntReplays, 1) 59 go func() { 60 defer replaySendDone() 61 // Replay bit later 62 time.Sleep(time.Millisecond) 63 if _, werr := cb.Write(b[:n]); werr != nil { 64 t.Error(werr) 65 } 66 }() 67 } 68 } 69 wgRoutines.Add(2) 70 go replayer(conn[1], conn[2]) 71 go replayer(conn[2], conn[1]) 72 73 ca, cb, err := pipeConn(conn[0], conn[3]) 74 if err != nil { 75 t.Fatal(err) 76 } 77 78 const numMsgs = 10 79 80 var received [2][][]byte 81 for i, c := range []net.Conn{ca, cb} { 82 i := i 83 c := c 84 wgRoutines.Add(1) 85 atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message 86 var lastMsgDone sync.Once 87 go func() { 88 defer wgRoutines.Done() 89 for { 90 b := make([]byte, 2048) 91 n, rerr := c.Read(b) 92 if rerr != nil { 93 return 94 } 95 received[i] = append(received[i], b[:n]) 96 if b[0] == numMsgs-1 { 97 // Final message received 98 lastMsgDone.Do(func() { 99 defer replaySendDone() 100 }) 101 } 102 } 103 }() 104 } 105 106 var sent [][]byte 107 for i := 0; i < numMsgs; i++ { 108 data := []byte{byte(i)} 109 sent = append(sent, data) 110 if _, werr := ca.Write(data); werr != nil { 111 t.Error(werr) 112 return 113 } 114 if _, werr := cb.Write(data); werr != nil { 115 t.Error(werr) 116 return 117 } 118 } 119 120 replaySendDone() 121 <-ctxReplayDone.Done() 122 time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent 123 124 for i := 0; i < 4; i++ { 125 if err := conn[i].Close(); err != nil { 126 t.Error(err) 127 } 128 } 129 if err := ca.Close(); err != nil { 130 t.Error(err) 131 } 132 if err := cb.Close(); err != nil { 133 t.Error(err) 134 } 135 wgRoutines.Wait() 136 137 for _, r := range received { 138 if !reflect.DeepEqual(sent, r) { 139 t.Errorf("Received data differs, expected: %v, got: %v", sent, r) 140 } 141 } 142 }