github.com/pion/webrtc/v4@v4.0.1/internal/mux/mux_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package mux
     5  
     6  import (
     7  	"io"
     8  	"net"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/pion/logging"
    13  	"github.com/pion/transport/v3/packetio"
    14  	"github.com/pion/transport/v3/test"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  const testPipeBufferSize = 8192
    19  
    20  func TestNoEndpoints(t *testing.T) {
    21  	// In memory pipe
    22  	ca, cb := net.Pipe()
    23  	require.NoError(t, cb.Close())
    24  
    25  	m := NewMux(Config{
    26  		Conn:          ca,
    27  		BufferSize:    testPipeBufferSize,
    28  		LoggerFactory: logging.NewDefaultLoggerFactory(),
    29  	})
    30  	require.NoError(t, m.dispatch(make([]byte, 1)))
    31  	require.NoError(t, m.Close())
    32  	require.NoError(t, ca.Close())
    33  }
    34  
    35  type muxErrorConnReadResult struct {
    36  	err  error
    37  	data []byte
    38  }
    39  
    40  // muxErrorConn
    41  type muxErrorConn struct {
    42  	net.Conn
    43  	readResults []muxErrorConnReadResult
    44  }
    45  
    46  func (m *muxErrorConn) Read(b []byte) (n int, err error) {
    47  	err = m.readResults[0].err
    48  	copy(b, m.readResults[0].data)
    49  	n = len(m.readResults[0].data)
    50  
    51  	m.readResults = m.readResults[1:]
    52  	return
    53  }
    54  
    55  /*
    56  Don't end the mux readLoop for packetio.ErrTimeout or io.ErrShortBuffer, assert the following
    57  
    58    - io.ErrShortBuffer and packetio.ErrTimeout don't end the read loop
    59  
    60    - io.EOF ends the loop
    61  
    62      pion/webrtc#1720
    63  */
    64  func TestNonFatalRead(t *testing.T) {
    65  	// Limit runtime in case of deadlocks
    66  	lim := test.TimeOut(time.Second * 20)
    67  	defer lim.Stop()
    68  
    69  	expectedData := []byte("expectedData")
    70  
    71  	// In memory pipe
    72  	ca, cb := net.Pipe()
    73  	require.NoError(t, cb.Close())
    74  
    75  	conn := &muxErrorConn{ca, []muxErrorConnReadResult{
    76  		// Non-fatal timeout error
    77  		{packetio.ErrTimeout, nil},
    78  		{nil, expectedData},
    79  		{io.ErrShortBuffer, nil},
    80  		{nil, expectedData},
    81  		{io.EOF, nil},
    82  	}}
    83  
    84  	m := NewMux(Config{
    85  		Conn:          conn,
    86  		BufferSize:    testPipeBufferSize,
    87  		LoggerFactory: logging.NewDefaultLoggerFactory(),
    88  	})
    89  
    90  	e := m.NewEndpoint(MatchAll)
    91  
    92  	buff := make([]byte, testPipeBufferSize)
    93  	n, err := e.Read(buff)
    94  	require.NoError(t, err)
    95  	require.Equal(t, buff[:n], expectedData)
    96  
    97  	n, err = e.Read(buff)
    98  	require.NoError(t, err)
    99  	require.Equal(t, buff[:n], expectedData)
   100  
   101  	<-m.closedCh
   102  	require.NoError(t, m.Close())
   103  	require.NoError(t, ca.Close())
   104  }
   105  
   106  // If a endpoint returns packetio.ErrFull it is a non-fatal error and shouldn't cause
   107  // the mux to be destroyed
   108  // pion/webrtc#2180
   109  func TestNonFatalDispatch(t *testing.T) {
   110  	in, out := net.Pipe()
   111  
   112  	m := NewMux(Config{
   113  		Conn:          out,
   114  		LoggerFactory: logging.NewDefaultLoggerFactory(),
   115  		BufferSize:    1500,
   116  	})
   117  
   118  	e := m.NewEndpoint(MatchSRTP)
   119  	e.buffer.SetLimitSize(1)
   120  
   121  	for i := 0; i <= 25; i++ {
   122  		srtpPacket := []byte{128, 1, 2, 3, 4}
   123  		_, err := in.Write(srtpPacket)
   124  		require.NoError(t, err)
   125  	}
   126  
   127  	require.NoError(t, m.Close())
   128  	require.NoError(t, in.Close())
   129  	require.NoError(t, out.Close())
   130  }
   131  
   132  func BenchmarkDispatch(b *testing.B) {
   133  	m := &Mux{
   134  		endpoints: make(map[*Endpoint]MatchFunc),
   135  		log:       logging.NewDefaultLoggerFactory().NewLogger("mux"),
   136  	}
   137  
   138  	e := m.NewEndpoint(MatchSRTP)
   139  	m.NewEndpoint(MatchSRTCP)
   140  
   141  	buf := []byte{128, 1, 2, 3, 4}
   142  	buf2 := make([]byte, 1200)
   143  
   144  	b.StartTimer()
   145  
   146  	for i := 0; i < b.N; i++ {
   147  		err := m.dispatch(buf)
   148  		if err != nil {
   149  			b.Errorf("dispatch: %v", err)
   150  		}
   151  		_, err = e.buffer.Read(buf2)
   152  		if err != nil {
   153  			b.Errorf("read: %v", err)
   154  		}
   155  	}
   156  }
   157  
   158  func TestPendingQueue(t *testing.T) {
   159  	factory := logging.NewDefaultLoggerFactory()
   160  	factory.DefaultLogLevel = logging.LogLevelDebug
   161  	m := &Mux{
   162  		endpoints: make(map[*Endpoint]MatchFunc),
   163  		log:       factory.NewLogger("mux"),
   164  	}
   165  
   166  	// Assert empty packets don't end up in queue
   167  	require.NoError(t, m.dispatch([]byte{}))
   168  	require.Equal(t, len(m.pendingPackets), 0)
   169  
   170  	// Test Happy Case
   171  	inBuffer := []byte{20, 1, 2, 3, 4}
   172  	outBuffer := make([]byte, len(inBuffer))
   173  
   174  	require.NoError(t, m.dispatch(inBuffer))
   175  
   176  	endpoint := m.NewEndpoint(MatchDTLS)
   177  	require.NotNil(t, endpoint)
   178  
   179  	_, err := endpoint.Read(outBuffer)
   180  	require.NoError(t, err)
   181  
   182  	require.Equal(t, outBuffer, inBuffer)
   183  
   184  	// Assert limit on pendingPackets
   185  	for i := 0; i <= 100; i++ {
   186  		require.NoError(t, m.dispatch([]byte{64, 65, 66}))
   187  	}
   188  	require.Equal(t, len(m.pendingPackets), maxPendingPackets)
   189  }