github.com/pion/webrtc/v3@v3.2.24/datachannel_go_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/rand"
    12  	"encoding/binary"
    13  	"io"
    14  	"io/ioutil"
    15  	"math/big"
    16  	"reflect"
    17  	"regexp"
    18  	"strings"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/pion/datachannel"
    25  	"github.com/pion/logging"
    26  	"github.com/pion/transport/v2/test"
    27  	"github.com/stretchr/testify/assert"
    28  )
    29  
    30  func TestDataChannel_EventHandlers(t *testing.T) {
    31  	to := test.TimeOut(time.Second * 20)
    32  	defer to.Stop()
    33  
    34  	report := test.CheckRoutines(t)
    35  	defer report()
    36  
    37  	api := NewAPI()
    38  	dc := &DataChannel{api: api}
    39  
    40  	onDialCalled := make(chan struct{})
    41  	onOpenCalled := make(chan struct{})
    42  	onMessageCalled := make(chan struct{})
    43  
    44  	// Verify that the noop case works
    45  	assert.NotPanics(t, func() { dc.onOpen() })
    46  
    47  	dc.OnDial(func() {
    48  		close(onDialCalled)
    49  	})
    50  
    51  	dc.OnOpen(func() {
    52  		close(onOpenCalled)
    53  	})
    54  
    55  	dc.OnMessage(func(p DataChannelMessage) {
    56  		close(onMessageCalled)
    57  	})
    58  
    59  	// Verify that the set handlers are called
    60  	assert.NotPanics(t, func() { dc.onDial() })
    61  	assert.NotPanics(t, func() { dc.onOpen() })
    62  	assert.NotPanics(t, func() { dc.onMessage(DataChannelMessage{Data: []byte("o hai")}) })
    63  
    64  	// Wait for all handlers to be called
    65  	<-onDialCalled
    66  	<-onOpenCalled
    67  	<-onMessageCalled
    68  }
    69  
    70  func TestDataChannel_MessagesAreOrdered(t *testing.T) {
    71  	report := test.CheckRoutines(t)
    72  	defer report()
    73  
    74  	api := NewAPI()
    75  	dc := &DataChannel{api: api}
    76  
    77  	max := 512
    78  	out := make(chan int)
    79  	inner := func(msg DataChannelMessage) {
    80  		// randomly sleep
    81  		// math/rand a weak RNG, but this does not need to be secure. Ignore with #nosec
    82  		/* #nosec */
    83  		randInt, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
    84  		/* #nosec */ if err != nil {
    85  			t.Fatalf("Failed to get random sleep duration: %s", err)
    86  		}
    87  		time.Sleep(time.Duration(randInt.Int64()) * time.Microsecond)
    88  		s, _ := binary.Varint(msg.Data)
    89  		out <- int(s)
    90  	}
    91  	dc.OnMessage(func(p DataChannelMessage) {
    92  		inner(p)
    93  	})
    94  
    95  	go func() {
    96  		for i := 1; i <= max; i++ {
    97  			buf := make([]byte, 8)
    98  			binary.PutVarint(buf, int64(i))
    99  			dc.onMessage(DataChannelMessage{Data: buf})
   100  			// Change the registered handler a couple of times to make sure
   101  			// that everything continues to work, we don't lose messages, etc.
   102  			if i%2 == 0 {
   103  				handler := func(msg DataChannelMessage) {
   104  					inner(msg)
   105  				}
   106  				dc.OnMessage(handler)
   107  			}
   108  		}
   109  	}()
   110  
   111  	values := make([]int, 0, max)
   112  	for v := range out {
   113  		values = append(values, v)
   114  		if len(values) == max {
   115  			close(out)
   116  		}
   117  	}
   118  
   119  	expected := make([]int, max)
   120  	for i := 1; i <= max; i++ {
   121  		expected[i-1] = i
   122  	}
   123  	assert.EqualValues(t, expected, values)
   124  }
   125  
   126  // Note(albrow): This test includes some features that aren't supported by the
   127  // Wasm bindings (at least for now).
   128  func TestDataChannelParamters_Go(t *testing.T) {
   129  	report := test.CheckRoutines(t)
   130  	defer report()
   131  
   132  	t.Run("MaxPacketLifeTime exchange", func(t *testing.T) {
   133  		ordered := true
   134  		var maxPacketLifeTime uint16 = 3
   135  		options := &DataChannelInit{
   136  			Ordered:           &ordered,
   137  			MaxPacketLifeTime: &maxPacketLifeTime,
   138  		}
   139  
   140  		offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options)
   141  
   142  		// Check if parameters are correctly set
   143  		assert.True(t, dc.Ordered(), "Ordered should be set to true")
   144  		if assert.NotNil(t, dc.MaxPacketLifeTime(), "should not be nil") {
   145  			assert.Equal(t, maxPacketLifeTime, *dc.MaxPacketLifeTime(), "should match")
   146  		}
   147  
   148  		answerPC.OnDataChannel(func(d *DataChannel) {
   149  			// Make sure this is the data channel we were looking for. (Not the one
   150  			// created in signalPair).
   151  			if d.Label() != expectedLabel {
   152  				return
   153  			}
   154  
   155  			// Check if parameters are correctly set
   156  			assert.True(t, d.ordered, "Ordered should be set to true")
   157  			if assert.NotNil(t, d.maxPacketLifeTime, "should not be nil") {
   158  				assert.Equal(t, maxPacketLifeTime, *d.maxPacketLifeTime, "should match")
   159  			}
   160  			done <- true
   161  		})
   162  
   163  		closeReliabilityParamTest(t, offerPC, answerPC, done)
   164  	})
   165  
   166  	t.Run("All other property methods", func(t *testing.T) {
   167  		id := uint16(123)
   168  		dc := &DataChannel{}
   169  		dc.id = &id
   170  		dc.label = "mylabel"
   171  		dc.protocol = "myprotocol"
   172  		dc.negotiated = true
   173  
   174  		assert.Equal(t, dc.id, dc.ID(), "should match")
   175  		assert.Equal(t, dc.label, dc.Label(), "should match")
   176  		assert.Equal(t, dc.protocol, dc.Protocol(), "should match")
   177  		assert.Equal(t, dc.negotiated, dc.Negotiated(), "should match")
   178  		assert.Equal(t, uint64(0), dc.BufferedAmount(), "should match")
   179  		dc.SetBufferedAmountLowThreshold(1500)
   180  		assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "should match")
   181  	})
   182  }
   183  
   184  func TestDataChannelBufferedAmount(t *testing.T) {
   185  	t.Run("set before datachannel becomes open", func(t *testing.T) {
   186  		report := test.CheckRoutines(t)
   187  		defer report()
   188  
   189  		var nOfferBufferedAmountLowCbs uint32
   190  		var offerBufferedAmountLowThreshold uint64 = 1500
   191  		var nAnswerBufferedAmountLowCbs uint32
   192  		var answerBufferedAmountLowThreshold uint64 = 1400
   193  
   194  		buf := make([]byte, 1000)
   195  
   196  		offerPC, answerPC, err := newPair()
   197  		if err != nil {
   198  			t.Fatalf("Failed to create a PC pair for testing")
   199  		}
   200  
   201  		nPacketsToSend := int(10)
   202  		var nOfferReceived uint32
   203  		var nAnswerReceived uint32
   204  
   205  		done := make(chan bool)
   206  
   207  		answerPC.OnDataChannel(func(answerDC *DataChannel) {
   208  			// Make sure this is the data channel we were looking for. (Not the one
   209  			// created in signalPair).
   210  			if answerDC.Label() != expectedLabel {
   211  				return
   212  			}
   213  
   214  			answerDC.OnOpen(func() {
   215  				assert.Equal(t, answerBufferedAmountLowThreshold, answerDC.BufferedAmountLowThreshold(), "value mismatch")
   216  
   217  				for i := 0; i < nPacketsToSend; i++ {
   218  					e := answerDC.Send(buf)
   219  					if e != nil {
   220  						t.Fatalf("Failed to send string on data channel")
   221  					}
   222  				}
   223  			})
   224  
   225  			answerDC.OnMessage(func(msg DataChannelMessage) {
   226  				atomic.AddUint32(&nAnswerReceived, 1)
   227  			})
   228  			assert.True(t, answerDC.Ordered(), "Ordered should be set to true")
   229  
   230  			// The value is temporarily stored in the answerDC object
   231  			// until the answerDC gets opened
   232  			answerDC.SetBufferedAmountLowThreshold(answerBufferedAmountLowThreshold)
   233  			// The callback function is temporarily stored in the answerDC object
   234  			// until the answerDC gets opened
   235  			answerDC.OnBufferedAmountLow(func() {
   236  				atomic.AddUint32(&nAnswerBufferedAmountLowCbs, 1)
   237  				if atomic.LoadUint32(&nOfferBufferedAmountLowCbs) > 0 {
   238  					done <- true
   239  				}
   240  			})
   241  		})
   242  
   243  		offerDC, err := offerPC.CreateDataChannel(expectedLabel, nil)
   244  		if err != nil {
   245  			t.Fatalf("Failed to create a PC pair for testing")
   246  		}
   247  
   248  		assert.True(t, offerDC.Ordered(), "Ordered should be set to true")
   249  
   250  		offerDC.OnOpen(func() {
   251  			assert.Equal(t, offerBufferedAmountLowThreshold, offerDC.BufferedAmountLowThreshold(), "value mismatch")
   252  
   253  			for i := 0; i < nPacketsToSend; i++ {
   254  				e := offerDC.Send(buf)
   255  				if e != nil {
   256  					t.Fatalf("Failed to send string on data channel")
   257  				}
   258  				// assert.Equal(t, (i+1)*len(buf), int(offerDC.BufferedAmount()), "unexpected bufferedAmount")
   259  			}
   260  		})
   261  
   262  		offerDC.OnMessage(func(msg DataChannelMessage) {
   263  			atomic.AddUint32(&nOfferReceived, 1)
   264  		})
   265  
   266  		// The value is temporarily stored in the offerDC object
   267  		// until the offerDC gets opened
   268  		offerDC.SetBufferedAmountLowThreshold(offerBufferedAmountLowThreshold)
   269  		// The callback function is temporarily stored in the offerDC object
   270  		// until the offerDC gets opened
   271  		offerDC.OnBufferedAmountLow(func() {
   272  			atomic.AddUint32(&nOfferBufferedAmountLowCbs, 1)
   273  			if atomic.LoadUint32(&nAnswerBufferedAmountLowCbs) > 0 {
   274  				done <- true
   275  			}
   276  		})
   277  
   278  		err = signalPair(offerPC, answerPC)
   279  		if err != nil {
   280  			t.Fatalf("Failed to signal our PC pair for testing")
   281  		}
   282  
   283  		closePair(t, offerPC, answerPC, done)
   284  
   285  		t.Logf("nOfferBufferedAmountLowCbs : %d", nOfferBufferedAmountLowCbs)
   286  		t.Logf("nAnswerBufferedAmountLowCbs: %d", nAnswerBufferedAmountLowCbs)
   287  		assert.True(t, nOfferBufferedAmountLowCbs > uint32(0), "callback should be made at least once")
   288  		assert.True(t, nAnswerBufferedAmountLowCbs > uint32(0), "callback should be made at least once")
   289  	})
   290  
   291  	t.Run("set after datachannel becomes open", func(t *testing.T) {
   292  		report := test.CheckRoutines(t)
   293  		defer report()
   294  
   295  		var nCbs int
   296  		buf := make([]byte, 1000)
   297  
   298  		offerPC, answerPC, err := newPair()
   299  		if err != nil {
   300  			t.Fatalf("Failed to create a PC pair for testing")
   301  		}
   302  
   303  		done := make(chan bool)
   304  
   305  		answerPC.OnDataChannel(func(d *DataChannel) {
   306  			// Make sure this is the data channel we were looking for. (Not the one
   307  			// created in signalPair).
   308  			if d.Label() != expectedLabel {
   309  				return
   310  			}
   311  			var nPacketsReceived int
   312  			d.OnMessage(func(msg DataChannelMessage) {
   313  				nPacketsReceived++
   314  
   315  				if nPacketsReceived == 10 {
   316  					go func() {
   317  						time.Sleep(time.Second)
   318  						done <- true
   319  					}()
   320  				}
   321  			})
   322  			assert.True(t, d.Ordered(), "Ordered should be set to true")
   323  		})
   324  
   325  		dc, err := offerPC.CreateDataChannel(expectedLabel, nil)
   326  		if err != nil {
   327  			t.Fatalf("Failed to create a PC pair for testing")
   328  		}
   329  
   330  		assert.True(t, dc.Ordered(), "Ordered should be set to true")
   331  
   332  		dc.OnOpen(func() {
   333  			// The value should directly be passed to sctp
   334  			dc.SetBufferedAmountLowThreshold(1500)
   335  			// The callback function should directly be passed to sctp
   336  			dc.OnBufferedAmountLow(func() {
   337  				nCbs++
   338  			})
   339  
   340  			for i := 0; i < 10; i++ {
   341  				e := dc.Send(buf)
   342  				if e != nil {
   343  					t.Fatalf("Failed to send string on data channel")
   344  				}
   345  				assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "value mismatch")
   346  				// assert.Equal(t, (i+1)*len(buf), int(dc.BufferedAmount()), "unexpected bufferedAmount")
   347  			}
   348  		})
   349  
   350  		dc.OnMessage(func(msg DataChannelMessage) {
   351  		})
   352  
   353  		err = signalPair(offerPC, answerPC)
   354  		if err != nil {
   355  			t.Fatalf("Failed to signal our PC pair for testing")
   356  		}
   357  
   358  		closePair(t, offerPC, answerPC, done)
   359  
   360  		assert.True(t, nCbs > 0, "callback should be made at least once")
   361  	})
   362  }
   363  
   364  func TestEOF(t *testing.T) {
   365  	report := test.CheckRoutines(t)
   366  	defer report()
   367  
   368  	log := logging.NewDefaultLoggerFactory().NewLogger("test")
   369  	label := "test-channel"
   370  	testData := []byte("this is some test data")
   371  
   372  	t.Run("Detach", func(t *testing.T) {
   373  		// Use Detach data channels mode
   374  		s := SettingEngine{}
   375  		s.DetachDataChannels()
   376  		api := NewAPI(WithSettingEngine(s))
   377  
   378  		// Set up two peer connections.
   379  		config := Configuration{}
   380  		pca, err := api.NewPeerConnection(config)
   381  		if err != nil {
   382  			t.Fatal(err)
   383  		}
   384  		pcb, err := api.NewPeerConnection(config)
   385  		if err != nil {
   386  			t.Fatal(err)
   387  		}
   388  
   389  		defer closePairNow(t, pca, pcb)
   390  
   391  		var wg sync.WaitGroup
   392  
   393  		dcChan := make(chan datachannel.ReadWriteCloser)
   394  		pcb.OnDataChannel(func(dc *DataChannel) {
   395  			if dc.Label() != label {
   396  				return
   397  			}
   398  			log.Debug("OnDataChannel was called")
   399  			dc.OnOpen(func() {
   400  				detached, err2 := dc.Detach()
   401  				if err2 != nil {
   402  					log.Debugf("Detach failed: %s", err2.Error())
   403  					t.Error(err2)
   404  				}
   405  
   406  				dcChan <- detached
   407  			})
   408  		})
   409  
   410  		wg.Add(1)
   411  		go func() {
   412  			defer wg.Done()
   413  
   414  			var msg []byte
   415  
   416  			log.Debug("Waiting for OnDataChannel")
   417  			dc := <-dcChan
   418  			log.Debug("data channel opened")
   419  			defer func() { assert.NoError(t, dc.Close(), "should succeed") }()
   420  
   421  			log.Debug("Waiting for ping...")
   422  			msg, err2 := ioutil.ReadAll(dc)
   423  			log.Debugf("Received ping! \"%s\"", string(msg))
   424  			if err2 != nil {
   425  				t.Error(err2)
   426  			}
   427  
   428  			if !bytes.Equal(msg, testData) {
   429  				t.Errorf("expected %q, got %q", string(msg), string(testData))
   430  			} else {
   431  				log.Debug("Received ping successfully!")
   432  			}
   433  		}()
   434  
   435  		if err = signalPair(pca, pcb); err != nil {
   436  			t.Fatal(err)
   437  		}
   438  
   439  		attached, err := pca.CreateDataChannel(label, nil)
   440  		if err != nil {
   441  			t.Fatal(err)
   442  		}
   443  		log.Debug("Waiting for data channel to open")
   444  		open := make(chan struct{})
   445  		attached.OnOpen(func() {
   446  			open <- struct{}{}
   447  		})
   448  		<-open
   449  		log.Debug("data channel opened")
   450  
   451  		var dc io.ReadWriteCloser
   452  		dc, err = attached.Detach()
   453  		if err != nil {
   454  			t.Fatal(err)
   455  		}
   456  
   457  		wg.Add(1)
   458  		go func() {
   459  			defer wg.Done()
   460  			log.Debug("Sending ping...")
   461  			if _, err2 := dc.Write(testData); err2 != nil {
   462  				t.Error(err2)
   463  			}
   464  			log.Debug("Sent ping")
   465  
   466  			assert.NoError(t, dc.Close(), "should succeed")
   467  
   468  			log.Debug("Wating for EOF")
   469  			ret, err2 := ioutil.ReadAll(dc)
   470  			assert.Nil(t, err2, "should succeed")
   471  			assert.Equal(t, 0, len(ret), "should be empty")
   472  		}()
   473  
   474  		wg.Wait()
   475  	})
   476  
   477  	t.Run("No detach", func(t *testing.T) {
   478  		lim := test.TimeOut(time.Second * 5)
   479  		defer lim.Stop()
   480  
   481  		// Set up two peer connections.
   482  		config := Configuration{}
   483  		pca, err := NewPeerConnection(config)
   484  		if err != nil {
   485  			t.Fatal(err)
   486  		}
   487  		pcb, err := NewPeerConnection(config)
   488  		if err != nil {
   489  			t.Fatal(err)
   490  		}
   491  
   492  		defer closePairNow(t, pca, pcb)
   493  
   494  		var dca, dcb *DataChannel
   495  		dcaClosedCh := make(chan struct{})
   496  		dcbClosedCh := make(chan struct{})
   497  
   498  		pcb.OnDataChannel(func(dc *DataChannel) {
   499  			if dc.Label() != label {
   500  				return
   501  			}
   502  
   503  			log.Debugf("pcb: new datachannel: %s", dc.Label())
   504  
   505  			dcb = dc
   506  			// Register channel opening handling
   507  			dcb.OnOpen(func() {
   508  				log.Debug("pcb: datachannel opened")
   509  			})
   510  
   511  			dcb.OnClose(func() {
   512  				// (2)
   513  				log.Debug("pcb: data channel closed")
   514  				close(dcbClosedCh)
   515  			})
   516  
   517  			// Register the OnMessage to handle incoming messages
   518  			log.Debug("pcb: registering onMessage callback")
   519  			dcb.OnMessage(func(dcMsg DataChannelMessage) {
   520  				log.Debugf("pcb: received ping: %s", string(dcMsg.Data))
   521  				if !reflect.DeepEqual(dcMsg.Data, testData) {
   522  					t.Error("data mismatch")
   523  				}
   524  			})
   525  		})
   526  
   527  		dca, err = pca.CreateDataChannel(label, nil)
   528  		if err != nil {
   529  			t.Fatal(err)
   530  		}
   531  
   532  		dca.OnOpen(func() {
   533  			log.Debug("pca: data channel opened")
   534  			log.Debugf("pca: sending \"%s\"", string(testData))
   535  			if err := dca.Send(testData); err != nil {
   536  				t.Fatal(err)
   537  			}
   538  			log.Debug("pca: sent ping")
   539  			assert.NoError(t, dca.Close(), "should succeed") // <-- dca closes
   540  		})
   541  
   542  		dca.OnClose(func() {
   543  			// (1)
   544  			log.Debug("pca: data channel closed")
   545  			close(dcaClosedCh)
   546  		})
   547  
   548  		// Register the OnMessage to handle incoming messages
   549  		log.Debug("pca: registering onMessage callback")
   550  		dca.OnMessage(func(dcMsg DataChannelMessage) {
   551  			log.Debugf("pca: received pong: %s", string(dcMsg.Data))
   552  			if !reflect.DeepEqual(dcMsg.Data, testData) {
   553  				t.Error("data mismatch")
   554  			}
   555  		})
   556  
   557  		if err := signalPair(pca, pcb); err != nil {
   558  			t.Fatal(err)
   559  		}
   560  
   561  		// When dca closes the channel,
   562  		// (1) dca.Onclose() will fire immediately, then
   563  		// (2) dcb.OnClose will also fire
   564  		<-dcaClosedCh // (1)
   565  		<-dcbClosedCh // (2)
   566  	})
   567  }
   568  
   569  // Assert that a Session Description that doesn't follow
   570  // draft-ietf-mmusic-sctp-sdp is still accepted
   571  func TestDataChannel_NonStandardSessionDescription(t *testing.T) {
   572  	to := test.TimeOut(time.Second * 20)
   573  	defer to.Stop()
   574  
   575  	report := test.CheckRoutines(t)
   576  	defer report()
   577  
   578  	offerPC, answerPC, err := newPair()
   579  	assert.NoError(t, err)
   580  
   581  	_, err = offerPC.CreateDataChannel("foo", nil)
   582  	assert.NoError(t, err)
   583  
   584  	onDataChannelCalled := make(chan struct{})
   585  	answerPC.OnDataChannel(func(_ *DataChannel) {
   586  		close(onDataChannelCalled)
   587  	})
   588  
   589  	offer, err := offerPC.CreateOffer(nil)
   590  	assert.NoError(t, err)
   591  
   592  	offerGatheringComplete := GatheringCompletePromise(offerPC)
   593  	assert.NoError(t, offerPC.SetLocalDescription(offer))
   594  	<-offerGatheringComplete
   595  
   596  	offer = *offerPC.LocalDescription()
   597  
   598  	// Replace with old values
   599  	const (
   600  		oldApplication = "m=application 63743 DTLS/SCTP 5000\r"
   601  		oldAttribute   = "a=sctpmap:5000 webrtc-datachannel 256\r"
   602  	)
   603  
   604  	offer.SDP = regexp.MustCompile(`m=application (.*?)\r`).ReplaceAllString(offer.SDP, oldApplication)
   605  	offer.SDP = regexp.MustCompile(`a=sctp-port(.*?)\r`).ReplaceAllString(offer.SDP, oldAttribute)
   606  
   607  	// Assert that replace worked
   608  	assert.True(t, strings.Contains(offer.SDP, oldApplication))
   609  	assert.True(t, strings.Contains(offer.SDP, oldAttribute))
   610  
   611  	assert.NoError(t, answerPC.SetRemoteDescription(offer))
   612  
   613  	answer, err := answerPC.CreateAnswer(nil)
   614  	assert.NoError(t, err)
   615  
   616  	answerGatheringComplete := GatheringCompletePromise(answerPC)
   617  	assert.NoError(t, answerPC.SetLocalDescription(answer))
   618  	<-answerGatheringComplete
   619  	assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription()))
   620  
   621  	<-onDataChannelCalled
   622  	closePairNow(t, offerPC, answerPC)
   623  }
   624  
   625  func TestDataChannel_Dial(t *testing.T) {
   626  	t.Run("handler should be called once, by dialing peer only", func(t *testing.T) {
   627  		report := test.CheckRoutines(t)
   628  		defer report()
   629  
   630  		dialCalls := make(chan bool, 2)
   631  		wg := new(sync.WaitGroup)
   632  		wg.Add(2)
   633  
   634  		offerPC, answerPC, err := newPair()
   635  		if err != nil {
   636  			t.Fatalf("Failed to create a PC pair for testing")
   637  		}
   638  
   639  		answerPC.OnDataChannel(func(d *DataChannel) {
   640  			if d.Label() != expectedLabel {
   641  				return
   642  			}
   643  
   644  			d.OnDial(func() {
   645  				// only dialing side should fire OnDial
   646  				t.Fatalf("answering side should not call on dial")
   647  			})
   648  
   649  			d.OnOpen(wg.Done)
   650  		})
   651  
   652  		d, err := offerPC.CreateDataChannel(expectedLabel, nil)
   653  		assert.NoError(t, err)
   654  		d.OnDial(func() {
   655  			dialCalls <- true
   656  			wg.Done()
   657  		})
   658  
   659  		assert.NoError(t, signalPair(offerPC, answerPC))
   660  
   661  		wg.Wait()
   662  		closePairNow(t, offerPC, answerPC)
   663  
   664  		assert.Len(t, dialCalls, 1)
   665  	})
   666  
   667  	t.Run("handler should be called immediately if already dialed", func(t *testing.T) {
   668  		report := test.CheckRoutines(t)
   669  		defer report()
   670  
   671  		done := make(chan bool)
   672  
   673  		offerPC, answerPC, err := newPair()
   674  		if err != nil {
   675  			t.Fatalf("Failed to create a PC pair for testing")
   676  		}
   677  
   678  		d, err := offerPC.CreateDataChannel(expectedLabel, nil)
   679  		assert.NoError(t, err)
   680  		d.OnOpen(func() {
   681  			// when the offer DC has been opened, its guaranteed to have dialed since it has
   682  			// received a response to said dial. this test represents an unrealistic usage,
   683  			// but its the best way to guarantee we "missed" the dial event and still invoke
   684  			// the handler.
   685  			d.OnDial(func() {
   686  				done <- true
   687  			})
   688  		})
   689  
   690  		assert.NoError(t, signalPair(offerPC, answerPC))
   691  
   692  		closePair(t, offerPC, answerPC, done)
   693  	})
   694  }