github.com/pion/dtls/v2@v2.2.12/conn_go_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package dtls
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"net"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    19  	"github.com/pion/transport/v2/dpipe"
    20  	"github.com/pion/transport/v2/test"
    21  )
    22  
    23  func TestContextConfig(t *testing.T) {
    24  	// Limit runtime in case of deadlocks
    25  	lim := test.TimeOut(time.Second * 20)
    26  	defer lim.Stop()
    27  
    28  	report := test.CheckRoutines(t)
    29  	defer report()
    30  
    31  	addrListen, err := net.ResolveUDPAddr("udp", "localhost:0")
    32  	if err != nil {
    33  		t.Fatalf("Unexpected error: %v", err)
    34  	}
    35  
    36  	// Dummy listener
    37  	listen, err := net.ListenUDP("udp", addrListen)
    38  	if err != nil {
    39  		t.Fatalf("Unexpected error: %v", err)
    40  	}
    41  	defer func() {
    42  		_ = listen.Close()
    43  	}()
    44  	addr, ok := listen.LocalAddr().(*net.UDPAddr)
    45  	if !ok {
    46  		t.Fatal("Failed to cast net.UDPAddr")
    47  	}
    48  
    49  	cert, err := selfsign.GenerateSelfSigned()
    50  	if err != nil {
    51  		t.Fatalf("Unexpected error: %v", err)
    52  	}
    53  	config := &Config{
    54  		ConnectContextMaker: func() (context.Context, func()) {
    55  			return context.WithTimeout(context.Background(), 40*time.Millisecond)
    56  		},
    57  		Certificates: []tls.Certificate{cert},
    58  	}
    59  
    60  	dials := map[string]struct {
    61  		f     func() (func() (net.Conn, error), func())
    62  		order []byte
    63  	}{
    64  		"Dial": {
    65  			f: func() (func() (net.Conn, error), func()) {
    66  				return func() (net.Conn, error) {
    67  						return Dial("udp", addr, config)
    68  					}, func() {
    69  					}
    70  			},
    71  			order: []byte{0, 1, 2},
    72  		},
    73  		"DialWithContext": {
    74  			f: func() (func() (net.Conn, error), func()) {
    75  				ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
    76  				return func() (net.Conn, error) {
    77  						return DialWithContext(ctx, "udp", addr, config)
    78  					}, func() {
    79  						cancel()
    80  					}
    81  			},
    82  			order: []byte{0, 2, 1},
    83  		},
    84  		"Client": {
    85  			f: func() (func() (net.Conn, error), func()) {
    86  				ca, _ := dpipe.Pipe()
    87  				return func() (net.Conn, error) {
    88  						return Client(ca, config)
    89  					}, func() {
    90  						_ = ca.Close()
    91  					}
    92  			},
    93  			order: []byte{0, 1, 2},
    94  		},
    95  		"ClientWithContext": {
    96  			f: func() (func() (net.Conn, error), func()) {
    97  				ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
    98  				ca, _ := dpipe.Pipe()
    99  				return func() (net.Conn, error) {
   100  						return ClientWithContext(ctx, ca, config)
   101  					}, func() {
   102  						cancel()
   103  						_ = ca.Close()
   104  					}
   105  			},
   106  			order: []byte{0, 2, 1},
   107  		},
   108  		"Server": {
   109  			f: func() (func() (net.Conn, error), func()) {
   110  				ca, _ := dpipe.Pipe()
   111  				return func() (net.Conn, error) {
   112  						return Server(ca, config)
   113  					}, func() {
   114  						_ = ca.Close()
   115  					}
   116  			},
   117  			order: []byte{0, 1, 2},
   118  		},
   119  		"ServerWithContext": {
   120  			f: func() (func() (net.Conn, error), func()) {
   121  				ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
   122  				ca, _ := dpipe.Pipe()
   123  				return func() (net.Conn, error) {
   124  						return ServerWithContext(ctx, ca, config)
   125  					}, func() {
   126  						cancel()
   127  						_ = ca.Close()
   128  					}
   129  			},
   130  			order: []byte{0, 2, 1},
   131  		},
   132  	}
   133  
   134  	for name, dial := range dials {
   135  		dial := dial
   136  		t.Run(name, func(t *testing.T) {
   137  			done := make(chan struct{})
   138  
   139  			go func() {
   140  				d, cancel := dial.f()
   141  				conn, err := d()
   142  				defer cancel()
   143  				var netError net.Error
   144  				if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck
   145  					t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
   146  					close(done)
   147  					return
   148  				}
   149  				done <- struct{}{}
   150  				if err == nil {
   151  					_ = conn.Close()
   152  				}
   153  			}()
   154  
   155  			var order []byte
   156  			early := time.After(20 * time.Millisecond)
   157  			late := time.After(60 * time.Millisecond)
   158  			func() {
   159  				for len(order) < 3 {
   160  					select {
   161  					case <-early:
   162  						order = append(order, 0)
   163  					case _, ok := <-done:
   164  						if !ok {
   165  							return
   166  						}
   167  						order = append(order, 1)
   168  					case <-late:
   169  						order = append(order, 2)
   170  					}
   171  				}
   172  			}()
   173  			if !bytes.Equal(dial.order, order) {
   174  				t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order)
   175  			}
   176  		})
   177  	}
   178  }