github.com/pion/webrtc/v4@v4.0.1/datachannel_go_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "bytes" 11 "crypto/rand" 12 "encoding/binary" 13 "io" 14 "math/big" 15 "reflect" 16 "regexp" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "testing" 21 "time" 22 23 "github.com/pion/datachannel" 24 "github.com/pion/logging" 25 "github.com/pion/transport/v3/test" 26 "github.com/stretchr/testify/assert" 27 ) 28 29 func TestDataChannel_EventHandlers(t *testing.T) { 30 to := test.TimeOut(time.Second * 20) 31 defer to.Stop() 32 33 report := test.CheckRoutines(t) 34 defer report() 35 36 api := NewAPI() 37 dc := &DataChannel{api: api} 38 39 onDialCalled := make(chan struct{}) 40 onOpenCalled := make(chan struct{}) 41 onMessageCalled := make(chan struct{}) 42 43 // Verify that the noop case works 44 assert.NotPanics(t, func() { dc.onOpen() }) 45 46 dc.OnDial(func() { 47 close(onDialCalled) 48 }) 49 50 dc.OnOpen(func() { 51 close(onOpenCalled) 52 }) 53 54 dc.OnMessage(func(DataChannelMessage) { 55 close(onMessageCalled) 56 }) 57 58 // Verify that the set handlers are called 59 assert.NotPanics(t, func() { dc.onDial() }) 60 assert.NotPanics(t, func() { dc.onOpen() }) 61 assert.NotPanics(t, func() { dc.onMessage(DataChannelMessage{Data: []byte("o hai")}) }) 62 63 // Wait for all handlers to be called 64 <-onDialCalled 65 <-onOpenCalled 66 <-onMessageCalled 67 } 68 69 func TestDataChannel_MessagesAreOrdered(t *testing.T) { 70 report := test.CheckRoutines(t) 71 defer report() 72 73 api := NewAPI() 74 dc := &DataChannel{api: api} 75 76 maxVal := 512 77 out := make(chan int) 78 inner := func(msg DataChannelMessage) { 79 // randomly sleep 80 // math/rand a weak RNG, but this does not need to be secure. Ignore with #nosec 81 /* #nosec */ 82 randInt, err := rand.Int(rand.Reader, big.NewInt(int64(maxVal))) 83 /* #nosec */ if err != nil { 84 t.Fatalf("Failed to get random sleep duration: %s", err) 85 } 86 time.Sleep(time.Duration(randInt.Int64()) * time.Microsecond) 87 s, _ := binary.Varint(msg.Data) 88 out <- int(s) 89 } 90 dc.OnMessage(func(p DataChannelMessage) { 91 inner(p) 92 }) 93 94 go func() { 95 for i := 1; i <= maxVal; i++ { 96 buf := make([]byte, 8) 97 binary.PutVarint(buf, int64(i)) 98 dc.onMessage(DataChannelMessage{Data: buf}) 99 // Change the registered handler a couple of times to make sure 100 // that everything continues to work, we don't lose messages, etc. 101 if i%2 == 0 { 102 handler := func(msg DataChannelMessage) { 103 inner(msg) 104 } 105 dc.OnMessage(handler) 106 } 107 } 108 }() 109 110 values := make([]int, 0, maxVal) 111 for v := range out { 112 values = append(values, v) 113 if len(values) == maxVal { 114 close(out) 115 } 116 } 117 118 expected := make([]int, maxVal) 119 for i := 1; i <= maxVal; i++ { 120 expected[i-1] = i 121 } 122 assert.EqualValues(t, expected, values) 123 } 124 125 // Note(albrow): This test includes some features that aren't supported by the 126 // Wasm bindings (at least for now). 127 func TestDataChannelParamters_Go(t *testing.T) { 128 report := test.CheckRoutines(t) 129 defer report() 130 131 t.Run("MaxPacketLifeTime exchange", func(t *testing.T) { 132 ordered := true 133 var maxPacketLifeTime uint16 = 3 134 options := &DataChannelInit{ 135 Ordered: &ordered, 136 MaxPacketLifeTime: &maxPacketLifeTime, 137 } 138 139 offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options) 140 141 // Check if parameters are correctly set 142 assert.True(t, dc.Ordered(), "Ordered should be set to true") 143 if assert.NotNil(t, dc.MaxPacketLifeTime(), "should not be nil") { 144 assert.Equal(t, maxPacketLifeTime, *dc.MaxPacketLifeTime(), "should match") 145 } 146 147 answerPC.OnDataChannel(func(d *DataChannel) { 148 // Make sure this is the data channel we were looking for. (Not the one 149 // created in signalPair). 150 if d.Label() != expectedLabel { 151 return 152 } 153 154 // Check if parameters are correctly set 155 assert.True(t, d.ordered, "Ordered should be set to true") 156 if assert.NotNil(t, d.maxPacketLifeTime, "should not be nil") { 157 assert.Equal(t, maxPacketLifeTime, *d.maxPacketLifeTime, "should match") 158 } 159 done <- true 160 }) 161 162 closeReliabilityParamTest(t, offerPC, answerPC, done) 163 }) 164 165 t.Run("All other property methods", func(t *testing.T) { 166 id := uint16(123) 167 dc := &DataChannel{} 168 dc.id = &id 169 dc.label = "mylabel" 170 dc.protocol = "myprotocol" 171 dc.negotiated = true 172 173 assert.Equal(t, dc.id, dc.ID(), "should match") 174 assert.Equal(t, dc.label, dc.Label(), "should match") 175 assert.Equal(t, dc.protocol, dc.Protocol(), "should match") 176 assert.Equal(t, dc.negotiated, dc.Negotiated(), "should match") 177 assert.Equal(t, uint64(0), dc.BufferedAmount(), "should match") 178 dc.SetBufferedAmountLowThreshold(1500) 179 assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "should match") 180 }) 181 } 182 183 func TestDataChannelBufferedAmount(t *testing.T) { 184 t.Run("set before datachannel becomes open", func(t *testing.T) { 185 report := test.CheckRoutines(t) 186 defer report() 187 188 var nOfferBufferedAmountLowCbs uint32 189 var offerBufferedAmountLowThreshold uint64 = 1500 190 var nAnswerBufferedAmountLowCbs uint32 191 var answerBufferedAmountLowThreshold uint64 = 1400 192 193 buf := make([]byte, 1000) 194 195 offerPC, answerPC, err := newPair() 196 if err != nil { 197 t.Fatalf("Failed to create a PC pair for testing") 198 } 199 200 nPacketsToSend := int(10) 201 var nOfferReceived uint32 202 var nAnswerReceived uint32 203 204 done := make(chan bool) 205 206 answerPC.OnDataChannel(func(answerDC *DataChannel) { 207 // Make sure this is the data channel we were looking for. (Not the one 208 // created in signalPair). 209 if answerDC.Label() != expectedLabel { 210 return 211 } 212 213 answerDC.OnOpen(func() { 214 assert.Equal(t, answerBufferedAmountLowThreshold, answerDC.BufferedAmountLowThreshold(), "value mismatch") 215 216 for i := 0; i < nPacketsToSend; i++ { 217 e := answerDC.Send(buf) 218 if e != nil { 219 t.Fatalf("Failed to send string on data channel") 220 } 221 } 222 }) 223 224 answerDC.OnMessage(func(DataChannelMessage) { 225 atomic.AddUint32(&nAnswerReceived, 1) 226 }) 227 assert.True(t, answerDC.Ordered(), "Ordered should be set to true") 228 229 // The value is temporarily stored in the answerDC object 230 // until the answerDC gets opened 231 answerDC.SetBufferedAmountLowThreshold(answerBufferedAmountLowThreshold) 232 // The callback function is temporarily stored in the answerDC object 233 // until the answerDC gets opened 234 answerDC.OnBufferedAmountLow(func() { 235 atomic.AddUint32(&nAnswerBufferedAmountLowCbs, 1) 236 if atomic.LoadUint32(&nOfferBufferedAmountLowCbs) > 0 { 237 done <- true 238 } 239 }) 240 }) 241 242 offerDC, err := offerPC.CreateDataChannel(expectedLabel, nil) 243 if err != nil { 244 t.Fatalf("Failed to create a PC pair for testing") 245 } 246 247 assert.True(t, offerDC.Ordered(), "Ordered should be set to true") 248 249 offerDC.OnOpen(func() { 250 assert.Equal(t, offerBufferedAmountLowThreshold, offerDC.BufferedAmountLowThreshold(), "value mismatch") 251 252 for i := 0; i < nPacketsToSend; i++ { 253 e := offerDC.Send(buf) 254 if e != nil { 255 t.Fatalf("Failed to send string on data channel") 256 } 257 // assert.Equal(t, (i+1)*len(buf), int(offerDC.BufferedAmount()), "unexpected bufferedAmount") 258 } 259 }) 260 261 offerDC.OnMessage(func(DataChannelMessage) { 262 atomic.AddUint32(&nOfferReceived, 1) 263 }) 264 265 // The value is temporarily stored in the offerDC object 266 // until the offerDC gets opened 267 offerDC.SetBufferedAmountLowThreshold(offerBufferedAmountLowThreshold) 268 // The callback function is temporarily stored in the offerDC object 269 // until the offerDC gets opened 270 offerDC.OnBufferedAmountLow(func() { 271 atomic.AddUint32(&nOfferBufferedAmountLowCbs, 1) 272 if atomic.LoadUint32(&nAnswerBufferedAmountLowCbs) > 0 { 273 done <- true 274 } 275 }) 276 277 err = signalPair(offerPC, answerPC) 278 if err != nil { 279 t.Fatalf("Failed to signal our PC pair for testing") 280 } 281 282 closePair(t, offerPC, answerPC, done) 283 284 t.Logf("nOfferBufferedAmountLowCbs : %d", nOfferBufferedAmountLowCbs) 285 t.Logf("nAnswerBufferedAmountLowCbs: %d", nAnswerBufferedAmountLowCbs) 286 assert.True(t, nOfferBufferedAmountLowCbs > uint32(0), "callback should be made at least once") 287 assert.True(t, nAnswerBufferedAmountLowCbs > uint32(0), "callback should be made at least once") 288 }) 289 290 t.Run("set after datachannel becomes open", func(t *testing.T) { 291 report := test.CheckRoutines(t) 292 defer report() 293 294 var nCbs int 295 buf := make([]byte, 1000) 296 297 offerPC, answerPC, err := newPair() 298 if err != nil { 299 t.Fatalf("Failed to create a PC pair for testing") 300 } 301 302 done := make(chan bool) 303 304 answerPC.OnDataChannel(func(d *DataChannel) { 305 // Make sure this is the data channel we were looking for. (Not the one 306 // created in signalPair). 307 if d.Label() != expectedLabel { 308 return 309 } 310 var nPacketsReceived int 311 d.OnMessage(func(DataChannelMessage) { 312 nPacketsReceived++ 313 314 if nPacketsReceived == 10 { 315 go func() { 316 time.Sleep(time.Second) 317 done <- true 318 }() 319 } 320 }) 321 assert.True(t, d.Ordered(), "Ordered should be set to true") 322 }) 323 324 dc, err := offerPC.CreateDataChannel(expectedLabel, nil) 325 if err != nil { 326 t.Fatalf("Failed to create a PC pair for testing") 327 } 328 329 assert.True(t, dc.Ordered(), "Ordered should be set to true") 330 331 dc.OnOpen(func() { 332 // The value should directly be passed to sctp 333 dc.SetBufferedAmountLowThreshold(1500) 334 // The callback function should directly be passed to sctp 335 dc.OnBufferedAmountLow(func() { 336 nCbs++ 337 }) 338 339 for i := 0; i < 10; i++ { 340 e := dc.Send(buf) 341 if e != nil { 342 t.Fatalf("Failed to send string on data channel") 343 } 344 assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "value mismatch") 345 // assert.Equal(t, (i+1)*len(buf), int(dc.BufferedAmount()), "unexpected bufferedAmount") 346 } 347 }) 348 349 dc.OnMessage(func(DataChannelMessage) { 350 }) 351 352 err = signalPair(offerPC, answerPC) 353 if err != nil { 354 t.Fatalf("Failed to signal our PC pair for testing") 355 } 356 357 closePair(t, offerPC, answerPC, done) 358 359 assert.True(t, nCbs > 0, "callback should be made at least once") 360 }) 361 } 362 363 func TestEOF(t *testing.T) { 364 report := test.CheckRoutines(t) 365 defer report() 366 367 log := logging.NewDefaultLoggerFactory().NewLogger("test") 368 label := "test-channel" 369 testData := []byte("this is some test data") 370 371 t.Run("Detach", func(t *testing.T) { 372 // Use Detach data channels mode 373 s := SettingEngine{} 374 s.DetachDataChannels() 375 api := NewAPI(WithSettingEngine(s)) 376 377 // Set up two peer connections. 378 config := Configuration{} 379 pca, err := api.NewPeerConnection(config) 380 if err != nil { 381 t.Fatal(err) 382 } 383 pcb, err := api.NewPeerConnection(config) 384 if err != nil { 385 t.Fatal(err) 386 } 387 388 defer closePairNow(t, pca, pcb) 389 390 var wg sync.WaitGroup 391 392 dcChan := make(chan datachannel.ReadWriteCloser) 393 pcb.OnDataChannel(func(dc *DataChannel) { 394 if dc.Label() != label { 395 return 396 } 397 log.Debug("OnDataChannel was called") 398 dc.OnOpen(func() { 399 detached, err2 := dc.Detach() 400 if err2 != nil { 401 log.Debugf("Detach failed: %s", err2.Error()) 402 t.Error(err2) 403 } 404 405 dcChan <- detached 406 }) 407 }) 408 409 wg.Add(1) 410 go func() { 411 defer wg.Done() 412 413 var msg []byte 414 415 log.Debug("Waiting for OnDataChannel") 416 dc := <-dcChan 417 log.Debug("data channel opened") 418 defer func() { assert.NoError(t, dc.Close(), "should succeed") }() 419 420 log.Debug("Waiting for ping...") 421 msg, err2 := io.ReadAll(dc) 422 log.Debugf("Received ping! \"%s\"", string(msg)) 423 if err2 != nil { 424 t.Error(err2) 425 } 426 427 if !bytes.Equal(msg, testData) { 428 t.Errorf("expected %q, got %q", string(msg), string(testData)) 429 } else { 430 log.Debug("Received ping successfully!") 431 } 432 }() 433 434 if err = signalPair(pca, pcb); err != nil { 435 t.Fatal(err) 436 } 437 438 attached, err := pca.CreateDataChannel(label, nil) 439 if err != nil { 440 t.Fatal(err) 441 } 442 log.Debug("Waiting for data channel to open") 443 open := make(chan struct{}) 444 attached.OnOpen(func() { 445 open <- struct{}{} 446 }) 447 <-open 448 log.Debug("data channel opened") 449 450 var dc io.ReadWriteCloser 451 dc, err = attached.Detach() 452 if err != nil { 453 t.Fatal(err) 454 } 455 456 wg.Add(1) 457 go func() { 458 defer wg.Done() 459 log.Debug("Sending ping...") 460 if _, err2 := dc.Write(testData); err2 != nil { 461 t.Error(err2) 462 } 463 log.Debug("Sent ping") 464 465 assert.NoError(t, dc.Close(), "should succeed") 466 467 log.Debug("Wating for EOF") 468 ret, err2 := io.ReadAll(dc) 469 assert.Nil(t, err2, "should succeed") 470 assert.Equal(t, 0, len(ret), "should be empty") 471 }() 472 473 wg.Wait() 474 }) 475 476 t.Run("No detach", func(t *testing.T) { 477 lim := test.TimeOut(time.Second * 5) 478 defer lim.Stop() 479 480 // Set up two peer connections. 481 config := Configuration{} 482 pca, err := NewPeerConnection(config) 483 if err != nil { 484 t.Fatal(err) 485 } 486 pcb, err := NewPeerConnection(config) 487 if err != nil { 488 t.Fatal(err) 489 } 490 491 defer closePairNow(t, pca, pcb) 492 493 var dca, dcb *DataChannel 494 dcaClosedCh := make(chan struct{}) 495 dcbClosedCh := make(chan struct{}) 496 497 pcb.OnDataChannel(func(dc *DataChannel) { 498 if dc.Label() != label { 499 return 500 } 501 502 log.Debugf("pcb: new datachannel: %s", dc.Label()) 503 504 dcb = dc 505 // Register channel opening handling 506 dcb.OnOpen(func() { 507 log.Debug("pcb: datachannel opened") 508 }) 509 510 dcb.OnClose(func() { 511 // (2) 512 log.Debug("pcb: data channel closed") 513 close(dcbClosedCh) 514 }) 515 516 // Register the OnMessage to handle incoming messages 517 log.Debug("pcb: registering onMessage callback") 518 dcb.OnMessage(func(dcMsg DataChannelMessage) { 519 log.Debugf("pcb: received ping: %s", string(dcMsg.Data)) 520 if !reflect.DeepEqual(dcMsg.Data, testData) { 521 t.Error("data mismatch") 522 } 523 }) 524 }) 525 526 dca, err = pca.CreateDataChannel(label, nil) 527 if err != nil { 528 t.Fatal(err) 529 } 530 531 dca.OnOpen(func() { 532 log.Debug("pca: data channel opened") 533 log.Debugf("pca: sending \"%s\"", string(testData)) 534 if err := dca.Send(testData); err != nil { 535 t.Fatal(err) 536 } 537 log.Debug("pca: sent ping") 538 assert.NoError(t, dca.Close(), "should succeed") // <-- dca closes 539 }) 540 541 dca.OnClose(func() { 542 // (1) 543 log.Debug("pca: data channel closed") 544 close(dcaClosedCh) 545 }) 546 547 // Register the OnMessage to handle incoming messages 548 log.Debug("pca: registering onMessage callback") 549 dca.OnMessage(func(dcMsg DataChannelMessage) { 550 log.Debugf("pca: received pong: %s", string(dcMsg.Data)) 551 if !reflect.DeepEqual(dcMsg.Data, testData) { 552 t.Error("data mismatch") 553 } 554 }) 555 556 if err := signalPair(pca, pcb); err != nil { 557 t.Fatal(err) 558 } 559 560 // When dca closes the channel, 561 // (1) dca.Onclose() will fire immediately, then 562 // (2) dcb.OnClose will also fire 563 <-dcaClosedCh // (1) 564 <-dcbClosedCh // (2) 565 }) 566 } 567 568 // Assert that a Session Description that doesn't follow 569 // draft-ietf-mmusic-sctp-sdp is still accepted 570 func TestDataChannel_NonStandardSessionDescription(t *testing.T) { 571 to := test.TimeOut(time.Second * 20) 572 defer to.Stop() 573 574 report := test.CheckRoutines(t) 575 defer report() 576 577 offerPC, answerPC, err := newPair() 578 assert.NoError(t, err) 579 580 _, err = offerPC.CreateDataChannel("foo", nil) 581 assert.NoError(t, err) 582 583 onDataChannelCalled := make(chan struct{}) 584 answerPC.OnDataChannel(func(_ *DataChannel) { 585 close(onDataChannelCalled) 586 }) 587 588 offer, err := offerPC.CreateOffer(nil) 589 assert.NoError(t, err) 590 591 offerGatheringComplete := GatheringCompletePromise(offerPC) 592 assert.NoError(t, offerPC.SetLocalDescription(offer)) 593 <-offerGatheringComplete 594 595 offer = *offerPC.LocalDescription() 596 597 // Replace with old values 598 const ( 599 oldApplication = "m=application 63743 DTLS/SCTP 5000\r" 600 oldAttribute = "a=sctpmap:5000 webrtc-datachannel 256\r" 601 ) 602 603 offer.SDP = regexp.MustCompile(`m=application (.*?)\r`).ReplaceAllString(offer.SDP, oldApplication) 604 offer.SDP = regexp.MustCompile(`a=sctp-port(.*?)\r`).ReplaceAllString(offer.SDP, oldAttribute) 605 606 // Assert that replace worked 607 assert.True(t, strings.Contains(offer.SDP, oldApplication)) 608 assert.True(t, strings.Contains(offer.SDP, oldAttribute)) 609 610 assert.NoError(t, answerPC.SetRemoteDescription(offer)) 611 612 answer, err := answerPC.CreateAnswer(nil) 613 assert.NoError(t, err) 614 615 answerGatheringComplete := GatheringCompletePromise(answerPC) 616 assert.NoError(t, answerPC.SetLocalDescription(answer)) 617 <-answerGatheringComplete 618 assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription())) 619 620 <-onDataChannelCalled 621 closePairNow(t, offerPC, answerPC) 622 } 623 624 func TestDataChannel_Dial(t *testing.T) { 625 t.Run("handler should be called once, by dialing peer only", func(t *testing.T) { 626 report := test.CheckRoutines(t) 627 defer report() 628 629 dialCalls := make(chan bool, 2) 630 wg := new(sync.WaitGroup) 631 wg.Add(2) 632 633 offerPC, answerPC, err := newPair() 634 if err != nil { 635 t.Fatalf("Failed to create a PC pair for testing") 636 } 637 638 answerPC.OnDataChannel(func(d *DataChannel) { 639 if d.Label() != expectedLabel { 640 return 641 } 642 643 d.OnDial(func() { 644 // only dialing side should fire OnDial 645 t.Fatalf("answering side should not call on dial") 646 }) 647 648 d.OnOpen(wg.Done) 649 }) 650 651 d, err := offerPC.CreateDataChannel(expectedLabel, nil) 652 assert.NoError(t, err) 653 d.OnDial(func() { 654 dialCalls <- true 655 wg.Done() 656 }) 657 658 assert.NoError(t, signalPair(offerPC, answerPC)) 659 660 wg.Wait() 661 closePairNow(t, offerPC, answerPC) 662 663 assert.Len(t, dialCalls, 1) 664 }) 665 666 t.Run("handler should be called immediately if already dialed", func(t *testing.T) { 667 report := test.CheckRoutines(t) 668 defer report() 669 670 done := make(chan bool) 671 672 offerPC, answerPC, err := newPair() 673 if err != nil { 674 t.Fatalf("Failed to create a PC pair for testing") 675 } 676 677 d, err := offerPC.CreateDataChannel(expectedLabel, nil) 678 assert.NoError(t, err) 679 d.OnOpen(func() { 680 // when the offer DC has been opened, its guaranteed to have dialed since it has 681 // received a response to said dial. this test represents an unrealistic usage, 682 // but its the best way to guarantee we "missed" the dial event and still invoke 683 // the handler. 684 d.OnDial(func() { 685 done <- true 686 }) 687 }) 688 689 assert.NoError(t, signalPair(offerPC, answerPC)) 690 691 closePair(t, offerPC, answerPC, done) 692 }) 693 } 694 695 func TestDetachRemovesDatachannelReference(t *testing.T) { 696 // Use Detach data channels mode 697 s := SettingEngine{} 698 s.DetachDataChannels() 699 api := NewAPI(WithSettingEngine(s)) 700 701 // Set up two peer connections. 702 config := Configuration{} 703 pca, err := api.NewPeerConnection(config) 704 if err != nil { 705 t.Fatal(err) 706 } 707 pcb, err := api.NewPeerConnection(config) 708 if err != nil { 709 t.Fatal(err) 710 } 711 712 defer closePairNow(t, pca, pcb) 713 714 dcChan := make(chan *DataChannel, 1) 715 pcb.OnDataChannel(func(d *DataChannel) { 716 d.OnOpen(func() { 717 if _, detachErr := d.Detach(); detachErr != nil { 718 t.Error(detachErr) 719 } 720 721 dcChan <- d 722 }) 723 }) 724 725 if err = signalPair(pca, pcb); err != nil { 726 t.Fatal(err) 727 } 728 729 attached, err := pca.CreateDataChannel("", nil) 730 if err != nil { 731 t.Fatal(err) 732 } 733 open := make(chan struct{}, 1) 734 attached.OnOpen(func() { 735 open <- struct{}{} 736 }) 737 <-open 738 739 d := <-dcChan 740 d.sctpTransport.lock.RLock() 741 defer d.sctpTransport.lock.RUnlock() 742 for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] { 743 if dc == d { 744 t.Errorf("expected sctpTransport to drop reference to datachannel") 745 } 746 } 747 }