github.com/pion/webrtc/v3@v3.2.24/datachannel.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"math"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"github.com/pion/datachannel"
    19  	"github.com/pion/logging"
    20  	"github.com/pion/webrtc/v3/pkg/rtcerr"
    21  )
    22  
    23  const dataChannelBufferSize = math.MaxUint16 // message size limit for Chromium
    24  var errSCTPNotEstablished = errors.New("SCTP not established")
    25  
    26  // DataChannel represents a WebRTC DataChannel
    27  // The DataChannel interface represents a network channel
    28  // which can be used for bidirectional peer-to-peer transfers of arbitrary data
    29  type DataChannel struct {
    30  	mu sync.RWMutex
    31  
    32  	statsID                    string
    33  	label                      string
    34  	ordered                    bool
    35  	maxPacketLifeTime          *uint16
    36  	maxRetransmits             *uint16
    37  	protocol                   string
    38  	negotiated                 bool
    39  	id                         *uint16
    40  	readyState                 atomic.Value // DataChannelState
    41  	bufferedAmountLowThreshold uint64
    42  	detachCalled               bool
    43  
    44  	// The binaryType represents attribute MUST, on getting, return the value to
    45  	// which it was last set. On setting, if the new value is either the string
    46  	// "blob" or the string "arraybuffer", then set the IDL attribute to this
    47  	// new value. Otherwise, throw a SyntaxError. When an DataChannel object
    48  	// is created, the binaryType attribute MUST be initialized to the string
    49  	// "blob". This attribute controls how binary data is exposed to scripts.
    50  	// binaryType                 string
    51  
    52  	onMessageHandler    func(DataChannelMessage)
    53  	openHandlerOnce     sync.Once
    54  	onOpenHandler       func()
    55  	dialHandlerOnce     sync.Once
    56  	onDialHandler       func()
    57  	onCloseHandler      func()
    58  	onBufferedAmountLow func()
    59  	onErrorHandler      func(error)
    60  
    61  	sctpTransport *SCTPTransport
    62  	dataChannel   *datachannel.DataChannel
    63  
    64  	// A reference to the associated api object used by this datachannel
    65  	api *API
    66  	log logging.LeveledLogger
    67  }
    68  
    69  // NewDataChannel creates a new DataChannel.
    70  // This constructor is part of the ORTC API. It is not
    71  // meant to be used together with the basic WebRTC API.
    72  func (api *API) NewDataChannel(transport *SCTPTransport, params *DataChannelParameters) (*DataChannel, error) {
    73  	d, err := api.newDataChannel(params, nil, api.settingEngine.LoggerFactory.NewLogger("ortc"))
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	err = d.open(transport)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	return d, nil
    84  }
    85  
    86  // newDataChannel is an internal constructor for the data channel used to
    87  // create the DataChannel object before the networking is set up.
    88  func (api *API) newDataChannel(params *DataChannelParameters, sctpTransport *SCTPTransport, log logging.LeveledLogger) (*DataChannel, error) {
    89  	// https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #5)
    90  	if len(params.Label) > 65535 {
    91  		return nil, &rtcerr.TypeError{Err: ErrStringSizeLimit}
    92  	}
    93  
    94  	d := &DataChannel{
    95  		sctpTransport:     sctpTransport,
    96  		statsID:           fmt.Sprintf("DataChannel-%d", time.Now().UnixNano()),
    97  		label:             params.Label,
    98  		protocol:          params.Protocol,
    99  		negotiated:        params.Negotiated,
   100  		id:                params.ID,
   101  		ordered:           params.Ordered,
   102  		maxPacketLifeTime: params.MaxPacketLifeTime,
   103  		maxRetransmits:    params.MaxRetransmits,
   104  		api:               api,
   105  		log:               log,
   106  	}
   107  
   108  	d.setReadyState(DataChannelStateConnecting)
   109  	return d, nil
   110  }
   111  
   112  // open opens the datachannel over the sctp transport
   113  func (d *DataChannel) open(sctpTransport *SCTPTransport) error {
   114  	association := sctpTransport.association()
   115  	if association == nil {
   116  		return errSCTPNotEstablished
   117  	}
   118  
   119  	d.mu.Lock()
   120  	if d.sctpTransport != nil { // already open
   121  		d.mu.Unlock()
   122  		return nil
   123  	}
   124  	d.sctpTransport = sctpTransport
   125  	var channelType datachannel.ChannelType
   126  	var reliabilityParameter uint32
   127  
   128  	switch {
   129  	case d.maxPacketLifeTime == nil && d.maxRetransmits == nil:
   130  		if d.ordered {
   131  			channelType = datachannel.ChannelTypeReliable
   132  		} else {
   133  			channelType = datachannel.ChannelTypeReliableUnordered
   134  		}
   135  
   136  	case d.maxRetransmits != nil:
   137  		reliabilityParameter = uint32(*d.maxRetransmits)
   138  		if d.ordered {
   139  			channelType = datachannel.ChannelTypePartialReliableRexmit
   140  		} else {
   141  			channelType = datachannel.ChannelTypePartialReliableRexmitUnordered
   142  		}
   143  	default:
   144  		reliabilityParameter = uint32(*d.maxPacketLifeTime)
   145  		if d.ordered {
   146  			channelType = datachannel.ChannelTypePartialReliableTimed
   147  		} else {
   148  			channelType = datachannel.ChannelTypePartialReliableTimedUnordered
   149  		}
   150  	}
   151  
   152  	cfg := &datachannel.Config{
   153  		ChannelType:          channelType,
   154  		Priority:             datachannel.ChannelPriorityNormal,
   155  		ReliabilityParameter: reliabilityParameter,
   156  		Label:                d.label,
   157  		Protocol:             d.protocol,
   158  		Negotiated:           d.negotiated,
   159  		LoggerFactory:        d.api.settingEngine.LoggerFactory,
   160  	}
   161  
   162  	if d.id == nil {
   163  		// avoid holding lock when generating ID, since id generation locks
   164  		d.mu.Unlock()
   165  		var dcID *uint16
   166  		err := d.sctpTransport.generateAndSetDataChannelID(d.sctpTransport.dtlsTransport.role(), &dcID)
   167  		if err != nil {
   168  			return err
   169  		}
   170  		d.mu.Lock()
   171  		d.id = dcID
   172  	}
   173  	dc, err := datachannel.Dial(association, *d.id, cfg)
   174  	if err != nil {
   175  		d.mu.Unlock()
   176  		return err
   177  	}
   178  
   179  	// bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier
   180  	dc.SetBufferedAmountLowThreshold(d.bufferedAmountLowThreshold)
   181  	dc.OnBufferedAmountLow(d.onBufferedAmountLow)
   182  	d.mu.Unlock()
   183  
   184  	d.onDial()
   185  	d.handleOpen(dc, false, d.negotiated)
   186  	return nil
   187  }
   188  
   189  // Transport returns the SCTPTransport instance the DataChannel is sending over.
   190  func (d *DataChannel) Transport() *SCTPTransport {
   191  	d.mu.RLock()
   192  	defer d.mu.RUnlock()
   193  
   194  	return d.sctpTransport
   195  }
   196  
   197  // After onOpen is complete check that the user called detach
   198  // and provide an error message if the call was missed
   199  func (d *DataChannel) checkDetachAfterOpen() {
   200  	d.mu.RLock()
   201  	defer d.mu.RUnlock()
   202  
   203  	if d.api.settingEngine.detach.DataChannels && !d.detachCalled {
   204  		d.log.Warn("webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen")
   205  	}
   206  }
   207  
   208  // OnOpen sets an event handler which is invoked when
   209  // the underlying data transport has been established (or re-established).
   210  func (d *DataChannel) OnOpen(f func()) {
   211  	d.mu.Lock()
   212  	d.openHandlerOnce = sync.Once{}
   213  	d.onOpenHandler = f
   214  	d.mu.Unlock()
   215  
   216  	if d.ReadyState() == DataChannelStateOpen {
   217  		// If the data channel is already open, call the handler immediately.
   218  		go d.openHandlerOnce.Do(func() {
   219  			f()
   220  			d.checkDetachAfterOpen()
   221  		})
   222  	}
   223  }
   224  
   225  func (d *DataChannel) onOpen() {
   226  	d.mu.RLock()
   227  	handler := d.onOpenHandler
   228  	d.mu.RUnlock()
   229  
   230  	if handler != nil {
   231  		go d.openHandlerOnce.Do(func() {
   232  			handler()
   233  			d.checkDetachAfterOpen()
   234  		})
   235  	}
   236  }
   237  
   238  // OnDial sets an event handler which is invoked when the
   239  // peer has been dialed, but before said peer has responsed
   240  func (d *DataChannel) OnDial(f func()) {
   241  	d.mu.Lock()
   242  	d.dialHandlerOnce = sync.Once{}
   243  	d.onDialHandler = f
   244  	d.mu.Unlock()
   245  
   246  	if d.ReadyState() == DataChannelStateOpen {
   247  		// If the data channel is already open, call the handler immediately.
   248  		go d.dialHandlerOnce.Do(f)
   249  	}
   250  }
   251  
   252  func (d *DataChannel) onDial() {
   253  	d.mu.RLock()
   254  	handler := d.onDialHandler
   255  	d.mu.RUnlock()
   256  
   257  	if handler != nil {
   258  		go d.dialHandlerOnce.Do(handler)
   259  	}
   260  }
   261  
   262  // OnClose sets an event handler which is invoked when
   263  // the underlying data transport has been closed.
   264  func (d *DataChannel) OnClose(f func()) {
   265  	d.mu.Lock()
   266  	defer d.mu.Unlock()
   267  	d.onCloseHandler = f
   268  }
   269  
   270  func (d *DataChannel) onClose() {
   271  	d.mu.RLock()
   272  	handler := d.onCloseHandler
   273  	d.mu.RUnlock()
   274  
   275  	if handler != nil {
   276  		go handler()
   277  	}
   278  }
   279  
   280  // OnMessage sets an event handler which is invoked on a binary
   281  // message arrival over the sctp transport from a remote peer.
   282  // OnMessage can currently receive messages up to 16384 bytes
   283  // in size. Check out the detach API if you want to use larger
   284  // message sizes. Note that browser support for larger messages
   285  // is also limited.
   286  func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) {
   287  	d.mu.Lock()
   288  	defer d.mu.Unlock()
   289  	d.onMessageHandler = f
   290  }
   291  
   292  func (d *DataChannel) onMessage(msg DataChannelMessage) {
   293  	d.mu.RLock()
   294  	handler := d.onMessageHandler
   295  	d.mu.RUnlock()
   296  
   297  	if handler == nil {
   298  		return
   299  	}
   300  	handler(msg)
   301  }
   302  
   303  func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) {
   304  	d.mu.Lock()
   305  	d.dataChannel = dc
   306  	bufferedAmountLowThreshold := d.bufferedAmountLowThreshold
   307  	onBufferedAmountLow := d.onBufferedAmountLow
   308  	d.mu.Unlock()
   309  	d.setReadyState(DataChannelStateOpen)
   310  
   311  	// Fire the OnOpen handler immediately not using pion/datachannel
   312  	// * detached datachannels have no read loop, the user needs to read and query themselves
   313  	// * remote datachannels should fire OnOpened. This isn't spec compliant, but we can't break behavior yet
   314  	// * already negotiated datachannels should fire OnOpened
   315  	if d.api.settingEngine.detach.DataChannels || isRemote || isAlreadyNegotiated {
   316  		// bufferedAmountLowThreshold and onBufferedAmountLow might be set earlier
   317  		d.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
   318  		d.dataChannel.OnBufferedAmountLow(onBufferedAmountLow)
   319  		d.onOpen()
   320  	} else {
   321  		dc.OnOpen(func() {
   322  			d.onOpen()
   323  		})
   324  	}
   325  
   326  	d.mu.Lock()
   327  	defer d.mu.Unlock()
   328  
   329  	if !d.api.settingEngine.detach.DataChannels {
   330  		go d.readLoop()
   331  	}
   332  }
   333  
   334  // OnError sets an event handler which is invoked when
   335  // the underlying data transport cannot be read.
   336  func (d *DataChannel) OnError(f func(err error)) {
   337  	d.mu.Lock()
   338  	defer d.mu.Unlock()
   339  	d.onErrorHandler = f
   340  }
   341  
   342  func (d *DataChannel) onError(err error) {
   343  	d.mu.RLock()
   344  	handler := d.onErrorHandler
   345  	d.mu.RUnlock()
   346  
   347  	if handler != nil {
   348  		go handler(err)
   349  	}
   350  }
   351  
   352  // See https://github.com/pion/webrtc/issues/1516
   353  // nolint:gochecknoglobals
   354  var rlBufPool = sync.Pool{New: func() interface{} {
   355  	return make([]byte, dataChannelBufferSize)
   356  }}
   357  
   358  func (d *DataChannel) readLoop() {
   359  	for {
   360  		buffer := rlBufPool.Get().([]byte) //nolint:forcetypeassert
   361  		n, isString, err := d.dataChannel.ReadDataChannel(buffer)
   362  		if err != nil {
   363  			rlBufPool.Put(buffer) // nolint:staticcheck
   364  			d.setReadyState(DataChannelStateClosed)
   365  			if !errors.Is(err, io.EOF) {
   366  				d.onError(err)
   367  			}
   368  			d.onClose()
   369  			return
   370  		}
   371  
   372  		m := DataChannelMessage{Data: make([]byte, n), IsString: isString}
   373  		copy(m.Data, buffer[:n])
   374  		// The 'staticcheck' pragma is a false positive on the part of the CI linter.
   375  		rlBufPool.Put(buffer) // nolint:staticcheck
   376  
   377  		// NB: Why was DataChannelMessage not passed as a pointer value?
   378  		d.onMessage(m) // nolint:staticcheck
   379  	}
   380  }
   381  
   382  // Send sends the binary message to the DataChannel peer
   383  func (d *DataChannel) Send(data []byte) error {
   384  	err := d.ensureOpen()
   385  	if err != nil {
   386  		return err
   387  	}
   388  
   389  	_, err = d.dataChannel.WriteDataChannel(data, false)
   390  	return err
   391  }
   392  
   393  // SendText sends the text message to the DataChannel peer
   394  func (d *DataChannel) SendText(s string) error {
   395  	err := d.ensureOpen()
   396  	if err != nil {
   397  		return err
   398  	}
   399  
   400  	_, err = d.dataChannel.WriteDataChannel([]byte(s), true)
   401  	return err
   402  }
   403  
   404  func (d *DataChannel) ensureOpen() error {
   405  	d.mu.RLock()
   406  	defer d.mu.RUnlock()
   407  	if d.ReadyState() != DataChannelStateOpen {
   408  		return io.ErrClosedPipe
   409  	}
   410  	return nil
   411  }
   412  
   413  // Detach allows you to detach the underlying datachannel. This provides
   414  // an idiomatic API to work with, however it disables the OnMessage callback.
   415  // Before calling Detach you have to enable this behavior by calling
   416  // webrtc.DetachDataChannels(). Combining detached and normal data channels
   417  // is not supported.
   418  // Please refer to the data-channels-detach example and the
   419  // pion/datachannel documentation for the correct way to handle the
   420  // resulting DataChannel object.
   421  func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
   422  	d.mu.Lock()
   423  	defer d.mu.Unlock()
   424  
   425  	if !d.api.settingEngine.detach.DataChannels {
   426  		return nil, errDetachNotEnabled
   427  	}
   428  
   429  	if d.dataChannel == nil {
   430  		return nil, errDetachBeforeOpened
   431  	}
   432  
   433  	d.detachCalled = true
   434  
   435  	return d.dataChannel, nil
   436  }
   437  
   438  // Close Closes the DataChannel. It may be called regardless of whether
   439  // the DataChannel object was created by this peer or the remote peer.
   440  func (d *DataChannel) Close() error {
   441  	d.mu.Lock()
   442  	haveSctpTransport := d.dataChannel != nil
   443  	d.mu.Unlock()
   444  
   445  	if d.ReadyState() == DataChannelStateClosed {
   446  		return nil
   447  	}
   448  
   449  	d.setReadyState(DataChannelStateClosing)
   450  	if !haveSctpTransport {
   451  		return nil
   452  	}
   453  
   454  	return d.dataChannel.Close()
   455  }
   456  
   457  // Label represents a label that can be used to distinguish this
   458  // DataChannel object from other DataChannel objects. Scripts are
   459  // allowed to create multiple DataChannel objects with the same label.
   460  func (d *DataChannel) Label() string {
   461  	d.mu.RLock()
   462  	defer d.mu.RUnlock()
   463  
   464  	return d.label
   465  }
   466  
   467  // Ordered returns true if the DataChannel is ordered, and false if
   468  // out-of-order delivery is allowed.
   469  func (d *DataChannel) Ordered() bool {
   470  	d.mu.RLock()
   471  	defer d.mu.RUnlock()
   472  
   473  	return d.ordered
   474  }
   475  
   476  // MaxPacketLifeTime represents the length of the time window (msec) during
   477  // which transmissions and retransmissions may occur in unreliable mode.
   478  func (d *DataChannel) MaxPacketLifeTime() *uint16 {
   479  	d.mu.RLock()
   480  	defer d.mu.RUnlock()
   481  
   482  	return d.maxPacketLifeTime
   483  }
   484  
   485  // MaxRetransmits represents the maximum number of retransmissions that are
   486  // attempted in unreliable mode.
   487  func (d *DataChannel) MaxRetransmits() *uint16 {
   488  	d.mu.RLock()
   489  	defer d.mu.RUnlock()
   490  
   491  	return d.maxRetransmits
   492  }
   493  
   494  // Protocol represents the name of the sub-protocol used with this
   495  // DataChannel.
   496  func (d *DataChannel) Protocol() string {
   497  	d.mu.RLock()
   498  	defer d.mu.RUnlock()
   499  
   500  	return d.protocol
   501  }
   502  
   503  // Negotiated represents whether this DataChannel was negotiated by the
   504  // application (true), or not (false).
   505  func (d *DataChannel) Negotiated() bool {
   506  	d.mu.RLock()
   507  	defer d.mu.RUnlock()
   508  
   509  	return d.negotiated
   510  }
   511  
   512  // ID represents the ID for this DataChannel. The value is initially
   513  // null, which is what will be returned if the ID was not provided at
   514  // channel creation time, and the DTLS role of the SCTP transport has not
   515  // yet been negotiated. Otherwise, it will return the ID that was either
   516  // selected by the script or generated. After the ID is set to a non-null
   517  // value, it will not change.
   518  func (d *DataChannel) ID() *uint16 {
   519  	d.mu.RLock()
   520  	defer d.mu.RUnlock()
   521  
   522  	return d.id
   523  }
   524  
   525  // ReadyState represents the state of the DataChannel object.
   526  func (d *DataChannel) ReadyState() DataChannelState {
   527  	if v, ok := d.readyState.Load().(DataChannelState); ok {
   528  		return v
   529  	}
   530  	return DataChannelState(0)
   531  }
   532  
   533  // BufferedAmount represents the number of bytes of application data
   534  // (UTF-8 text and binary data) that have been queued using send(). Even
   535  // though the data transmission can occur in parallel, the returned value
   536  // MUST NOT be decreased before the current task yielded back to the event
   537  // loop to prevent race conditions. The value does not include framing
   538  // overhead incurred by the protocol, or buffering done by the operating
   539  // system or network hardware. The value of BufferedAmount slot will only
   540  // increase with each call to the send() method as long as the ReadyState is
   541  // open; however, BufferedAmount does not reset to zero once the channel
   542  // closes.
   543  func (d *DataChannel) BufferedAmount() uint64 {
   544  	d.mu.RLock()
   545  	defer d.mu.RUnlock()
   546  
   547  	if d.dataChannel == nil {
   548  		return 0
   549  	}
   550  	return d.dataChannel.BufferedAmount()
   551  }
   552  
   553  // BufferedAmountLowThreshold represents the threshold at which the
   554  // bufferedAmount is considered to be low. When the bufferedAmount decreases
   555  // from above this threshold to equal or below it, the bufferedamountlow
   556  // event fires. BufferedAmountLowThreshold is initially zero on each new
   557  // DataChannel, but the application may change its value at any time.
   558  // The threshold is set to 0 by default.
   559  func (d *DataChannel) BufferedAmountLowThreshold() uint64 {
   560  	d.mu.RLock()
   561  	defer d.mu.RUnlock()
   562  
   563  	if d.dataChannel == nil {
   564  		return d.bufferedAmountLowThreshold
   565  	}
   566  	return d.dataChannel.BufferedAmountLowThreshold()
   567  }
   568  
   569  // SetBufferedAmountLowThreshold is used to update the threshold.
   570  // See BufferedAmountLowThreshold().
   571  func (d *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
   572  	d.mu.Lock()
   573  	defer d.mu.Unlock()
   574  
   575  	d.bufferedAmountLowThreshold = th
   576  
   577  	if d.dataChannel != nil {
   578  		d.dataChannel.SetBufferedAmountLowThreshold(th)
   579  	}
   580  }
   581  
   582  // OnBufferedAmountLow sets an event handler which is invoked when
   583  // the number of bytes of outgoing data becomes lower than the
   584  // BufferedAmountLowThreshold.
   585  func (d *DataChannel) OnBufferedAmountLow(f func()) {
   586  	d.mu.Lock()
   587  	defer d.mu.Unlock()
   588  
   589  	d.onBufferedAmountLow = f
   590  	if d.dataChannel != nil {
   591  		d.dataChannel.OnBufferedAmountLow(f)
   592  	}
   593  }
   594  
   595  func (d *DataChannel) getStatsID() string {
   596  	d.mu.Lock()
   597  	defer d.mu.Unlock()
   598  	return d.statsID
   599  }
   600  
   601  func (d *DataChannel) collectStats(collector *statsReportCollector) {
   602  	collector.Collecting()
   603  
   604  	d.mu.Lock()
   605  	defer d.mu.Unlock()
   606  
   607  	stats := DataChannelStats{
   608  		Timestamp: statsTimestampNow(),
   609  		Type:      StatsTypeDataChannel,
   610  		ID:        d.statsID,
   611  		Label:     d.label,
   612  		Protocol:  d.protocol,
   613  		// TransportID string `json:"transportId"`
   614  		State: d.ReadyState(),
   615  	}
   616  
   617  	if d.id != nil {
   618  		stats.DataChannelIdentifier = int32(*d.id)
   619  	}
   620  
   621  	if d.dataChannel != nil {
   622  		stats.MessagesSent = d.dataChannel.MessagesSent()
   623  		stats.BytesSent = d.dataChannel.BytesSent()
   624  		stats.MessagesReceived = d.dataChannel.MessagesReceived()
   625  		stats.BytesReceived = d.dataChannel.BytesReceived()
   626  	}
   627  
   628  	collector.Collect(stats.ID, stats)
   629  }
   630  
   631  func (d *DataChannel) setReadyState(r DataChannelState) {
   632  	d.readyState.Store(r)
   633  }