github.com/pion/dtls/v2@v2.2.12/resume_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package dtls
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/tls"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/pion/dtls/v2/pkg/crypto/selfsign"
    17  	"github.com/pion/transport/v2/test"
    18  )
    19  
    20  var errMessageMissmatch = errors.New("messages missmatch")
    21  
    22  func TestResumeClient(t *testing.T) {
    23  	DoTestResume(t, Client, Server)
    24  }
    25  
    26  func TestResumeServer(t *testing.T) {
    27  	DoTestResume(t, Server, Client)
    28  }
    29  
    30  func fatal(t *testing.T, errChan chan error, err error) {
    31  	close(errChan)
    32  	t.Fatal(err)
    33  }
    34  
    35  func DoTestResume(t *testing.T, newLocal, newRemote func(net.Conn, *Config) (*Conn, error)) {
    36  	// Limit runtime in case of deadlocks
    37  	lim := test.TimeOut(time.Second * 20)
    38  	defer lim.Stop()
    39  
    40  	// Check for leaking routines
    41  	report := test.CheckRoutines(t)
    42  	defer report()
    43  
    44  	certificate, err := selfsign.GenerateSelfSigned()
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  
    49  	// Generate connections
    50  	localConn1, rc1 := net.Pipe()
    51  	localConn2, rc2 := net.Pipe()
    52  	remoteConn := &backupConn{curr: rc1, next: rc2}
    53  
    54  	// Launch remote in another goroutine
    55  	errChan := make(chan error, 1)
    56  	defer func() {
    57  		err = <-errChan
    58  		if err != nil {
    59  			t.Fatal(err)
    60  		}
    61  	}()
    62  	config := &Config{
    63  		Certificates:         []tls.Certificate{certificate},
    64  		InsecureSkipVerify:   true,
    65  		ExtendedMasterSecret: RequireExtendedMasterSecret,
    66  	}
    67  	go func() {
    68  		var remote *Conn
    69  		var errR error
    70  		remote, errR = newRemote(remoteConn, config)
    71  		if errR != nil {
    72  			errChan <- errR
    73  		}
    74  
    75  		// Loop of read write
    76  		for i := 0; i < 2; i++ {
    77  			recv := make([]byte, 1024)
    78  			var n int
    79  			n, errR = remote.Read(recv)
    80  			if errR != nil {
    81  				errChan <- errR
    82  			}
    83  
    84  			if _, errR = remote.Write(recv[:n]); errR != nil {
    85  				errChan <- errR
    86  			}
    87  		}
    88  		errChan <- nil
    89  	}()
    90  
    91  	var local *Conn
    92  	local, err = newLocal(localConn1, config)
    93  	if err != nil {
    94  		fatal(t, errChan, err)
    95  	}
    96  	defer func() {
    97  		_ = local.Close()
    98  	}()
    99  
   100  	// Test write and read
   101  	message := []byte("Hello")
   102  	if _, err = local.Write(message); err != nil {
   103  		fatal(t, errChan, err)
   104  	}
   105  
   106  	recv := make([]byte, 1024)
   107  	var n int
   108  	n, err = local.Read(recv)
   109  	if err != nil {
   110  		fatal(t, errChan, err)
   111  	}
   112  
   113  	if !bytes.Equal(message, recv[:n]) {
   114  		fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
   115  	}
   116  
   117  	if err = localConn1.Close(); err != nil {
   118  		fatal(t, errChan, err)
   119  	}
   120  
   121  	// Serialize and deserialize state
   122  	state := local.ConnectionState()
   123  	var b []byte
   124  	b, err = state.MarshalBinary()
   125  	if err != nil {
   126  		fatal(t, errChan, err)
   127  	}
   128  	deserialized := &State{}
   129  	if err = deserialized.UnmarshalBinary(b); err != nil {
   130  		fatal(t, errChan, err)
   131  	}
   132  
   133  	// Resume dtls connection
   134  	var resumed net.Conn
   135  	resumed, err = Resume(deserialized, localConn2, config)
   136  	if err != nil {
   137  		fatal(t, errChan, err)
   138  	}
   139  	defer func() {
   140  		_ = resumed.Close()
   141  	}()
   142  
   143  	// Test write and read on resumed connection
   144  	if _, err = resumed.Write(message); err != nil {
   145  		fatal(t, errChan, err)
   146  	}
   147  
   148  	recv = make([]byte, 1024)
   149  	n, err = resumed.Read(recv)
   150  	if err != nil {
   151  		fatal(t, errChan, err)
   152  	}
   153  
   154  	if !bytes.Equal(message, recv[:n]) {
   155  		fatal(t, errChan, fmt.Errorf("%w: %s != %s", errMessageMissmatch, message, recv[:n]))
   156  	}
   157  }
   158  
   159  type backupConn struct {
   160  	curr net.Conn
   161  	next net.Conn
   162  	mux  sync.Mutex
   163  }
   164  
   165  func (b *backupConn) Read(data []byte) (n int, err error) {
   166  	n, err = b.curr.Read(data)
   167  	if err != nil && b.next != nil {
   168  		b.mux.Lock()
   169  		b.curr = b.next
   170  		b.next = nil
   171  		b.mux.Unlock()
   172  		return b.Read(data)
   173  	}
   174  	return n, err
   175  }
   176  
   177  func (b *backupConn) Write(data []byte) (n int, err error) {
   178  	n, err = b.curr.Write(data)
   179  	if err != nil && b.next != nil {
   180  		b.mux.Lock()
   181  		b.curr = b.next
   182  		b.next = nil
   183  		b.mux.Unlock()
   184  		return b.Write(data)
   185  	}
   186  	return n, err
   187  }
   188  
   189  func (b *backupConn) Close() error {
   190  	return nil
   191  }
   192  
   193  func (b *backupConn) LocalAddr() net.Addr {
   194  	return nil
   195  }
   196  
   197  func (b *backupConn) RemoteAddr() net.Addr {
   198  	return nil
   199  }
   200  
   201  func (b *backupConn) SetDeadline(time.Time) error {
   202  	return nil
   203  }
   204  
   205  func (b *backupConn) SetReadDeadline(time.Time) error {
   206  	return nil
   207  }
   208  
   209  func (b *backupConn) SetWriteDeadline(time.Time) error {
   210  	return nil
   211  }