github.com/pion/webrtc/v3@v3.2.24/sctptransport.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"errors"
    11  	"io"
    12  	"math"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/pion/datachannel"
    17  	"github.com/pion/logging"
    18  	"github.com/pion/sctp"
    19  	"github.com/pion/webrtc/v3/pkg/rtcerr"
    20  )
    21  
    22  const sctpMaxChannels = uint16(65535)
    23  
    24  // SCTPTransport provides details about the SCTP transport.
    25  type SCTPTransport struct {
    26  	lock sync.RWMutex
    27  
    28  	dtlsTransport *DTLSTransport
    29  
    30  	// State represents the current state of the SCTP transport.
    31  	state SCTPTransportState
    32  
    33  	// SCTPTransportState doesn't have an enum to distinguish between New/Connecting
    34  	// so we need a dedicated field
    35  	isStarted bool
    36  
    37  	// MaxMessageSize represents the maximum size of data that can be passed to
    38  	// DataChannel's send() method.
    39  	maxMessageSize float64
    40  
    41  	// MaxChannels represents the maximum amount of DataChannel's that can
    42  	// be used simultaneously.
    43  	maxChannels *uint16
    44  
    45  	// OnStateChange  func()
    46  
    47  	onErrorHandler func(error)
    48  
    49  	sctpAssociation            *sctp.Association
    50  	onDataChannelHandler       func(*DataChannel)
    51  	onDataChannelOpenedHandler func(*DataChannel)
    52  
    53  	// DataChannels
    54  	dataChannels          []*DataChannel
    55  	dataChannelsOpened    uint32
    56  	dataChannelsRequested uint32
    57  	dataChannelsAccepted  uint32
    58  
    59  	api *API
    60  	log logging.LeveledLogger
    61  }
    62  
    63  // NewSCTPTransport creates a new SCTPTransport.
    64  // This constructor is part of the ORTC API. It is not
    65  // meant to be used together with the basic WebRTC API.
    66  func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
    67  	res := &SCTPTransport{
    68  		dtlsTransport: dtls,
    69  		state:         SCTPTransportStateConnecting,
    70  		api:           api,
    71  		log:           api.settingEngine.LoggerFactory.NewLogger("ortc"),
    72  	}
    73  
    74  	res.updateMessageSize()
    75  	res.updateMaxChannels()
    76  
    77  	return res
    78  }
    79  
    80  // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
    81  func (r *SCTPTransport) Transport() *DTLSTransport {
    82  	r.lock.RLock()
    83  	defer r.lock.RUnlock()
    84  
    85  	return r.dtlsTransport
    86  }
    87  
    88  // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
    89  func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
    90  	return SCTPCapabilities{
    91  		MaxMessageSize: 0,
    92  	}
    93  }
    94  
    95  // Start the SCTPTransport. Since both local and remote parties must mutually
    96  // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
    97  // a connection over SCTP.
    98  func (r *SCTPTransport) Start(SCTPCapabilities) error {
    99  	if r.isStarted {
   100  		return nil
   101  	}
   102  	r.isStarted = true
   103  
   104  	dtlsTransport := r.Transport()
   105  	if dtlsTransport == nil || dtlsTransport.conn == nil {
   106  		return errSCTPTransportDTLS
   107  	}
   108  
   109  	sctpAssociation, err := sctp.Client(sctp.Config{
   110  		NetConn:              dtlsTransport.conn,
   111  		MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
   112  		LoggerFactory:        r.api.settingEngine.LoggerFactory,
   113  	})
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	r.lock.Lock()
   119  	r.sctpAssociation = sctpAssociation
   120  	r.state = SCTPTransportStateConnected
   121  	dataChannels := append([]*DataChannel{}, r.dataChannels...)
   122  	r.lock.Unlock()
   123  
   124  	var openedDCCount uint32
   125  	for _, d := range dataChannels {
   126  		if d.ReadyState() == DataChannelStateConnecting {
   127  			err := d.open(r)
   128  			if err != nil {
   129  				r.log.Warnf("failed to open data channel: %s", err)
   130  				continue
   131  			}
   132  			openedDCCount++
   133  		}
   134  	}
   135  
   136  	r.lock.Lock()
   137  	r.dataChannelsOpened += openedDCCount
   138  	r.lock.Unlock()
   139  
   140  	go r.acceptDataChannels(sctpAssociation)
   141  
   142  	return nil
   143  }
   144  
   145  // Stop stops the SCTPTransport
   146  func (r *SCTPTransport) Stop() error {
   147  	r.lock.Lock()
   148  	defer r.lock.Unlock()
   149  	if r.sctpAssociation == nil {
   150  		return nil
   151  	}
   152  	err := r.sctpAssociation.Close()
   153  	if err != nil {
   154  		return err
   155  	}
   156  
   157  	r.sctpAssociation = nil
   158  	r.state = SCTPTransportStateClosed
   159  
   160  	return nil
   161  }
   162  
   163  func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
   164  	r.lock.RLock()
   165  	dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
   166  	for _, dc := range r.dataChannels {
   167  		dc.mu.Lock()
   168  		isNil := dc.dataChannel == nil
   169  		dc.mu.Unlock()
   170  		if isNil {
   171  			continue
   172  		}
   173  		dataChannels = append(dataChannels, dc.dataChannel)
   174  	}
   175  	r.lock.RUnlock()
   176  ACCEPT:
   177  	for {
   178  		dc, err := datachannel.Accept(a, &datachannel.Config{
   179  			LoggerFactory: r.api.settingEngine.LoggerFactory,
   180  		}, dataChannels...)
   181  		if err != nil {
   182  			if !errors.Is(err, io.EOF) {
   183  				r.log.Errorf("Failed to accept data channel: %v", err)
   184  				r.onError(err)
   185  			}
   186  			return
   187  		}
   188  		for _, ch := range dataChannels {
   189  			if ch.StreamIdentifier() == dc.StreamIdentifier() {
   190  				continue ACCEPT
   191  			}
   192  		}
   193  
   194  		var (
   195  			maxRetransmits    *uint16
   196  			maxPacketLifeTime *uint16
   197  		)
   198  		val := uint16(dc.Config.ReliabilityParameter)
   199  		ordered := true
   200  
   201  		switch dc.Config.ChannelType {
   202  		case datachannel.ChannelTypeReliable:
   203  			ordered = true
   204  		case datachannel.ChannelTypeReliableUnordered:
   205  			ordered = false
   206  		case datachannel.ChannelTypePartialReliableRexmit:
   207  			ordered = true
   208  			maxRetransmits = &val
   209  		case datachannel.ChannelTypePartialReliableRexmitUnordered:
   210  			ordered = false
   211  			maxRetransmits = &val
   212  		case datachannel.ChannelTypePartialReliableTimed:
   213  			ordered = true
   214  			maxPacketLifeTime = &val
   215  		case datachannel.ChannelTypePartialReliableTimedUnordered:
   216  			ordered = false
   217  			maxPacketLifeTime = &val
   218  		default:
   219  		}
   220  
   221  		sid := dc.StreamIdentifier()
   222  		rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
   223  			ID:                &sid,
   224  			Label:             dc.Config.Label,
   225  			Protocol:          dc.Config.Protocol,
   226  			Negotiated:        dc.Config.Negotiated,
   227  			Ordered:           ordered,
   228  			MaxPacketLifeTime: maxPacketLifeTime,
   229  			MaxRetransmits:    maxRetransmits,
   230  		}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
   231  		if err != nil {
   232  			r.log.Errorf("Failed to accept data channel: %v", err)
   233  			r.onError(err)
   234  			return
   235  		}
   236  
   237  		<-r.onDataChannel(rtcDC)
   238  		rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
   239  
   240  		r.lock.Lock()
   241  		r.dataChannelsOpened++
   242  		handler := r.onDataChannelOpenedHandler
   243  		r.lock.Unlock()
   244  
   245  		if handler != nil {
   246  			handler(rtcDC)
   247  		}
   248  	}
   249  }
   250  
   251  // OnError sets an event handler which is invoked when
   252  // the SCTP connection error occurs.
   253  func (r *SCTPTransport) OnError(f func(err error)) {
   254  	r.lock.Lock()
   255  	defer r.lock.Unlock()
   256  	r.onErrorHandler = f
   257  }
   258  
   259  func (r *SCTPTransport) onError(err error) {
   260  	r.lock.RLock()
   261  	handler := r.onErrorHandler
   262  	r.lock.RUnlock()
   263  
   264  	if handler != nil {
   265  		go handler(err)
   266  	}
   267  }
   268  
   269  // OnDataChannel sets an event handler which is invoked when a data
   270  // channel message arrives from a remote peer.
   271  func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
   272  	r.lock.Lock()
   273  	defer r.lock.Unlock()
   274  	r.onDataChannelHandler = f
   275  }
   276  
   277  // OnDataChannelOpened sets an event handler which is invoked when a data
   278  // channel is opened
   279  func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
   280  	r.lock.Lock()
   281  	defer r.lock.Unlock()
   282  	r.onDataChannelOpenedHandler = f
   283  }
   284  
   285  func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
   286  	r.lock.Lock()
   287  	r.dataChannels = append(r.dataChannels, dc)
   288  	r.dataChannelsAccepted++
   289  	handler := r.onDataChannelHandler
   290  	r.lock.Unlock()
   291  
   292  	done = make(chan struct{})
   293  	if handler == nil || dc == nil {
   294  		close(done)
   295  		return
   296  	}
   297  
   298  	// Run this synchronously to allow setup done in onDataChannelFn()
   299  	// to complete before datachannel event handlers might be called.
   300  	go func() {
   301  		handler(dc)
   302  		close(done)
   303  	}()
   304  
   305  	return
   306  }
   307  
   308  func (r *SCTPTransport) updateMessageSize() {
   309  	r.lock.Lock()
   310  	defer r.lock.Unlock()
   311  
   312  	var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
   313  	var canSendSize float64 = 65536          // pion/webrtc#758
   314  
   315  	r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
   316  }
   317  
   318  func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
   319  	switch {
   320  	case remoteMaxMessageSize == 0 &&
   321  		canSendSize == 0:
   322  		return math.Inf(1)
   323  
   324  	case remoteMaxMessageSize == 0:
   325  		return canSendSize
   326  
   327  	case canSendSize == 0:
   328  		return remoteMaxMessageSize
   329  
   330  	case canSendSize > remoteMaxMessageSize:
   331  		return remoteMaxMessageSize
   332  
   333  	default:
   334  		return canSendSize
   335  	}
   336  }
   337  
   338  func (r *SCTPTransport) updateMaxChannels() {
   339  	val := sctpMaxChannels
   340  	r.maxChannels = &val
   341  }
   342  
   343  // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
   344  func (r *SCTPTransport) MaxChannels() uint16 {
   345  	r.lock.Lock()
   346  	defer r.lock.Unlock()
   347  
   348  	if r.maxChannels == nil {
   349  		return sctpMaxChannels
   350  	}
   351  
   352  	return *r.maxChannels
   353  }
   354  
   355  // State returns the current state of the SCTPTransport
   356  func (r *SCTPTransport) State() SCTPTransportState {
   357  	r.lock.RLock()
   358  	defer r.lock.RUnlock()
   359  	return r.state
   360  }
   361  
   362  func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
   363  	collector.Collecting()
   364  
   365  	stats := SCTPTransportStats{
   366  		Timestamp: statsTimestampFrom(time.Now()),
   367  		Type:      StatsTypeSCTPTransport,
   368  		ID:        "sctpTransport",
   369  	}
   370  
   371  	association := r.association()
   372  	if association != nil {
   373  		stats.BytesSent = association.BytesSent()
   374  		stats.BytesReceived = association.BytesReceived()
   375  		stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds
   376  		stats.CongestionWindow = association.CWND()
   377  		stats.ReceiverWindow = association.RWND()
   378  		stats.MTU = association.MTU()
   379  	}
   380  
   381  	collector.Collect(stats.ID, stats)
   382  }
   383  
   384  func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
   385  	var id uint16
   386  	if dtlsRole != DTLSRoleClient {
   387  		id++
   388  	}
   389  
   390  	max := r.MaxChannels()
   391  
   392  	r.lock.Lock()
   393  	defer r.lock.Unlock()
   394  
   395  	// Create map of ids so we can compare without double-looping each time.
   396  	idsMap := make(map[uint16]struct{}, len(r.dataChannels))
   397  	for _, dc := range r.dataChannels {
   398  		if dc.ID() == nil {
   399  			continue
   400  		}
   401  
   402  		idsMap[*dc.ID()] = struct{}{}
   403  	}
   404  
   405  	for ; id < max-1; id += 2 {
   406  		if _, ok := idsMap[id]; ok {
   407  			continue
   408  		}
   409  		*idOut = &id
   410  		return nil
   411  	}
   412  
   413  	return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
   414  }
   415  
   416  func (r *SCTPTransport) association() *sctp.Association {
   417  	if r == nil {
   418  		return nil
   419  	}
   420  	r.lock.RLock()
   421  	association := r.sctpAssociation
   422  	r.lock.RUnlock()
   423  	return association
   424  }