github.com/pion/webrtc/v3@v3.2.24/datachannel_go_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "bytes" 11 "crypto/rand" 12 "encoding/binary" 13 "io" 14 "io/ioutil" 15 "math/big" 16 "reflect" 17 "regexp" 18 "strings" 19 "sync" 20 "sync/atomic" 21 "testing" 22 "time" 23 24 "github.com/pion/datachannel" 25 "github.com/pion/logging" 26 "github.com/pion/transport/v2/test" 27 "github.com/stretchr/testify/assert" 28 ) 29 30 func TestDataChannel_EventHandlers(t *testing.T) { 31 to := test.TimeOut(time.Second * 20) 32 defer to.Stop() 33 34 report := test.CheckRoutines(t) 35 defer report() 36 37 api := NewAPI() 38 dc := &DataChannel{api: api} 39 40 onDialCalled := make(chan struct{}) 41 onOpenCalled := make(chan struct{}) 42 onMessageCalled := make(chan struct{}) 43 44 // Verify that the noop case works 45 assert.NotPanics(t, func() { dc.onOpen() }) 46 47 dc.OnDial(func() { 48 close(onDialCalled) 49 }) 50 51 dc.OnOpen(func() { 52 close(onOpenCalled) 53 }) 54 55 dc.OnMessage(func(p DataChannelMessage) { 56 close(onMessageCalled) 57 }) 58 59 // Verify that the set handlers are called 60 assert.NotPanics(t, func() { dc.onDial() }) 61 assert.NotPanics(t, func() { dc.onOpen() }) 62 assert.NotPanics(t, func() { dc.onMessage(DataChannelMessage{Data: []byte("o hai")}) }) 63 64 // Wait for all handlers to be called 65 <-onDialCalled 66 <-onOpenCalled 67 <-onMessageCalled 68 } 69 70 func TestDataChannel_MessagesAreOrdered(t *testing.T) { 71 report := test.CheckRoutines(t) 72 defer report() 73 74 api := NewAPI() 75 dc := &DataChannel{api: api} 76 77 max := 512 78 out := make(chan int) 79 inner := func(msg DataChannelMessage) { 80 // randomly sleep 81 // math/rand a weak RNG, but this does not need to be secure. Ignore with #nosec 82 /* #nosec */ 83 randInt, err := rand.Int(rand.Reader, big.NewInt(int64(max))) 84 /* #nosec */ if err != nil { 85 t.Fatalf("Failed to get random sleep duration: %s", err) 86 } 87 time.Sleep(time.Duration(randInt.Int64()) * time.Microsecond) 88 s, _ := binary.Varint(msg.Data) 89 out <- int(s) 90 } 91 dc.OnMessage(func(p DataChannelMessage) { 92 inner(p) 93 }) 94 95 go func() { 96 for i := 1; i <= max; i++ { 97 buf := make([]byte, 8) 98 binary.PutVarint(buf, int64(i)) 99 dc.onMessage(DataChannelMessage{Data: buf}) 100 // Change the registered handler a couple of times to make sure 101 // that everything continues to work, we don't lose messages, etc. 102 if i%2 == 0 { 103 handler := func(msg DataChannelMessage) { 104 inner(msg) 105 } 106 dc.OnMessage(handler) 107 } 108 } 109 }() 110 111 values := make([]int, 0, max) 112 for v := range out { 113 values = append(values, v) 114 if len(values) == max { 115 close(out) 116 } 117 } 118 119 expected := make([]int, max) 120 for i := 1; i <= max; i++ { 121 expected[i-1] = i 122 } 123 assert.EqualValues(t, expected, values) 124 } 125 126 // Note(albrow): This test includes some features that aren't supported by the 127 // Wasm bindings (at least for now). 128 func TestDataChannelParamters_Go(t *testing.T) { 129 report := test.CheckRoutines(t) 130 defer report() 131 132 t.Run("MaxPacketLifeTime exchange", func(t *testing.T) { 133 ordered := true 134 var maxPacketLifeTime uint16 = 3 135 options := &DataChannelInit{ 136 Ordered: &ordered, 137 MaxPacketLifeTime: &maxPacketLifeTime, 138 } 139 140 offerPC, answerPC, dc, done := setUpDataChannelParametersTest(t, options) 141 142 // Check if parameters are correctly set 143 assert.True(t, dc.Ordered(), "Ordered should be set to true") 144 if assert.NotNil(t, dc.MaxPacketLifeTime(), "should not be nil") { 145 assert.Equal(t, maxPacketLifeTime, *dc.MaxPacketLifeTime(), "should match") 146 } 147 148 answerPC.OnDataChannel(func(d *DataChannel) { 149 // Make sure this is the data channel we were looking for. (Not the one 150 // created in signalPair). 151 if d.Label() != expectedLabel { 152 return 153 } 154 155 // Check if parameters are correctly set 156 assert.True(t, d.ordered, "Ordered should be set to true") 157 if assert.NotNil(t, d.maxPacketLifeTime, "should not be nil") { 158 assert.Equal(t, maxPacketLifeTime, *d.maxPacketLifeTime, "should match") 159 } 160 done <- true 161 }) 162 163 closeReliabilityParamTest(t, offerPC, answerPC, done) 164 }) 165 166 t.Run("All other property methods", func(t *testing.T) { 167 id := uint16(123) 168 dc := &DataChannel{} 169 dc.id = &id 170 dc.label = "mylabel" 171 dc.protocol = "myprotocol" 172 dc.negotiated = true 173 174 assert.Equal(t, dc.id, dc.ID(), "should match") 175 assert.Equal(t, dc.label, dc.Label(), "should match") 176 assert.Equal(t, dc.protocol, dc.Protocol(), "should match") 177 assert.Equal(t, dc.negotiated, dc.Negotiated(), "should match") 178 assert.Equal(t, uint64(0), dc.BufferedAmount(), "should match") 179 dc.SetBufferedAmountLowThreshold(1500) 180 assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "should match") 181 }) 182 } 183 184 func TestDataChannelBufferedAmount(t *testing.T) { 185 t.Run("set before datachannel becomes open", func(t *testing.T) { 186 report := test.CheckRoutines(t) 187 defer report() 188 189 var nOfferBufferedAmountLowCbs uint32 190 var offerBufferedAmountLowThreshold uint64 = 1500 191 var nAnswerBufferedAmountLowCbs uint32 192 var answerBufferedAmountLowThreshold uint64 = 1400 193 194 buf := make([]byte, 1000) 195 196 offerPC, answerPC, err := newPair() 197 if err != nil { 198 t.Fatalf("Failed to create a PC pair for testing") 199 } 200 201 nPacketsToSend := int(10) 202 var nOfferReceived uint32 203 var nAnswerReceived uint32 204 205 done := make(chan bool) 206 207 answerPC.OnDataChannel(func(answerDC *DataChannel) { 208 // Make sure this is the data channel we were looking for. (Not the one 209 // created in signalPair). 210 if answerDC.Label() != expectedLabel { 211 return 212 } 213 214 answerDC.OnOpen(func() { 215 assert.Equal(t, answerBufferedAmountLowThreshold, answerDC.BufferedAmountLowThreshold(), "value mismatch") 216 217 for i := 0; i < nPacketsToSend; i++ { 218 e := answerDC.Send(buf) 219 if e != nil { 220 t.Fatalf("Failed to send string on data channel") 221 } 222 } 223 }) 224 225 answerDC.OnMessage(func(msg DataChannelMessage) { 226 atomic.AddUint32(&nAnswerReceived, 1) 227 }) 228 assert.True(t, answerDC.Ordered(), "Ordered should be set to true") 229 230 // The value is temporarily stored in the answerDC object 231 // until the answerDC gets opened 232 answerDC.SetBufferedAmountLowThreshold(answerBufferedAmountLowThreshold) 233 // The callback function is temporarily stored in the answerDC object 234 // until the answerDC gets opened 235 answerDC.OnBufferedAmountLow(func() { 236 atomic.AddUint32(&nAnswerBufferedAmountLowCbs, 1) 237 if atomic.LoadUint32(&nOfferBufferedAmountLowCbs) > 0 { 238 done <- true 239 } 240 }) 241 }) 242 243 offerDC, err := offerPC.CreateDataChannel(expectedLabel, nil) 244 if err != nil { 245 t.Fatalf("Failed to create a PC pair for testing") 246 } 247 248 assert.True(t, offerDC.Ordered(), "Ordered should be set to true") 249 250 offerDC.OnOpen(func() { 251 assert.Equal(t, offerBufferedAmountLowThreshold, offerDC.BufferedAmountLowThreshold(), "value mismatch") 252 253 for i := 0; i < nPacketsToSend; i++ { 254 e := offerDC.Send(buf) 255 if e != nil { 256 t.Fatalf("Failed to send string on data channel") 257 } 258 // assert.Equal(t, (i+1)*len(buf), int(offerDC.BufferedAmount()), "unexpected bufferedAmount") 259 } 260 }) 261 262 offerDC.OnMessage(func(msg DataChannelMessage) { 263 atomic.AddUint32(&nOfferReceived, 1) 264 }) 265 266 // The value is temporarily stored in the offerDC object 267 // until the offerDC gets opened 268 offerDC.SetBufferedAmountLowThreshold(offerBufferedAmountLowThreshold) 269 // The callback function is temporarily stored in the offerDC object 270 // until the offerDC gets opened 271 offerDC.OnBufferedAmountLow(func() { 272 atomic.AddUint32(&nOfferBufferedAmountLowCbs, 1) 273 if atomic.LoadUint32(&nAnswerBufferedAmountLowCbs) > 0 { 274 done <- true 275 } 276 }) 277 278 err = signalPair(offerPC, answerPC) 279 if err != nil { 280 t.Fatalf("Failed to signal our PC pair for testing") 281 } 282 283 closePair(t, offerPC, answerPC, done) 284 285 t.Logf("nOfferBufferedAmountLowCbs : %d", nOfferBufferedAmountLowCbs) 286 t.Logf("nAnswerBufferedAmountLowCbs: %d", nAnswerBufferedAmountLowCbs) 287 assert.True(t, nOfferBufferedAmountLowCbs > uint32(0), "callback should be made at least once") 288 assert.True(t, nAnswerBufferedAmountLowCbs > uint32(0), "callback should be made at least once") 289 }) 290 291 t.Run("set after datachannel becomes open", func(t *testing.T) { 292 report := test.CheckRoutines(t) 293 defer report() 294 295 var nCbs int 296 buf := make([]byte, 1000) 297 298 offerPC, answerPC, err := newPair() 299 if err != nil { 300 t.Fatalf("Failed to create a PC pair for testing") 301 } 302 303 done := make(chan bool) 304 305 answerPC.OnDataChannel(func(d *DataChannel) { 306 // Make sure this is the data channel we were looking for. (Not the one 307 // created in signalPair). 308 if d.Label() != expectedLabel { 309 return 310 } 311 var nPacketsReceived int 312 d.OnMessage(func(msg DataChannelMessage) { 313 nPacketsReceived++ 314 315 if nPacketsReceived == 10 { 316 go func() { 317 time.Sleep(time.Second) 318 done <- true 319 }() 320 } 321 }) 322 assert.True(t, d.Ordered(), "Ordered should be set to true") 323 }) 324 325 dc, err := offerPC.CreateDataChannel(expectedLabel, nil) 326 if err != nil { 327 t.Fatalf("Failed to create a PC pair for testing") 328 } 329 330 assert.True(t, dc.Ordered(), "Ordered should be set to true") 331 332 dc.OnOpen(func() { 333 // The value should directly be passed to sctp 334 dc.SetBufferedAmountLowThreshold(1500) 335 // The callback function should directly be passed to sctp 336 dc.OnBufferedAmountLow(func() { 337 nCbs++ 338 }) 339 340 for i := 0; i < 10; i++ { 341 e := dc.Send(buf) 342 if e != nil { 343 t.Fatalf("Failed to send string on data channel") 344 } 345 assert.Equal(t, uint64(1500), dc.BufferedAmountLowThreshold(), "value mismatch") 346 // assert.Equal(t, (i+1)*len(buf), int(dc.BufferedAmount()), "unexpected bufferedAmount") 347 } 348 }) 349 350 dc.OnMessage(func(msg DataChannelMessage) { 351 }) 352 353 err = signalPair(offerPC, answerPC) 354 if err != nil { 355 t.Fatalf("Failed to signal our PC pair for testing") 356 } 357 358 closePair(t, offerPC, answerPC, done) 359 360 assert.True(t, nCbs > 0, "callback should be made at least once") 361 }) 362 } 363 364 func TestEOF(t *testing.T) { 365 report := test.CheckRoutines(t) 366 defer report() 367 368 log := logging.NewDefaultLoggerFactory().NewLogger("test") 369 label := "test-channel" 370 testData := []byte("this is some test data") 371 372 t.Run("Detach", func(t *testing.T) { 373 // Use Detach data channels mode 374 s := SettingEngine{} 375 s.DetachDataChannels() 376 api := NewAPI(WithSettingEngine(s)) 377 378 // Set up two peer connections. 379 config := Configuration{} 380 pca, err := api.NewPeerConnection(config) 381 if err != nil { 382 t.Fatal(err) 383 } 384 pcb, err := api.NewPeerConnection(config) 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 defer closePairNow(t, pca, pcb) 390 391 var wg sync.WaitGroup 392 393 dcChan := make(chan datachannel.ReadWriteCloser) 394 pcb.OnDataChannel(func(dc *DataChannel) { 395 if dc.Label() != label { 396 return 397 } 398 log.Debug("OnDataChannel was called") 399 dc.OnOpen(func() { 400 detached, err2 := dc.Detach() 401 if err2 != nil { 402 log.Debugf("Detach failed: %s", err2.Error()) 403 t.Error(err2) 404 } 405 406 dcChan <- detached 407 }) 408 }) 409 410 wg.Add(1) 411 go func() { 412 defer wg.Done() 413 414 var msg []byte 415 416 log.Debug("Waiting for OnDataChannel") 417 dc := <-dcChan 418 log.Debug("data channel opened") 419 defer func() { assert.NoError(t, dc.Close(), "should succeed") }() 420 421 log.Debug("Waiting for ping...") 422 msg, err2 := ioutil.ReadAll(dc) 423 log.Debugf("Received ping! \"%s\"", string(msg)) 424 if err2 != nil { 425 t.Error(err2) 426 } 427 428 if !bytes.Equal(msg, testData) { 429 t.Errorf("expected %q, got %q", string(msg), string(testData)) 430 } else { 431 log.Debug("Received ping successfully!") 432 } 433 }() 434 435 if err = signalPair(pca, pcb); err != nil { 436 t.Fatal(err) 437 } 438 439 attached, err := pca.CreateDataChannel(label, nil) 440 if err != nil { 441 t.Fatal(err) 442 } 443 log.Debug("Waiting for data channel to open") 444 open := make(chan struct{}) 445 attached.OnOpen(func() { 446 open <- struct{}{} 447 }) 448 <-open 449 log.Debug("data channel opened") 450 451 var dc io.ReadWriteCloser 452 dc, err = attached.Detach() 453 if err != nil { 454 t.Fatal(err) 455 } 456 457 wg.Add(1) 458 go func() { 459 defer wg.Done() 460 log.Debug("Sending ping...") 461 if _, err2 := dc.Write(testData); err2 != nil { 462 t.Error(err2) 463 } 464 log.Debug("Sent ping") 465 466 assert.NoError(t, dc.Close(), "should succeed") 467 468 log.Debug("Wating for EOF") 469 ret, err2 := ioutil.ReadAll(dc) 470 assert.Nil(t, err2, "should succeed") 471 assert.Equal(t, 0, len(ret), "should be empty") 472 }() 473 474 wg.Wait() 475 }) 476 477 t.Run("No detach", func(t *testing.T) { 478 lim := test.TimeOut(time.Second * 5) 479 defer lim.Stop() 480 481 // Set up two peer connections. 482 config := Configuration{} 483 pca, err := NewPeerConnection(config) 484 if err != nil { 485 t.Fatal(err) 486 } 487 pcb, err := NewPeerConnection(config) 488 if err != nil { 489 t.Fatal(err) 490 } 491 492 defer closePairNow(t, pca, pcb) 493 494 var dca, dcb *DataChannel 495 dcaClosedCh := make(chan struct{}) 496 dcbClosedCh := make(chan struct{}) 497 498 pcb.OnDataChannel(func(dc *DataChannel) { 499 if dc.Label() != label { 500 return 501 } 502 503 log.Debugf("pcb: new datachannel: %s", dc.Label()) 504 505 dcb = dc 506 // Register channel opening handling 507 dcb.OnOpen(func() { 508 log.Debug("pcb: datachannel opened") 509 }) 510 511 dcb.OnClose(func() { 512 // (2) 513 log.Debug("pcb: data channel closed") 514 close(dcbClosedCh) 515 }) 516 517 // Register the OnMessage to handle incoming messages 518 log.Debug("pcb: registering onMessage callback") 519 dcb.OnMessage(func(dcMsg DataChannelMessage) { 520 log.Debugf("pcb: received ping: %s", string(dcMsg.Data)) 521 if !reflect.DeepEqual(dcMsg.Data, testData) { 522 t.Error("data mismatch") 523 } 524 }) 525 }) 526 527 dca, err = pca.CreateDataChannel(label, nil) 528 if err != nil { 529 t.Fatal(err) 530 } 531 532 dca.OnOpen(func() { 533 log.Debug("pca: data channel opened") 534 log.Debugf("pca: sending \"%s\"", string(testData)) 535 if err := dca.Send(testData); err != nil { 536 t.Fatal(err) 537 } 538 log.Debug("pca: sent ping") 539 assert.NoError(t, dca.Close(), "should succeed") // <-- dca closes 540 }) 541 542 dca.OnClose(func() { 543 // (1) 544 log.Debug("pca: data channel closed") 545 close(dcaClosedCh) 546 }) 547 548 // Register the OnMessage to handle incoming messages 549 log.Debug("pca: registering onMessage callback") 550 dca.OnMessage(func(dcMsg DataChannelMessage) { 551 log.Debugf("pca: received pong: %s", string(dcMsg.Data)) 552 if !reflect.DeepEqual(dcMsg.Data, testData) { 553 t.Error("data mismatch") 554 } 555 }) 556 557 if err := signalPair(pca, pcb); err != nil { 558 t.Fatal(err) 559 } 560 561 // When dca closes the channel, 562 // (1) dca.Onclose() will fire immediately, then 563 // (2) dcb.OnClose will also fire 564 <-dcaClosedCh // (1) 565 <-dcbClosedCh // (2) 566 }) 567 } 568 569 // Assert that a Session Description that doesn't follow 570 // draft-ietf-mmusic-sctp-sdp is still accepted 571 func TestDataChannel_NonStandardSessionDescription(t *testing.T) { 572 to := test.TimeOut(time.Second * 20) 573 defer to.Stop() 574 575 report := test.CheckRoutines(t) 576 defer report() 577 578 offerPC, answerPC, err := newPair() 579 assert.NoError(t, err) 580 581 _, err = offerPC.CreateDataChannel("foo", nil) 582 assert.NoError(t, err) 583 584 onDataChannelCalled := make(chan struct{}) 585 answerPC.OnDataChannel(func(_ *DataChannel) { 586 close(onDataChannelCalled) 587 }) 588 589 offer, err := offerPC.CreateOffer(nil) 590 assert.NoError(t, err) 591 592 offerGatheringComplete := GatheringCompletePromise(offerPC) 593 assert.NoError(t, offerPC.SetLocalDescription(offer)) 594 <-offerGatheringComplete 595 596 offer = *offerPC.LocalDescription() 597 598 // Replace with old values 599 const ( 600 oldApplication = "m=application 63743 DTLS/SCTP 5000\r" 601 oldAttribute = "a=sctpmap:5000 webrtc-datachannel 256\r" 602 ) 603 604 offer.SDP = regexp.MustCompile(`m=application (.*?)\r`).ReplaceAllString(offer.SDP, oldApplication) 605 offer.SDP = regexp.MustCompile(`a=sctp-port(.*?)\r`).ReplaceAllString(offer.SDP, oldAttribute) 606 607 // Assert that replace worked 608 assert.True(t, strings.Contains(offer.SDP, oldApplication)) 609 assert.True(t, strings.Contains(offer.SDP, oldAttribute)) 610 611 assert.NoError(t, answerPC.SetRemoteDescription(offer)) 612 613 answer, err := answerPC.CreateAnswer(nil) 614 assert.NoError(t, err) 615 616 answerGatheringComplete := GatheringCompletePromise(answerPC) 617 assert.NoError(t, answerPC.SetLocalDescription(answer)) 618 <-answerGatheringComplete 619 assert.NoError(t, offerPC.SetRemoteDescription(*answerPC.LocalDescription())) 620 621 <-onDataChannelCalled 622 closePairNow(t, offerPC, answerPC) 623 } 624 625 func TestDataChannel_Dial(t *testing.T) { 626 t.Run("handler should be called once, by dialing peer only", func(t *testing.T) { 627 report := test.CheckRoutines(t) 628 defer report() 629 630 dialCalls := make(chan bool, 2) 631 wg := new(sync.WaitGroup) 632 wg.Add(2) 633 634 offerPC, answerPC, err := newPair() 635 if err != nil { 636 t.Fatalf("Failed to create a PC pair for testing") 637 } 638 639 answerPC.OnDataChannel(func(d *DataChannel) { 640 if d.Label() != expectedLabel { 641 return 642 } 643 644 d.OnDial(func() { 645 // only dialing side should fire OnDial 646 t.Fatalf("answering side should not call on dial") 647 }) 648 649 d.OnOpen(wg.Done) 650 }) 651 652 d, err := offerPC.CreateDataChannel(expectedLabel, nil) 653 assert.NoError(t, err) 654 d.OnDial(func() { 655 dialCalls <- true 656 wg.Done() 657 }) 658 659 assert.NoError(t, signalPair(offerPC, answerPC)) 660 661 wg.Wait() 662 closePairNow(t, offerPC, answerPC) 663 664 assert.Len(t, dialCalls, 1) 665 }) 666 667 t.Run("handler should be called immediately if already dialed", func(t *testing.T) { 668 report := test.CheckRoutines(t) 669 defer report() 670 671 done := make(chan bool) 672 673 offerPC, answerPC, err := newPair() 674 if err != nil { 675 t.Fatalf("Failed to create a PC pair for testing") 676 } 677 678 d, err := offerPC.CreateDataChannel(expectedLabel, nil) 679 assert.NoError(t, err) 680 d.OnOpen(func() { 681 // when the offer DC has been opened, its guaranteed to have dialed since it has 682 // received a response to said dial. this test represents an unrealistic usage, 683 // but its the best way to guarantee we "missed" the dial event and still invoke 684 // the handler. 685 d.OnDial(func() { 686 done <- true 687 }) 688 }) 689 690 assert.NoError(t, signalPair(offerPC, answerPC)) 691 692 closePair(t, offerPC, answerPC, done) 693 }) 694 }