github.com/pion/webrtc/v4@v4.0.1/sctptransport.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"errors"
    11  	"io"
    12  	"math"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/pion/datachannel"
    17  	"github.com/pion/logging"
    18  	"github.com/pion/sctp"
    19  	"github.com/pion/webrtc/v4/pkg/rtcerr"
    20  )
    21  
    22  const sctpMaxChannels = uint16(65535)
    23  
    24  // SCTPTransport provides details about the SCTP transport.
    25  type SCTPTransport struct {
    26  	lock sync.RWMutex
    27  
    28  	dtlsTransport *DTLSTransport
    29  
    30  	// State represents the current state of the SCTP transport.
    31  	state SCTPTransportState
    32  
    33  	// SCTPTransportState doesn't have an enum to distinguish between New/Connecting
    34  	// so we need a dedicated field
    35  	isStarted bool
    36  
    37  	// MaxMessageSize represents the maximum size of data that can be passed to
    38  	// DataChannel's send() method.
    39  	maxMessageSize float64
    40  
    41  	// MaxChannels represents the maximum amount of DataChannel's that can
    42  	// be used simultaneously.
    43  	maxChannels *uint16
    44  
    45  	// OnStateChange  func()
    46  
    47  	onErrorHandler func(error)
    48  	onCloseHandler func(error)
    49  
    50  	sctpAssociation            *sctp.Association
    51  	onDataChannelHandler       func(*DataChannel)
    52  	onDataChannelOpenedHandler func(*DataChannel)
    53  
    54  	// DataChannels
    55  	dataChannels          []*DataChannel
    56  	dataChannelIDsUsed    map[uint16]struct{}
    57  	dataChannelsOpened    uint32
    58  	dataChannelsRequested uint32
    59  	dataChannelsAccepted  uint32
    60  
    61  	api *API
    62  	log logging.LeveledLogger
    63  }
    64  
    65  // NewSCTPTransport creates a new SCTPTransport.
    66  // This constructor is part of the ORTC API. It is not
    67  // meant to be used together with the basic WebRTC API.
    68  func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
    69  	res := &SCTPTransport{
    70  		dtlsTransport:      dtls,
    71  		state:              SCTPTransportStateConnecting,
    72  		api:                api,
    73  		log:                api.settingEngine.LoggerFactory.NewLogger("ortc"),
    74  		dataChannelIDsUsed: make(map[uint16]struct{}),
    75  	}
    76  
    77  	res.updateMessageSize()
    78  	res.updateMaxChannels()
    79  
    80  	return res
    81  }
    82  
    83  // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
    84  func (r *SCTPTransport) Transport() *DTLSTransport {
    85  	r.lock.RLock()
    86  	defer r.lock.RUnlock()
    87  
    88  	return r.dtlsTransport
    89  }
    90  
    91  // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
    92  func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
    93  	return SCTPCapabilities{
    94  		MaxMessageSize: 0,
    95  	}
    96  }
    97  
    98  // Start the SCTPTransport. Since both local and remote parties must mutually
    99  // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
   100  // a connection over SCTP.
   101  func (r *SCTPTransport) Start(SCTPCapabilities) error {
   102  	if r.isStarted {
   103  		return nil
   104  	}
   105  	r.isStarted = true
   106  
   107  	dtlsTransport := r.Transport()
   108  	if dtlsTransport == nil || dtlsTransport.conn == nil {
   109  		return errSCTPTransportDTLS
   110  	}
   111  	sctpAssociation, err := sctp.Client(sctp.Config{
   112  		NetConn:              dtlsTransport.conn,
   113  		MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
   114  		EnableZeroChecksum:   r.api.settingEngine.sctp.enableZeroChecksum,
   115  		LoggerFactory:        r.api.settingEngine.LoggerFactory,
   116  		RTOMax:               float64(r.api.settingEngine.sctp.rtoMax) / float64(time.Millisecond),
   117  	})
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	r.lock.Lock()
   123  	r.sctpAssociation = sctpAssociation
   124  	r.state = SCTPTransportStateConnected
   125  	dataChannels := append([]*DataChannel{}, r.dataChannels...)
   126  	r.lock.Unlock()
   127  
   128  	var openedDCCount uint32
   129  	for _, d := range dataChannels {
   130  		if d.ReadyState() == DataChannelStateConnecting {
   131  			err := d.open(r)
   132  			if err != nil {
   133  				r.log.Warnf("failed to open data channel: %s", err)
   134  				continue
   135  			}
   136  			openedDCCount++
   137  		}
   138  	}
   139  
   140  	r.lock.Lock()
   141  	r.dataChannelsOpened += openedDCCount
   142  	r.lock.Unlock()
   143  
   144  	go r.acceptDataChannels(sctpAssociation)
   145  
   146  	return nil
   147  }
   148  
   149  // Stop stops the SCTPTransport
   150  func (r *SCTPTransport) Stop() error {
   151  	r.lock.Lock()
   152  	defer r.lock.Unlock()
   153  	if r.sctpAssociation == nil {
   154  		return nil
   155  	}
   156  
   157  	r.sctpAssociation.Abort("")
   158  
   159  	r.sctpAssociation = nil
   160  	r.state = SCTPTransportStateClosed
   161  
   162  	return nil
   163  }
   164  
   165  func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
   166  	r.lock.RLock()
   167  	dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
   168  	for _, dc := range r.dataChannels {
   169  		dc.mu.Lock()
   170  		isNil := dc.dataChannel == nil
   171  		dc.mu.Unlock()
   172  		if isNil {
   173  			continue
   174  		}
   175  		dataChannels = append(dataChannels, dc.dataChannel)
   176  	}
   177  	r.lock.RUnlock()
   178  
   179  ACCEPT:
   180  	for {
   181  		dc, err := datachannel.Accept(a, &datachannel.Config{
   182  			LoggerFactory: r.api.settingEngine.LoggerFactory,
   183  		}, dataChannels...)
   184  		if err != nil {
   185  			if !errors.Is(err, io.EOF) {
   186  				r.log.Errorf("Failed to accept data channel: %v", err)
   187  				r.onError(err)
   188  				r.onClose(err)
   189  			} else {
   190  				r.onClose(nil)
   191  			}
   192  			return
   193  		}
   194  		for _, ch := range dataChannels {
   195  			if ch.StreamIdentifier() == dc.StreamIdentifier() {
   196  				continue ACCEPT
   197  			}
   198  		}
   199  
   200  		var (
   201  			maxRetransmits    *uint16
   202  			maxPacketLifeTime *uint16
   203  		)
   204  		val := uint16(dc.Config.ReliabilityParameter)
   205  		ordered := true
   206  
   207  		switch dc.Config.ChannelType {
   208  		case datachannel.ChannelTypeReliable:
   209  			ordered = true
   210  		case datachannel.ChannelTypeReliableUnordered:
   211  			ordered = false
   212  		case datachannel.ChannelTypePartialReliableRexmit:
   213  			ordered = true
   214  			maxRetransmits = &val
   215  		case datachannel.ChannelTypePartialReliableRexmitUnordered:
   216  			ordered = false
   217  			maxRetransmits = &val
   218  		case datachannel.ChannelTypePartialReliableTimed:
   219  			ordered = true
   220  			maxPacketLifeTime = &val
   221  		case datachannel.ChannelTypePartialReliableTimedUnordered:
   222  			ordered = false
   223  			maxPacketLifeTime = &val
   224  		default:
   225  		}
   226  
   227  		sid := dc.StreamIdentifier()
   228  		rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
   229  			ID:                &sid,
   230  			Label:             dc.Config.Label,
   231  			Protocol:          dc.Config.Protocol,
   232  			Negotiated:        dc.Config.Negotiated,
   233  			Ordered:           ordered,
   234  			MaxPacketLifeTime: maxPacketLifeTime,
   235  			MaxRetransmits:    maxRetransmits,
   236  		}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
   237  		if err != nil {
   238  			// This data channel is invalid. Close it and log an error.
   239  			if err1 := dc.Close(); err1 != nil {
   240  				r.log.Errorf("Failed to close invalid data channel: %v", err1)
   241  			}
   242  			r.log.Errorf("Failed to accept data channel: %v", err)
   243  			r.onError(err)
   244  			// We've received a datachannel with invalid configuration. We can still receive other datachannels.
   245  			continue ACCEPT
   246  		}
   247  
   248  		<-r.onDataChannel(rtcDC)
   249  		rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
   250  
   251  		r.lock.Lock()
   252  		r.dataChannelsOpened++
   253  		handler := r.onDataChannelOpenedHandler
   254  		r.lock.Unlock()
   255  
   256  		if handler != nil {
   257  			handler(rtcDC)
   258  		}
   259  	}
   260  }
   261  
   262  // OnError sets an event handler which is invoked when the SCTP Association errors.
   263  func (r *SCTPTransport) OnError(f func(err error)) {
   264  	r.lock.Lock()
   265  	defer r.lock.Unlock()
   266  	r.onErrorHandler = f
   267  }
   268  
   269  func (r *SCTPTransport) onError(err error) {
   270  	r.lock.RLock()
   271  	handler := r.onErrorHandler
   272  	r.lock.RUnlock()
   273  
   274  	if handler != nil {
   275  		go handler(err)
   276  	}
   277  }
   278  
   279  // OnClose sets an event handler which is invoked when the SCTP Association closes.
   280  func (r *SCTPTransport) OnClose(f func(err error)) {
   281  	r.lock.Lock()
   282  	defer r.lock.Unlock()
   283  	r.onCloseHandler = f
   284  }
   285  
   286  func (r *SCTPTransport) onClose(err error) {
   287  	r.lock.RLock()
   288  	handler := r.onCloseHandler
   289  	r.lock.RUnlock()
   290  
   291  	if handler != nil {
   292  		go handler(err)
   293  	}
   294  }
   295  
   296  // OnDataChannel sets an event handler which is invoked when a data
   297  // channel message arrives from a remote peer.
   298  func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
   299  	r.lock.Lock()
   300  	defer r.lock.Unlock()
   301  	r.onDataChannelHandler = f
   302  }
   303  
   304  // OnDataChannelOpened sets an event handler which is invoked when a data
   305  // channel is opened
   306  func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
   307  	r.lock.Lock()
   308  	defer r.lock.Unlock()
   309  	r.onDataChannelOpenedHandler = f
   310  }
   311  
   312  func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
   313  	r.lock.Lock()
   314  	r.dataChannels = append(r.dataChannels, dc)
   315  	r.dataChannelsAccepted++
   316  	if dc.ID() != nil {
   317  		r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
   318  	} else {
   319  		// This cannot happen, the constructor for this datachannel in the caller
   320  		// takes a pointer to the id.
   321  		r.log.Errorf("accepted data channel with no ID")
   322  	}
   323  	handler := r.onDataChannelHandler
   324  	r.lock.Unlock()
   325  
   326  	done = make(chan struct{})
   327  	if handler == nil || dc == nil {
   328  		close(done)
   329  		return
   330  	}
   331  
   332  	// Run this synchronously to allow setup done in onDataChannelFn()
   333  	// to complete before datachannel event handlers might be called.
   334  	go func() {
   335  		handler(dc)
   336  		close(done)
   337  	}()
   338  
   339  	return
   340  }
   341  
   342  func (r *SCTPTransport) updateMessageSize() {
   343  	r.lock.Lock()
   344  	defer r.lock.Unlock()
   345  
   346  	var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
   347  	var canSendSize float64 = 65536          // pion/webrtc#758
   348  
   349  	r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
   350  }
   351  
   352  func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
   353  	switch {
   354  	case remoteMaxMessageSize == 0 &&
   355  		canSendSize == 0:
   356  		return math.Inf(1)
   357  
   358  	case remoteMaxMessageSize == 0:
   359  		return canSendSize
   360  
   361  	case canSendSize == 0:
   362  		return remoteMaxMessageSize
   363  
   364  	case canSendSize > remoteMaxMessageSize:
   365  		return remoteMaxMessageSize
   366  
   367  	default:
   368  		return canSendSize
   369  	}
   370  }
   371  
   372  func (r *SCTPTransport) updateMaxChannels() {
   373  	val := sctpMaxChannels
   374  	r.maxChannels = &val
   375  }
   376  
   377  // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
   378  func (r *SCTPTransport) MaxChannels() uint16 {
   379  	r.lock.Lock()
   380  	defer r.lock.Unlock()
   381  
   382  	if r.maxChannels == nil {
   383  		return sctpMaxChannels
   384  	}
   385  
   386  	return *r.maxChannels
   387  }
   388  
   389  // State returns the current state of the SCTPTransport
   390  func (r *SCTPTransport) State() SCTPTransportState {
   391  	r.lock.RLock()
   392  	defer r.lock.RUnlock()
   393  	return r.state
   394  }
   395  
   396  func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
   397  	collector.Collecting()
   398  
   399  	stats := SCTPTransportStats{
   400  		Timestamp: statsTimestampFrom(time.Now()),
   401  		Type:      StatsTypeSCTPTransport,
   402  		ID:        "sctpTransport",
   403  	}
   404  
   405  	association := r.association()
   406  	if association != nil {
   407  		stats.BytesSent = association.BytesSent()
   408  		stats.BytesReceived = association.BytesReceived()
   409  		stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds
   410  		stats.CongestionWindow = association.CWND()
   411  		stats.ReceiverWindow = association.RWND()
   412  		stats.MTU = association.MTU()
   413  	}
   414  
   415  	collector.Collect(stats.ID, stats)
   416  }
   417  
   418  func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
   419  	var id uint16
   420  	if dtlsRole != DTLSRoleClient {
   421  		id++
   422  	}
   423  
   424  	maxVal := r.MaxChannels()
   425  
   426  	r.lock.Lock()
   427  	defer r.lock.Unlock()
   428  
   429  	for ; id < maxVal-1; id += 2 {
   430  		if _, ok := r.dataChannelIDsUsed[id]; ok {
   431  			continue
   432  		}
   433  		*idOut = &id
   434  		r.dataChannelIDsUsed[id] = struct{}{}
   435  		return nil
   436  	}
   437  
   438  	return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
   439  }
   440  
   441  func (r *SCTPTransport) association() *sctp.Association {
   442  	if r == nil {
   443  		return nil
   444  	}
   445  	r.lock.RLock()
   446  	association := r.sctpAssociation
   447  	r.lock.RUnlock()
   448  	return association
   449  }