github.com/pion/dtls/v2@v2.2.12/conn_go_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package dtls 8 9 import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "errors" 14 "net" 15 "testing" 16 "time" 17 18 "github.com/pion/dtls/v2/pkg/crypto/selfsign" 19 "github.com/pion/transport/v2/dpipe" 20 "github.com/pion/transport/v2/test" 21 ) 22 23 func TestContextConfig(t *testing.T) { 24 // Limit runtime in case of deadlocks 25 lim := test.TimeOut(time.Second * 20) 26 defer lim.Stop() 27 28 report := test.CheckRoutines(t) 29 defer report() 30 31 addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") 32 if err != nil { 33 t.Fatalf("Unexpected error: %v", err) 34 } 35 36 // Dummy listener 37 listen, err := net.ListenUDP("udp", addrListen) 38 if err != nil { 39 t.Fatalf("Unexpected error: %v", err) 40 } 41 defer func() { 42 _ = listen.Close() 43 }() 44 addr, ok := listen.LocalAddr().(*net.UDPAddr) 45 if !ok { 46 t.Fatal("Failed to cast net.UDPAddr") 47 } 48 49 cert, err := selfsign.GenerateSelfSigned() 50 if err != nil { 51 t.Fatalf("Unexpected error: %v", err) 52 } 53 config := &Config{ 54 ConnectContextMaker: func() (context.Context, func()) { 55 return context.WithTimeout(context.Background(), 40*time.Millisecond) 56 }, 57 Certificates: []tls.Certificate{cert}, 58 } 59 60 dials := map[string]struct { 61 f func() (func() (net.Conn, error), func()) 62 order []byte 63 }{ 64 "Dial": { 65 f: func() (func() (net.Conn, error), func()) { 66 return func() (net.Conn, error) { 67 return Dial("udp", addr, config) 68 }, func() { 69 } 70 }, 71 order: []byte{0, 1, 2}, 72 }, 73 "DialWithContext": { 74 f: func() (func() (net.Conn, error), func()) { 75 ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) 76 return func() (net.Conn, error) { 77 return DialWithContext(ctx, "udp", addr, config) 78 }, func() { 79 cancel() 80 } 81 }, 82 order: []byte{0, 2, 1}, 83 }, 84 "Client": { 85 f: func() (func() (net.Conn, error), func()) { 86 ca, _ := dpipe.Pipe() 87 return func() (net.Conn, error) { 88 return Client(ca, config) 89 }, func() { 90 _ = ca.Close() 91 } 92 }, 93 order: []byte{0, 1, 2}, 94 }, 95 "ClientWithContext": { 96 f: func() (func() (net.Conn, error), func()) { 97 ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) 98 ca, _ := dpipe.Pipe() 99 return func() (net.Conn, error) { 100 return ClientWithContext(ctx, ca, config) 101 }, func() { 102 cancel() 103 _ = ca.Close() 104 } 105 }, 106 order: []byte{0, 2, 1}, 107 }, 108 "Server": { 109 f: func() (func() (net.Conn, error), func()) { 110 ca, _ := dpipe.Pipe() 111 return func() (net.Conn, error) { 112 return Server(ca, config) 113 }, func() { 114 _ = ca.Close() 115 } 116 }, 117 order: []byte{0, 1, 2}, 118 }, 119 "ServerWithContext": { 120 f: func() (func() (net.Conn, error), func()) { 121 ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) 122 ca, _ := dpipe.Pipe() 123 return func() (net.Conn, error) { 124 return ServerWithContext(ctx, ca, config) 125 }, func() { 126 cancel() 127 _ = ca.Close() 128 } 129 }, 130 order: []byte{0, 2, 1}, 131 }, 132 } 133 134 for name, dial := range dials { 135 dial := dial 136 t.Run(name, func(t *testing.T) { 137 done := make(chan struct{}) 138 139 go func() { 140 d, cancel := dial.f() 141 conn, err := d() 142 defer cancel() 143 var netError net.Error 144 if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck 145 t.Errorf("Client error exp(Temporary network error) failed(%v)", err) 146 close(done) 147 return 148 } 149 done <- struct{}{} 150 if err == nil { 151 _ = conn.Close() 152 } 153 }() 154 155 var order []byte 156 early := time.After(20 * time.Millisecond) 157 late := time.After(60 * time.Millisecond) 158 func() { 159 for len(order) < 3 { 160 select { 161 case <-early: 162 order = append(order, 0) 163 case _, ok := <-done: 164 if !ok { 165 return 166 } 167 order = append(order, 1) 168 case <-late: 169 order = append(order, 2) 170 } 171 } 172 }() 173 if !bytes.Equal(dial.order, order) { 174 t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order) 175 } 176 }) 177 } 178 }