github.com/pion/webrtc/v4@v4.0.1/sctptransport_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"bytes"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestGenerateDataChannelID(t *testing.T) {
    18  	sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
    19  		ret := &SCTPTransport{
    20  			dataChannels:       []*DataChannel{},
    21  			dataChannelIDsUsed: make(map[uint16]struct{}),
    22  		}
    23  
    24  		for i := range ids {
    25  			id := ids[i]
    26  			ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
    27  			ret.dataChannelIDsUsed[id] = struct{}{}
    28  		}
    29  
    30  		return ret
    31  	}
    32  
    33  	testCases := []struct {
    34  		role   DTLSRole
    35  		s      *SCTPTransport
    36  		result uint16
    37  	}{
    38  		{DTLSRoleClient, sctpTransportWithChannels([]uint16{}), 0},
    39  		{DTLSRoleClient, sctpTransportWithChannels([]uint16{1}), 0},
    40  		{DTLSRoleClient, sctpTransportWithChannels([]uint16{0}), 2},
    41  		{DTLSRoleClient, sctpTransportWithChannels([]uint16{0, 2}), 4},
    42  		{DTLSRoleClient, sctpTransportWithChannels([]uint16{0, 4}), 2},
    43  		{DTLSRoleServer, sctpTransportWithChannels([]uint16{}), 1},
    44  		{DTLSRoleServer, sctpTransportWithChannels([]uint16{0}), 1},
    45  		{DTLSRoleServer, sctpTransportWithChannels([]uint16{1}), 3},
    46  		{DTLSRoleServer, sctpTransportWithChannels([]uint16{1, 3}), 5},
    47  		{DTLSRoleServer, sctpTransportWithChannels([]uint16{1, 5}), 3},
    48  	}
    49  	for _, testCase := range testCases {
    50  		idPtr := new(uint16)
    51  		err := testCase.s.generateAndSetDataChannelID(testCase.role, &idPtr)
    52  		if err != nil {
    53  			t.Errorf("failed to generate id: %v", err)
    54  			return
    55  		}
    56  		if *idPtr != testCase.result {
    57  			t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
    58  		}
    59  		if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
    60  			t.Errorf("expected new id to be added to the map: %d", *idPtr)
    61  		}
    62  	}
    63  }
    64  
    65  func TestSCTPTransportOnClose(t *testing.T) {
    66  	offerPC, answerPC, err := newPair()
    67  	require.NoError(t, err)
    68  
    69  	defer closePairNow(t, offerPC, answerPC)
    70  
    71  	answerPC.OnDataChannel(func(dc *DataChannel) {
    72  		dc.OnMessage(func(_ DataChannelMessage) {
    73  			if err1 := dc.Send([]byte("hello")); err1 != nil {
    74  				t.Error("failed to send message")
    75  			}
    76  		})
    77  	})
    78  
    79  	recvMsg := make(chan struct{}, 1)
    80  	offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
    81  		if state == PeerConnectionStateConnected {
    82  			defer func() {
    83  				offerPC.OnConnectionStateChange(nil)
    84  			}()
    85  
    86  			dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
    87  			if createErr != nil {
    88  				t.Errorf("Failed to create a PC pair for testing")
    89  				return
    90  			}
    91  			dc.OnMessage(func(msg DataChannelMessage) {
    92  				if !bytes.Equal(msg.Data, []byte("hello")) {
    93  					t.Error("invalid msg received")
    94  				}
    95  				recvMsg <- struct{}{}
    96  			})
    97  			dc.OnOpen(func() {
    98  				if err1 := dc.Send([]byte("hello")); err1 != nil {
    99  					t.Error("failed to send initial msg", err1)
   100  				}
   101  			})
   102  		}
   103  	})
   104  
   105  	err = signalPair(offerPC, answerPC)
   106  	require.NoError(t, err)
   107  
   108  	select {
   109  	case <-recvMsg:
   110  	case <-time.After(5 * time.Second):
   111  		t.Fatal("timed out")
   112  	}
   113  
   114  	// setup SCTP OnClose callback
   115  	ch := make(chan error, 1)
   116  	answerPC.SCTP().OnClose(func(err error) {
   117  		ch <- err
   118  	})
   119  
   120  	err = offerPC.Close() // This will trigger sctp onclose callback on remote
   121  	require.NoError(t, err)
   122  
   123  	select {
   124  	case <-ch:
   125  	case <-time.After(5 * time.Second):
   126  		t.Fatal("timed out")
   127  	}
   128  }