github.com/pion/dtls/v2@v2.2.12/resume_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 package dtls 5 6 import ( 7 "bytes" 8 "crypto/tls" 9 "errors" 10 "fmt" 11 "net" 12 "sync" 13 "testing" 14 "time" 15 16 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 17 "github.com/pion/transport/v2/test" 18 ) 19 20 var errMessageMissmatch = errors.New("messages missmatch") 21 22 func TestResumeClient(t *testing.T) { 23 DoTestResume(t, Client, Server) 24 } 25 26 func TestResumeServer(t *testing.T) { 27 DoTestResume(t, Server, Client) 28 } 29 30 func fatal(t *testing.T, errChan chan error, err error) { 31 close(errChan) 32 t.Fatal(err) 33 } 34 35 func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) { 36 // Limit runtime in case of deadlocks 37 lim := test.TimeOut(time.Second * 20) 38 defer lim.Stop() 39 40 // Check for leaking routines 41 report := test.CheckRoutines(t) 42 defer report() 43 44 certificate, err := selfsign.GenerateSelfSigned() 45 if err != nil { 46 t.Fatal(err) 47 } 48 49 // Generate connections 50 localConn1, rc1 := net.Pipe() 51 localConn2, rc2 := net.Pipe() 52 remoteConn := &backupConn{curr: rc1, next: rc2} 53 54 // Launch remote in another goroutine 55 errChan := make(chan error, 1) 56 defer func() { 57 err = <-errChan 58 if err != nil { 59 t.Fatal(err) 60 } 61 }() 62 config := &Config{ 63 Certificates: []tls.Certificate{certificate}, 64 InsecureSkipVerify: true, 65 ExtendedMasterSecret: RequireExtendedMasterSecret, 66 } 67 go func() { 68 var remote *Conn 69 var errR error 70 remote, errR = newRemote(remoteConn, config) 71 if errR != nil { 72 errChan <- errR 73 } 74 75 // Loop of read write 76 for i := 0; i < 2; i++ { 77 recv := make([]byte, 1024) 78 var n int 79 n, errR = remote.Read(recv) 80 if errR != nil { 81 errChan <- errR 82 } 83 84 if _, errR = remote.Write(recv[:n]); errR != nil { 85 errChan <- errR 86 } 87 } 88 errChan <- nil 89 }() 90 91 var local *Conn 92 local, err = newLocal(localConn1, config) 93 if err != nil { 94 fatal(t, errChan, err) 95 } 96 defer func() { 97 _ = local.Close() 98 }() 99 100 // Test write and read 101 message := []byte("Hello") 102 if _, err = local.Write(message); err != nil { 103 fatal(t, errChan, err) 104 } 105 106 recv := make([]byte, 1024) 107 var n int 108 n, err = local.Read(recv) 109 if err != nil { 110 fatal(t, errChan, err) 111 } 112 113 if !bytes.Equal(message, recv[:n]) { 114 fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n])) 115 } 116 117 if err = localConn1.Close(); err != nil { 118 fatal(t, errChan, err) 119 } 120 121 // Serialize and deserialize state 122 state := local.ConnectionState() 123 var b []byte 124 b, err = state.MarshalBinary() 125 if err != nil { 126 fatal(t, errChan, err) 127 } 128 deserialized := &State{} 129 if err = deserialized.UnmarshalBinary(b); err != nil { 130 fatal(t, errChan, err) 131 } 132 133 // Resume dtls connection 134 var resumed net.Conn 135 resumed, err = Resume(deserialized, localConn2, config) 136 if err != nil { 137 fatal(t, errChan, err) 138 } 139 defer func() { 140 _ = resumed.Close() 141 }() 142 143 // Test write and read on resumed connection 144 if _, err = resumed.Write(message); err != nil { 145 fatal(t, errChan, err) 146 } 147 148 recv = make([]byte, 1024) 149 n, err = resumed.Read(recv) 150 if err != nil { 151 fatal(t, errChan, err) 152 } 153 154 if !bytes.Equal(message, recv[:n]) { 155 fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n])) 156 } 157 } 158 159 type backupConn struct { 160 curr net.Conn 161 next net.Conn 162 mux sync.Mutex 163 } 164 165 func (b *backupConn) Read(data []byte) (n int, err error) { 166 n, err = b.curr.Read(data) 167 if err != nil && b.next != nil { 168 b.mux.Lock() 169 b.curr = b.next 170 b.next = nil 171 b.mux.Unlock() 172 return b.Read(data) 173 } 174 return n, err 175 } 176 177 func (b *backupConn) Write(data []byte) (n int, err error) { 178 n, err = b.curr.Write(data) 179 if err != nil && b.next != nil { 180 b.mux.Lock() 181 b.curr = b.next 182 b.next = nil 183 b.mux.Unlock() 184 return b.Write(data) 185 } 186 return n, err 187 } 188 189 func (b *backupConn) Close() error { 190 return nil 191 } 192 193 func (b *backupConn) LocalAddr() net.Addr { 194 return nil 195 } 196 197 func (b *backupConn) RemoteAddr() net.Addr { 198 return nil 199 } 200 201 func (b *backupConn) SetDeadline(time.Time) error { 202 return nil 203 } 204 205 func (b *backupConn) SetReadDeadline(time.Time) error { 206 return nil 207 } 208 209 func (b *backupConn) SetWriteDeadline(time.Time) error { 210 return nil 211 }