github.com/pion/webrtc/v4@v4.0.1/datachannel_go_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/rand"
    12  	"encoding/binary"
    13  	"io"
    14  	"math/big"
    15  	"reflect"
    16  	"regexp"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/pion/datachannel"
    24  	"github.com/pion/logging"
    25  	"github.com/pion/transport/v3/test"
    26  	"github.com/stretchr/testify/assert"
    27  )
    28  
    29  func TestDataChannel_EventHandlers(t *testing.T) {
    30  	to := test.TimeOut(time.Second * 20)
    31  	defer to.Stop()
    32  
    33  	report := test.CheckRoutines(t)
    34  	defer report()
    35  
    36  	api := NewAPI()
    37  	dc := &DataChannel{api: api}
    38  
    39  	onDialCalled := make(chan struct{})
    40  	onOpenCalled := make(chan struct{})
    41  	onMessageCalled := make(chan struct{})
    42  
    43  	// Verify that the noop case works
    44  	assert.NotPanics(t, func() { dc.onOpen() })
    45  
    46  	dc.OnDial(func() {
    47  		close(onDialCalled)
    48  	})
    49  
    50  	dc.OnOpen(func() {
    51  		close(onOpenCalled)
    52  	})
    53  
    54  	dc.OnMessage(func(DataChannelMessage) {
    55  		close(onMessageCalled)
    56  	})
    57  
    58  	// Verify that the set handlers are called
    59  	assert.NotPanics(t, func() { dc.onDial() })
    60  	assert.NotPanics(t, func() { dc.onOpen() })
    61  	assert.NotPanics(t, func() { dc.onMessage(DataChannelMessage{Data: []byte("o hai")}) })
    62  
    63  	// Wait for all handlers to be called
    64  	<-onDialCalled
    65  	<-onOpenCalled
    66  	<-onMessageCalled
    67  }
    68  
    69  func TestDataChannel_MessagesAreOrdered(t *testing.T) {
    70  	report := test.CheckRoutines(t)
    71  	defer report()
    72  
    73  	api := NewAPI()
    74  	dc := &DataChannel{api: api}
    75  
    76  	maxVal := 512
    77  	out := make(chan int)
    78  	inner := func(msg DataChannelMessage) {
    79  		// randomly sleep
    80  		// math/rand a weak RNG, but this does not need to be secure. Ignore with #nosec
    81  		/* #nosec */
    82  		randInt, err := rand.Int(rand.Reader, big.NewInt(int64(maxVal)))
    83  		/* #nosec */ if err != nil {
    84  			t.Fatalf("Failed to get random sleep duration: %s", err)
    85  		}
    86  		time.Sleep(time.Duration(randInt.Int64()) * time.Microsecond)
    87  		s, _ := binary.Varint(msg.Data)
    88  		out <- int(s)
    89  	}
    90  	dc.OnMessage(func(p DataChannelMessage) {
    91  		inner(p)
    92  	})
    93  
    94  	go func() {
    95  		for i := 1; i <= maxVal; i++ {
    96  			buf := make([]byte, 8)
    97  			binary.PutVarint(buf, int64(i))
    98  			dc.onMessage(DataChannelMessage{Data: buf})
    99  			// Change the registered handler a couple of times to make sure
   100  			// that everything continues to work, we don't lose messages, etc.
   101  			if i%2 == 0 {
   102  				handler := func(msg DataChannelMessage) {
   103  					inner(msg)
   104  				}
   105  				dc.OnMessage(handler)
   106  			}
   107  		}
   108  	}()
   109  
   110  	values := make([]int, 0, maxVal)
   111  	for v := range out {
   112  		values = append(values, v)
   113  		if len(values) == maxVal {
   114  			close(out)
   115  		}
   116  	}
   117  
   118  	expected := make([]int, maxVal)
   119  	for i := 1; i <= maxVal; i++ {
   120  		expected[i-1] = i
   121  	}
   122  	assert.EqualValues(t, expected, values)
   123  }
   124  
   125  // Note(albrow): This test includes some features that aren't supported by the
   126  // Wasm bindings (at least for now).
   127  func TestDataChannelParamters_Go(t *testing.T) {
   128  	report := test.CheckRoutines(t)
   129  	defer report()
   130  
   131  	t.Run("MaxPacketLifeTime exchange", func(t *testing.T) {
   132  		ordered := true
   133  		var maxPacketLifeTime uint16 = 3
   134  		options := &DataChannelInit{
   135  			Ordered:           &ordered,
   136  			MaxPacketLifeTime: &maxPacketLifeTime,
   137  		}
   138  
   139  		offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options)
   140  
   141  		// Check if parameters are correctly set
   142  		assert.True(t, dc.Ordered(), "Ordered should be set to true")
   143  		if assert.NotNil(t, dc.MaxPacketLifeTime(), "should not be nil") {
   144  			assert.Equal(t, maxPacketLifeTime, *dc.MaxPacketLifeTime(), "should match")
   145  		}
   146  
   147  		answerPC.OnDataChannel(func(d *DataChannel) {
   148  			// Make sure this is the data channel we were looking for. (Not the one
   149  			// created in signalPair).
   150  			if d.Label() != expectedLabel {
   151  				return
   152  			}
   153  
   154  			// Check if parameters are correctly set
   155  			assert.True(t, d.ordered, "Ordered should be set to true")
   156  			if assert.NotNil(t, d.maxPacketLifeTime, "should not be nil") {
   157  				assert.Equal(t, maxPacketLifeTime, *d.maxPacketLifeTime, "should match")
   158  			}
   159  			done <- true
   160  		})
   161  
   162  		closeReliabilityParamTest(t, offerPC, answerPC, done)
   163  	})
   164  
   165  	t.Run("All other property methods", func(t *testing.T) {
   166  		id := uint16(123)
   167  		dc := &DataChannel{}
   168  		dc.id = &id
   169  		dc.label = "mylabel"
   170  		dc.protocol = "myprotocol"
   171  		dc.negotiated = true
   172  
   173  		assert.Equal(t, dc.id, dc.ID(), "should match")
   174  		assert.Equal(t, dc.label, dc.Label(), "should match")
   175  		assert.Equal(t, dc.protocol, dc.Protocol(), "should match")
   176  		assert.Equal(t, dc.negotiated, dc.Negotiated(), "should match")
   177  		assert.Equal(t, uint64(0), dc.BufferedAmount(), "should match")
   178  		dc.SetBufferedAmountLowThreshold(1500)
   179  		assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "should match")
   180  	})
   181  }
   182  
   183  func TestDataChannelBufferedAmount(t *testing.T) {
   184  	t.Run("set before datachannel becomes open", func(t *testing.T) {
   185  		report := test.CheckRoutines(t)
   186  		defer report()
   187  
   188  		var nOfferBufferedAmountLowCbs uint32
   189  		var offerBufferedAmountLowThreshold uint64 = 1500
   190  		var nAnswerBufferedAmountLowCbs uint32
   191  		var answerBufferedAmountLowThreshold uint64 = 1400
   192  
   193  		buf := make([]byte, 1000)
   194  
   195  		offerPC, answerPC, err := newPair()
   196  		if err != nil {
   197  			t.Fatalf("Failed to create a PC pair for testing")
   198  		}
   199  
   200  		nPacketsToSend := int(10)
   201  		var nOfferReceived uint32
   202  		var nAnswerReceived uint32
   203  
   204  		done := make(chan bool)
   205  
   206  		answerPC.OnDataChannel(func(answerDC *DataChannel) {
   207  			// Make sure this is the data channel we were looking for. (Not the one
   208  			// created in signalPair).
   209  			if answerDC.Label() != expectedLabel {
   210  				return
   211  			}
   212  
   213  			answerDC.OnOpen(func() {
   214  				assert.Equal(t, answerBufferedAmountLowThreshold, answerDC.BufferedAmountLowThreshold(), "value mismatch")
   215  
   216  				for i := 0; i < nPacketsToSend; i++ {
   217  					e := answerDC.Send(buf)
   218  					if e != nil {
   219  						t.Fatalf("Failed to send string on data channel")
   220  					}
   221  				}
   222  			})
   223  
   224  			answerDC.OnMessage(func(DataChannelMessage) {
   225  				atomic.AddUint32(&nAnswerReceived, 1)
   226  			})
   227  			assert.True(t, answerDC.Ordered(), "Ordered should be set to true")
   228  
   229  			// The value is temporarily stored in the answerDC object
   230  			// until the answerDC gets opened
   231  			answerDC.SetBufferedAmountLowThreshold(answerBufferedAmountLowThreshold)
   232  			// The callback function is temporarily stored in the answerDC object
   233  			// until the answerDC gets opened
   234  			answerDC.OnBufferedAmountLow(func() {
   235  				atomic.AddUint32(&nAnswerBufferedAmountLowCbs, 1)
   236  				if atomic.LoadUint32(&nOfferBufferedAmountLowCbs) > 0 {
   237  					done <- true
   238  				}
   239  			})
   240  		})
   241  
   242  		offerDC, err := offerPC.CreateDataChannel(expectedLabel, nil)
   243  		if err != nil {
   244  			t.Fatalf("Failed to create a PC pair for testing")
   245  		}
   246  
   247  		assert.True(t, offerDC.Ordered(), "Ordered should be set to true")
   248  
   249  		offerDC.OnOpen(func() {
   250  			assert.Equal(t, offerBufferedAmountLowThreshold, offerDC.BufferedAmountLowThreshold(), "value mismatch")
   251  
   252  			for i := 0; i < nPacketsToSend; i++ {
   253  				e := offerDC.Send(buf)
   254  				if e != nil {
   255  					t.Fatalf("Failed to send string on data channel")
   256  				}
   257  				// assert.Equal(t, (i+1)*len(buf), int(offerDC.BufferedAmount()), "unexpected bufferedAmount")
   258  			}
   259  		})
   260  
   261  		offerDC.OnMessage(func(DataChannelMessage) {
   262  			atomic.AddUint32(&nOfferReceived, 1)
   263  		})
   264  
   265  		// The value is temporarily stored in the offerDC object
   266  		// until the offerDC gets opened
   267  		offerDC.SetBufferedAmountLowThreshold(offerBufferedAmountLowThreshold)
   268  		// The callback function is temporarily stored in the offerDC object
   269  		// until the offerDC gets opened
   270  		offerDC.OnBufferedAmountLow(func() {
   271  			atomic.AddUint32(&nOfferBufferedAmountLowCbs, 1)
   272  			if atomic.LoadUint32(&nAnswerBufferedAmountLowCbs) > 0 {
   273  				done <- true
   274  			}
   275  		})
   276  
   277  		err = signalPair(offerPC, answerPC)
   278  		if err != nil {
   279  			t.Fatalf("Failed to signal our PC pair for testing")
   280  		}
   281  
   282  		closePair(t, offerPC, answerPC, done)
   283  
   284  		t.Logf("nOfferBufferedAmountLowCbs : %d", nOfferBufferedAmountLowCbs)
   285  		t.Logf("nAnswerBufferedAmountLowCbs: %d", nAnswerBufferedAmountLowCbs)
   286  		assert.True(t, nOfferBufferedAmountLowCbs > uint32(0), "callback should be made at least once")
   287  		assert.True(t, nAnswerBufferedAmountLowCbs > uint32(0), "callback should be made at least once")
   288  	})
   289  
   290  	t.Run("set after datachannel becomes open", func(t *testing.T) {
   291  		report := test.CheckRoutines(t)
   292  		defer report()
   293  
   294  		var nCbs int
   295  		buf := make([]byte, 1000)
   296  
   297  		offerPC, answerPC, err := newPair()
   298  		if err != nil {
   299  			t.Fatalf("Failed to create a PC pair for testing")
   300  		}
   301  
   302  		done := make(chan bool)
   303  
   304  		answerPC.OnDataChannel(func(d *DataChannel) {
   305  			// Make sure this is the data channel we were looking for. (Not the one
   306  			// created in signalPair).
   307  			if d.Label() != expectedLabel {
   308  				return
   309  			}
   310  			var nPacketsReceived int
   311  			d.OnMessage(func(DataChannelMessage) {
   312  				nPacketsReceived++
   313  
   314  				if nPacketsReceived == 10 {
   315  					go func() {
   316  						time.Sleep(time.Second)
   317  						done <- true
   318  					}()
   319  				}
   320  			})
   321  			assert.True(t, d.Ordered(), "Ordered should be set to true")
   322  		})
   323  
   324  		dc, err := offerPC.CreateDataChannel(expectedLabel, nil)
   325  		if err != nil {
   326  			t.Fatalf("Failed to create a PC pair for testing")
   327  		}
   328  
   329  		assert.True(t, dc.Ordered(), "Ordered should be set to true")
   330  
   331  		dc.OnOpen(func() {
   332  			// The value should directly be passed to sctp
   333  			dc.SetBufferedAmountLowThreshold(1500)
   334  			// The callback function should directly be passed to sctp
   335  			dc.OnBufferedAmountLow(func() {
   336  				nCbs++
   337  			})
   338  
   339  			for i := 0; i < 10; i++ {
   340  				e := dc.Send(buf)
   341  				if e != nil {
   342  					t.Fatalf("Failed to send string on data channel")
   343  				}
   344  				assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "value mismatch")
   345  				// assert.Equal(t, (i+1)*len(buf), int(dc.BufferedAmount()), "unexpected bufferedAmount")
   346  			}
   347  		})
   348  
   349  		dc.OnMessage(func(DataChannelMessage) {
   350  		})
   351  
   352  		err = signalPair(offerPC, answerPC)
   353  		if err != nil {
   354  			t.Fatalf("Failed to signal our PC pair for testing")
   355  		}
   356  
   357  		closePair(t, offerPC, answerPC, done)
   358  
   359  		assert.True(t, nCbs > 0, "callback should be made at least once")
   360  	})
   361  }
   362  
   363  func TestEOF(t *testing.T) {
   364  	report := test.CheckRoutines(t)
   365  	defer report()
   366  
   367  	log := logging.NewDefaultLoggerFactory().NewLogger("test")
   368  	label := "test-channel"
   369  	testData := []byte("this is some test data")
   370  
   371  	t.Run("Detach", func(t *testing.T) {
   372  		// Use Detach data channels mode
   373  		s := SettingEngine{}
   374  		s.DetachDataChannels()
   375  		api := NewAPI(WithSettingEngine(s))
   376  
   377  		// Set up two peer connections.
   378  		config := Configuration{}
   379  		pca, err := api.NewPeerConnection(config)
   380  		if err != nil {
   381  			t.Fatal(err)
   382  		}
   383  		pcb, err := api.NewPeerConnection(config)
   384  		if err != nil {
   385  			t.Fatal(err)
   386  		}
   387  
   388  		defer closePairNow(t, pca, pcb)
   389  
   390  		var wg sync.WaitGroup
   391  
   392  		dcChan := make(chan datachannel.ReadWriteCloser)
   393  		pcb.OnDataChannel(func(dc *DataChannel) {
   394  			if dc.Label() != label {
   395  				return
   396  			}
   397  			log.Debug("OnDataChannel was called")
   398  			dc.OnOpen(func() {
   399  				detached, err2 := dc.Detach()
   400  				if err2 != nil {
   401  					log.Debugf("Detach failed: %s", err2.Error())
   402  					t.Error(err2)
   403  				}
   404  
   405  				dcChan <- detached
   406  			})
   407  		})
   408  
   409  		wg.Add(1)
   410  		go func() {
   411  			defer wg.Done()
   412  
   413  			var msg []byte
   414  
   415  			log.Debug("Waiting for OnDataChannel")
   416  			dc := <-dcChan
   417  			log.Debug("data channel opened")
   418  			defer func() { assert.NoError(t, dc.Close(), "should succeed") }()
   419  
   420  			log.Debug("Waiting for ping...")
   421  			msg, err2 := io.ReadAll(dc)
   422  			log.Debugf("Received ping! \"%s\"", string(msg))
   423  			if err2 != nil {
   424  				t.Error(err2)
   425  			}
   426  
   427  			if !bytes.Equal(msg, testData) {
   428  				t.Errorf("expected %q, got %q", string(msg), string(testData))
   429  			} else {
   430  				log.Debug("Received ping successfully!")
   431  			}
   432  		}()
   433  
   434  		if err = signalPair(pca, pcb); err != nil {
   435  			t.Fatal(err)
   436  		}
   437  
   438  		attached, err := pca.CreateDataChannel(label, nil)
   439  		if err != nil {
   440  			t.Fatal(err)
   441  		}
   442  		log.Debug("Waiting for data channel to open")
   443  		open := make(chan struct{})
   444  		attached.OnOpen(func() {
   445  			open <- struct{}{}
   446  		})
   447  		<-open
   448  		log.Debug("data channel opened")
   449  
   450  		var dc io.ReadWriteCloser
   451  		dc, err = attached.Detach()
   452  		if err != nil {
   453  			t.Fatal(err)
   454  		}
   455  
   456  		wg.Add(1)
   457  		go func() {
   458  			defer wg.Done()
   459  			log.Debug("Sending ping...")
   460  			if _, err2 := dc.Write(testData); err2 != nil {
   461  				t.Error(err2)
   462  			}
   463  			log.Debug("Sent ping")
   464  
   465  			assert.NoError(t, dc.Close(), "should succeed")
   466  
   467  			log.Debug("Wating for EOF")
   468  			ret, err2 := io.ReadAll(dc)
   469  			assert.Nil(t, err2, "should succeed")
   470  			assert.Equal(t, 0, len(ret), "should be empty")
   471  		}()
   472  
   473  		wg.Wait()
   474  	})
   475  
   476  	t.Run("No detach", func(t *testing.T) {
   477  		lim := test.TimeOut(time.Second * 5)
   478  		defer lim.Stop()
   479  
   480  		// Set up two peer connections.
   481  		config := Configuration{}
   482  		pca, err := NewPeerConnection(config)
   483  		if err != nil {
   484  			t.Fatal(err)
   485  		}
   486  		pcb, err := NewPeerConnection(config)
   487  		if err != nil {
   488  			t.Fatal(err)
   489  		}
   490  
   491  		defer closePairNow(t, pca, pcb)
   492  
   493  		var dca, dcb *DataChannel
   494  		dcaClosedCh := make(chan struct{})
   495  		dcbClosedCh := make(chan struct{})
   496  
   497  		pcb.OnDataChannel(func(dc *DataChannel) {
   498  			if dc.Label() != label {
   499  				return
   500  			}
   501  
   502  			log.Debugf("pcb: new datachannel: %s", dc.Label())
   503  
   504  			dcb = dc
   505  			// Register channel opening handling
   506  			dcb.OnOpen(func() {
   507  				log.Debug("pcb: datachannel opened")
   508  			})
   509  
   510  			dcb.OnClose(func() {
   511  				// (2)
   512  				log.Debug("pcb: data channel closed")
   513  				close(dcbClosedCh)
   514  			})
   515  
   516  			// Register the OnMessage to handle incoming messages
   517  			log.Debug("pcb: registering onMessage callback")
   518  			dcb.OnMessage(func(dcMsg DataChannelMessage) {
   519  				log.Debugf("pcb: received ping: %s", string(dcMsg.Data))
   520  				if !reflect.DeepEqual(dcMsg.Data, testData) {
   521  					t.Error("data mismatch")
   522  				}
   523  			})
   524  		})
   525  
   526  		dca, err = pca.CreateDataChannel(label, nil)
   527  		if err != nil {
   528  			t.Fatal(err)
   529  		}
   530  
   531  		dca.OnOpen(func() {
   532  			log.Debug("pca: data channel opened")
   533  			log.Debugf("pca: sending \"%s\"", string(testData))
   534  			if err := dca.Send(testData); err != nil {
   535  				t.Fatal(err)
   536  			}
   537  			log.Debug("pca: sent ping")
   538  			assert.NoError(t, dca.Close(), "should succeed") // <-- dca closes
   539  		})
   540  
   541  		dca.OnClose(func() {
   542  			// (1)
   543  			log.Debug("pca: data channel closed")
   544  			close(dcaClosedCh)
   545  		})
   546  
   547  		// Register the OnMessage to handle incoming messages
   548  		log.Debug("pca: registering onMessage callback")
   549  		dca.OnMessage(func(dcMsg DataChannelMessage) {
   550  			log.Debugf("pca: received pong: %s", string(dcMsg.Data))
   551  			if !reflect.DeepEqual(dcMsg.Data, testData) {
   552  				t.Error("data mismatch")
   553  			}
   554  		})
   555  
   556  		if err := signalPair(pca, pcb); err != nil {
   557  			t.Fatal(err)
   558  		}
   559  
   560  		// When dca closes the channel,
   561  		// (1) dca.Onclose() will fire immediately, then
   562  		// (2) dcb.OnClose will also fire
   563  		<-dcaClosedCh // (1)
   564  		<-dcbClosedCh // (2)
   565  	})
   566  }
   567  
   568  // Assert that a Session Description that doesn't follow
   569  // draft-ietf-mmusic-sctp-sdp is still accepted
   570  func TestDataChannel_NonStandardSessionDescription(t *testing.T) {
   571  	to := test.TimeOut(time.Second * 20)
   572  	defer to.Stop()
   573  
   574  	report := test.CheckRoutines(t)
   575  	defer report()
   576  
   577  	offerPC, answerPC, err := newPair()
   578  	assert.NoError(t, err)
   579  
   580  	_, err = offerPC.CreateDataChannel("foo", nil)
   581  	assert.NoError(t, err)
   582  
   583  	onDataChannelCalled := make(chan struct{})
   584  	answerPC.OnDataChannel(func(_ *DataChannel) {
   585  		close(onDataChannelCalled)
   586  	})
   587  
   588  	offer, err := offerPC.CreateOffer(nil)
   589  	assert.NoError(t, err)
   590  
   591  	offerGatheringComplete := GatheringCompletePromise(offerPC)
   592  	assert.NoError(t, offerPC.SetLocalDescription(offer))
   593  	<-offerGatheringComplete
   594  
   595  	offer = *offerPC.LocalDescription()
   596  
   597  	// Replace with old values
   598  	const (
   599  		oldApplication = "m=application 63743 DTLS/SCTP 5000\r"
   600  		oldAttribute   = "a=sctpmap:5000 webrtc-datachannel 256\r"
   601  	)
   602  
   603  	offer.SDP = regexp.MustCompile(`m=application (.*?)\r`).ReplaceAllString(offer.SDP, oldApplication)
   604  	offer.SDP = regexp.MustCompile(`a=sctp-port(.*?)\r`).ReplaceAllString(offer.SDP, oldAttribute)
   605  
   606  	// Assert that replace worked
   607  	assert.True(t, strings.Contains(offer.SDP, oldApplication))
   608  	assert.True(t, strings.Contains(offer.SDP, oldAttribute))
   609  
   610  	assert.NoError(t, answerPC.SetRemoteDescription(offer))
   611  
   612  	answer, err := answerPC.CreateAnswer(nil)
   613  	assert.NoError(t, err)
   614  
   615  	answerGatheringComplete := GatheringCompletePromise(answerPC)
   616  	assert.NoError(t, answerPC.SetLocalDescription(answer))
   617  	<-answerGatheringComplete
   618  	assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription()))
   619  
   620  	<-onDataChannelCalled
   621  	closePairNow(t, offerPC, answerPC)
   622  }
   623  
   624  func TestDataChannel_Dial(t *testing.T) {
   625  	t.Run("handler should be called once, by dialing peer only", func(t *testing.T) {
   626  		report := test.CheckRoutines(t)
   627  		defer report()
   628  
   629  		dialCalls := make(chan bool, 2)
   630  		wg := new(sync.WaitGroup)
   631  		wg.Add(2)
   632  
   633  		offerPC, answerPC, err := newPair()
   634  		if err != nil {
   635  			t.Fatalf("Failed to create a PC pair for testing")
   636  		}
   637  
   638  		answerPC.OnDataChannel(func(d *DataChannel) {
   639  			if d.Label() != expectedLabel {
   640  				return
   641  			}
   642  
   643  			d.OnDial(func() {
   644  				// only dialing side should fire OnDial
   645  				t.Fatalf("answering side should not call on dial")
   646  			})
   647  
   648  			d.OnOpen(wg.Done)
   649  		})
   650  
   651  		d, err := offerPC.CreateDataChannel(expectedLabel, nil)
   652  		assert.NoError(t, err)
   653  		d.OnDial(func() {
   654  			dialCalls <- true
   655  			wg.Done()
   656  		})
   657  
   658  		assert.NoError(t, signalPair(offerPC, answerPC))
   659  
   660  		wg.Wait()
   661  		closePairNow(t, offerPC, answerPC)
   662  
   663  		assert.Len(t, dialCalls, 1)
   664  	})
   665  
   666  	t.Run("handler should be called immediately if already dialed", func(t *testing.T) {
   667  		report := test.CheckRoutines(t)
   668  		defer report()
   669  
   670  		done := make(chan bool)
   671  
   672  		offerPC, answerPC, err := newPair()
   673  		if err != nil {
   674  			t.Fatalf("Failed to create a PC pair for testing")
   675  		}
   676  
   677  		d, err := offerPC.CreateDataChannel(expectedLabel, nil)
   678  		assert.NoError(t, err)
   679  		d.OnOpen(func() {
   680  			// when the offer DC has been opened, its guaranteed to have dialed since it has
   681  			// received a response to said dial. this test represents an unrealistic usage,
   682  			// but its the best way to guarantee we "missed" the dial event and still invoke
   683  			// the handler.
   684  			d.OnDial(func() {
   685  				done <- true
   686  			})
   687  		})
   688  
   689  		assert.NoError(t, signalPair(offerPC, answerPC))
   690  
   691  		closePair(t, offerPC, answerPC, done)
   692  	})
   693  }
   694  
   695  func TestDetachRemovesDatachannelReference(t *testing.T) {
   696  	// Use Detach data channels mode
   697  	s := SettingEngine{}
   698  	s.DetachDataChannels()
   699  	api := NewAPI(WithSettingEngine(s))
   700  
   701  	// Set up two peer connections.
   702  	config := Configuration{}
   703  	pca, err := api.NewPeerConnection(config)
   704  	if err != nil {
   705  		t.Fatal(err)
   706  	}
   707  	pcb, err := api.NewPeerConnection(config)
   708  	if err != nil {
   709  		t.Fatal(err)
   710  	}
   711  
   712  	defer closePairNow(t, pca, pcb)
   713  
   714  	dcChan := make(chan *DataChannel, 1)
   715  	pcb.OnDataChannel(func(d *DataChannel) {
   716  		d.OnOpen(func() {
   717  			if _, detachErr := d.Detach(); detachErr != nil {
   718  				t.Error(detachErr)
   719  			}
   720  
   721  			dcChan <- d
   722  		})
   723  	})
   724  
   725  	if err = signalPair(pca, pcb); err != nil {
   726  		t.Fatal(err)
   727  	}
   728  
   729  	attached, err := pca.CreateDataChannel("", nil)
   730  	if err != nil {
   731  		t.Fatal(err)
   732  	}
   733  	open := make(chan struct{}, 1)
   734  	attached.OnOpen(func() {
   735  		open <- struct{}{}
   736  	})
   737  	<-open
   738  
   739  	d := <-dcChan
   740  	d.sctpTransport.lock.RLock()
   741  	defer d.sctpTransport.lock.RUnlock()
   742  	for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
   743  		if dc == d {
   744  			t.Errorf("expected sctpTransport to drop reference to datachannel")
   745  		}
   746  	}
   747  }