github.com/pion/webrtc/v4@v4.0.1/rtpreceiver.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"encoding/binary"
    11  	"fmt"
    12  	"io"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/pion/interceptor"
    17  	"github.com/pion/rtcp"
    18  	"github.com/pion/srtp/v3"
    19  	"github.com/pion/webrtc/v4/internal/util"
    20  )
    21  
    22  // trackStreams maintains a mapping of RTP/RTCP streams to a specific track
    23  // a RTPReceiver may contain multiple streams if we are dealing with Simulcast
    24  type trackStreams struct {
    25  	track *TrackRemote
    26  
    27  	streamInfo, repairStreamInfo *interceptor.StreamInfo
    28  
    29  	rtpReadStream  *srtp.ReadStreamSRTP
    30  	rtpInterceptor interceptor.RTPReader
    31  
    32  	rtcpReadStream  *srtp.ReadStreamSRTCP
    33  	rtcpInterceptor interceptor.RTCPReader
    34  
    35  	repairReadStream    *srtp.ReadStreamSRTP
    36  	repairInterceptor   interceptor.RTPReader
    37  	repairStreamChannel chan rtxPacketWithAttributes
    38  
    39  	repairRtcpReadStream  *srtp.ReadStreamSRTCP
    40  	repairRtcpInterceptor interceptor.RTCPReader
    41  }
    42  
    43  type rtxPacketWithAttributes struct {
    44  	pkt        []byte
    45  	attributes interceptor.Attributes
    46  	pool       *sync.Pool
    47  }
    48  
    49  func (p *rtxPacketWithAttributes) release() {
    50  	if p.pkt != nil {
    51  		b := p.pkt[:cap(p.pkt)]
    52  		p.pool.Put(b) // nolint:staticcheck
    53  		p.pkt = nil
    54  	}
    55  }
    56  
    57  // RTPReceiver allows an application to inspect the receipt of a TrackRemote
    58  type RTPReceiver struct {
    59  	kind      RTPCodecType
    60  	transport *DTLSTransport
    61  
    62  	tracks []trackStreams
    63  
    64  	closed, received chan interface{}
    65  	mu               sync.RWMutex
    66  
    67  	tr *RTPTransceiver
    68  
    69  	// A reference to the associated api object
    70  	api *API
    71  
    72  	rtxPool sync.Pool
    73  }
    74  
    75  // NewRTPReceiver constructs a new RTPReceiver
    76  func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RTPReceiver, error) {
    77  	if transport == nil {
    78  		return nil, errRTPReceiverDTLSTransportNil
    79  	}
    80  
    81  	r := &RTPReceiver{
    82  		kind:      kind,
    83  		transport: transport,
    84  		api:       api,
    85  		closed:    make(chan interface{}),
    86  		received:  make(chan interface{}),
    87  		tracks:    []trackStreams{},
    88  		rtxPool: sync.Pool{New: func() interface{} {
    89  			return make([]byte, api.settingEngine.getReceiveMTU())
    90  		}},
    91  	}
    92  
    93  	return r, nil
    94  }
    95  
    96  func (r *RTPReceiver) setRTPTransceiver(tr *RTPTransceiver) {
    97  	r.mu.Lock()
    98  	defer r.mu.Unlock()
    99  	r.tr = tr
   100  }
   101  
   102  // Transport returns the currently-configured *DTLSTransport or nil
   103  // if one has not yet been configured
   104  func (r *RTPReceiver) Transport() *DTLSTransport {
   105  	r.mu.RLock()
   106  	defer r.mu.RUnlock()
   107  	return r.transport
   108  }
   109  
   110  func (r *RTPReceiver) getParameters() RTPParameters {
   111  	parameters := r.api.mediaEngine.getRTPParametersByKind(r.kind, []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly})
   112  	if r.tr != nil {
   113  		parameters.Codecs = r.tr.getCodecs()
   114  	}
   115  	return parameters
   116  }
   117  
   118  // GetParameters describes the current configuration for the encoding and
   119  // transmission of media on the receiver's track.
   120  func (r *RTPReceiver) GetParameters() RTPParameters {
   121  	r.mu.RLock()
   122  	defer r.mu.RUnlock()
   123  	return r.getParameters()
   124  }
   125  
   126  // Track returns the RtpTransceiver TrackRemote
   127  func (r *RTPReceiver) Track() *TrackRemote {
   128  	r.mu.RLock()
   129  	defer r.mu.RUnlock()
   130  
   131  	if len(r.tracks) != 1 {
   132  		return nil
   133  	}
   134  	return r.tracks[0].track
   135  }
   136  
   137  // Tracks returns the RtpTransceiver tracks
   138  // A RTPReceiver to support Simulcast may now have multiple tracks
   139  func (r *RTPReceiver) Tracks() []*TrackRemote {
   140  	r.mu.RLock()
   141  	defer r.mu.RUnlock()
   142  
   143  	var tracks []*TrackRemote
   144  	for i := range r.tracks {
   145  		tracks = append(tracks, r.tracks[i].track)
   146  	}
   147  	return tracks
   148  }
   149  
   150  // RTPTransceiver returns the RTPTransceiver this
   151  // RTPReceiver belongs too, or nil if none
   152  func (r *RTPReceiver) RTPTransceiver() *RTPTransceiver {
   153  	r.mu.Lock()
   154  	defer r.mu.Unlock()
   155  
   156  	return r.tr
   157  }
   158  
   159  // configureReceive initialize the track
   160  func (r *RTPReceiver) configureReceive(parameters RTPReceiveParameters) {
   161  	r.mu.Lock()
   162  	defer r.mu.Unlock()
   163  
   164  	for i := range parameters.Encodings {
   165  		t := trackStreams{
   166  			track: newTrackRemote(
   167  				r.kind,
   168  				parameters.Encodings[i].SSRC,
   169  				parameters.Encodings[i].RTX.SSRC,
   170  				parameters.Encodings[i].RID,
   171  				r,
   172  			),
   173  		}
   174  
   175  		r.tracks = append(r.tracks, t)
   176  	}
   177  }
   178  
   179  // startReceive starts all the transports
   180  func (r *RTPReceiver) startReceive(parameters RTPReceiveParameters) error {
   181  	r.mu.Lock()
   182  	defer r.mu.Unlock()
   183  	select {
   184  	case <-r.received:
   185  		return errRTPReceiverReceiveAlreadyCalled
   186  	default:
   187  	}
   188  	defer close(r.received)
   189  
   190  	globalParams := r.getParameters()
   191  	codec := RTPCodecCapability{}
   192  	if len(globalParams.Codecs) != 0 {
   193  		codec = globalParams.Codecs[0].RTPCodecCapability
   194  	}
   195  
   196  	for i := range parameters.Encodings {
   197  		if parameters.Encodings[i].RID != "" {
   198  			// RID based tracks will be set up in receiveForRid
   199  			continue
   200  		}
   201  
   202  		var t *trackStreams
   203  		for idx, ts := range r.tracks {
   204  			if ts.track != nil && ts.track.SSRC() == parameters.Encodings[i].SSRC {
   205  				t = &r.tracks[idx]
   206  				break
   207  			}
   208  		}
   209  		if t == nil {
   210  			return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, parameters.Encodings[i].SSRC)
   211  		}
   212  
   213  		t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, 0, 0, 0, 0, codec, globalParams.HeaderExtensions)
   214  		var err error
   215  		if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil {
   216  			return err
   217  		}
   218  
   219  		if rtxSsrc := parameters.Encodings[i].RTX.SSRC; rtxSsrc != 0 {
   220  			streamInfo := createStreamInfo("", rtxSsrc, 0, 0, 0, 0, 0, codec, globalParams.HeaderExtensions)
   221  			rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, err := r.transport.streamsForSSRC(rtxSsrc, *streamInfo)
   222  			if err != nil {
   223  				return err
   224  			}
   225  
   226  			if err = r.receiveForRtx(rtxSsrc, "", streamInfo, rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor); err != nil {
   227  				return err
   228  			}
   229  		}
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  // Receive initialize the track and starts all the transports
   236  func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
   237  	r.configureReceive(parameters)
   238  	return r.startReceive(parameters)
   239  }
   240  
   241  // Read reads incoming RTCP for this RTPReceiver
   242  func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error) {
   243  	select {
   244  	case <-r.received:
   245  		return r.tracks[0].rtcpInterceptor.Read(b, a)
   246  	case <-r.closed:
   247  		return 0, nil, io.ErrClosedPipe
   248  	}
   249  }
   250  
   251  // ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid
   252  func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) {
   253  	select {
   254  	case <-r.received:
   255  		var rtcpInterceptor interceptor.RTCPReader
   256  
   257  		r.mu.Lock()
   258  		for _, t := range r.tracks {
   259  			if t.track != nil && t.track.rid == rid {
   260  				rtcpInterceptor = t.rtcpInterceptor
   261  			}
   262  		}
   263  		r.mu.Unlock()
   264  
   265  		if rtcpInterceptor == nil {
   266  			return 0, nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
   267  		}
   268  		return rtcpInterceptor.Read(b, a)
   269  
   270  	case <-r.closed:
   271  		return 0, nil, io.ErrClosedPipe
   272  	}
   273  }
   274  
   275  // ReadRTCP is a convenience method that wraps Read and unmarshal for you.
   276  // It also runs any configured interceptors.
   277  func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) {
   278  	b := make([]byte, r.api.settingEngine.getReceiveMTU())
   279  	i, attributes, err := r.Read(b)
   280  	if err != nil {
   281  		return nil, nil, err
   282  	}
   283  
   284  	pkts, err := rtcp.Unmarshal(b[:i])
   285  	if err != nil {
   286  		return nil, nil, err
   287  	}
   288  
   289  	return pkts, attributes, nil
   290  }
   291  
   292  // ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you
   293  func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) {
   294  	b := make([]byte, r.api.settingEngine.getReceiveMTU())
   295  	i, attributes, err := r.ReadSimulcast(b, rid)
   296  	if err != nil {
   297  		return nil, nil, err
   298  	}
   299  
   300  	pkts, err := rtcp.Unmarshal(b[:i])
   301  	return pkts, attributes, err
   302  }
   303  
   304  func (r *RTPReceiver) haveReceived() bool {
   305  	select {
   306  	case <-r.received:
   307  		return true
   308  	default:
   309  		return false
   310  	}
   311  }
   312  
   313  // Stop irreversibly stops the RTPReceiver
   314  func (r *RTPReceiver) Stop() error {
   315  	r.mu.Lock()
   316  	defer r.mu.Unlock()
   317  	var err error
   318  
   319  	select {
   320  	case <-r.closed:
   321  		return err
   322  	default:
   323  	}
   324  
   325  	select {
   326  	case <-r.received:
   327  		for i := range r.tracks {
   328  			errs := []error{}
   329  
   330  			if r.tracks[i].rtcpReadStream != nil {
   331  				errs = append(errs, r.tracks[i].rtcpReadStream.Close())
   332  			}
   333  
   334  			if r.tracks[i].rtpReadStream != nil {
   335  				errs = append(errs, r.tracks[i].rtpReadStream.Close())
   336  			}
   337  
   338  			if r.tracks[i].repairReadStream != nil {
   339  				errs = append(errs, r.tracks[i].repairReadStream.Close())
   340  			}
   341  
   342  			if r.tracks[i].repairRtcpReadStream != nil {
   343  				errs = append(errs, r.tracks[i].repairRtcpReadStream.Close())
   344  			}
   345  
   346  			if r.tracks[i].streamInfo != nil {
   347  				r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo)
   348  			}
   349  
   350  			if r.tracks[i].repairStreamInfo != nil {
   351  				r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo)
   352  			}
   353  
   354  			err = util.FlattenErrs(errs)
   355  		}
   356  	default:
   357  	}
   358  
   359  	close(r.closed)
   360  	return err
   361  }
   362  
   363  func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams {
   364  	for i := range r.tracks {
   365  		if r.tracks[i].track == t {
   366  			return &r.tracks[i]
   367  		}
   368  	}
   369  	return nil
   370  }
   371  
   372  // readRTP should only be called by a track, this only exists so we can keep state in one place
   373  func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) {
   374  	<-r.received
   375  	if t := r.streamsForTrack(reader); t != nil {
   376  		return t.rtpInterceptor.Read(b, a)
   377  	}
   378  
   379  	return 0, nil, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
   380  }
   381  
   382  // receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs
   383  // It populates all the internal state for the given RID
   384  func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) (*TrackRemote, error) {
   385  	r.mu.Lock()
   386  	defer r.mu.Unlock()
   387  
   388  	for i := range r.tracks {
   389  		if r.tracks[i].track.RID() == rid {
   390  			r.tracks[i].track.mu.Lock()
   391  			r.tracks[i].track.kind = r.kind
   392  			r.tracks[i].track.codec = params.Codecs[0]
   393  			r.tracks[i].track.params = params
   394  			r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
   395  			r.tracks[i].track.mu.Unlock()
   396  
   397  			r.tracks[i].streamInfo = streamInfo
   398  			r.tracks[i].rtpReadStream = rtpReadStream
   399  			r.tracks[i].rtpInterceptor = rtpInterceptor
   400  			r.tracks[i].rtcpReadStream = rtcpReadStream
   401  			r.tracks[i].rtcpInterceptor = rtcpInterceptor
   402  
   403  			return r.tracks[i].track, nil
   404  		}
   405  	}
   406  
   407  	return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
   408  }
   409  
   410  // receiveForRtx starts a routine that processes the repair stream
   411  func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) error {
   412  	var track *trackStreams
   413  	if ssrc != 0 && len(r.tracks) == 1 {
   414  		track = &r.tracks[0]
   415  	} else {
   416  		for i := range r.tracks {
   417  			if r.tracks[i].track.RID() == rsid {
   418  				track = &r.tracks[i]
   419  				if track.track.RtxSSRC() == 0 {
   420  					track.track.setRtxSSRC(SSRC(streamInfo.SSRC))
   421  				}
   422  				break
   423  			}
   424  		}
   425  	}
   426  
   427  	if track == nil {
   428  		return fmt.Errorf("%w: ssrc(%d) rsid(%s)", errRTPReceiverForRIDTrackStreamNotFound, ssrc, rsid)
   429  	}
   430  
   431  	track.repairStreamInfo = streamInfo
   432  	track.repairReadStream = rtpReadStream
   433  	track.repairInterceptor = rtpInterceptor
   434  	track.repairRtcpReadStream = rtcpReadStream
   435  	track.repairRtcpInterceptor = rtcpInterceptor
   436  	track.repairStreamChannel = make(chan rtxPacketWithAttributes, 50)
   437  
   438  	go func() {
   439  		for {
   440  			b := r.rtxPool.Get().([]byte) // nolint:forcetypeassert
   441  			i, attributes, err := track.repairInterceptor.Read(b, nil)
   442  			if err != nil {
   443  				r.rtxPool.Put(b) // nolint:staticcheck
   444  				return
   445  			}
   446  
   447  			// RTX packets have a different payload format. Move the OSN in the payload to the RTP header and rewrite the
   448  			// payload type and SSRC, so that we can return RTX packets to the caller 'transparently' i.e. in the same format
   449  			// as non-RTX RTP packets
   450  			hasExtension := b[0]&0b10000 > 0
   451  			hasPadding := b[0]&0b100000 > 0
   452  			csrcCount := b[0] & 0b1111
   453  			headerLength := uint16(12 + (4 * csrcCount))
   454  			paddingLength := 0
   455  			if hasExtension {
   456  				headerLength += 4 * (1 + binary.BigEndian.Uint16(b[headerLength+2:headerLength+4]))
   457  			}
   458  			if hasPadding {
   459  				paddingLength = int(b[i-1])
   460  			}
   461  
   462  			if i-int(headerLength)-paddingLength < 2 {
   463  				// BWE probe packet, ignore
   464  				r.rtxPool.Put(b) // nolint:staticcheck
   465  				continue
   466  			}
   467  
   468  			if attributes == nil {
   469  				attributes = make(interceptor.Attributes)
   470  			}
   471  			attributes.Set(AttributeRtxPayloadType, b[1]&0x7F)
   472  			attributes.Set(AttributeRtxSequenceNumber, binary.BigEndian.Uint16(b[2:4]))
   473  			attributes.Set(AttributeRtxSsrc, binary.BigEndian.Uint32(b[8:12]))
   474  
   475  			b[1] = (b[1] & 0x80) | uint8(track.track.PayloadType())
   476  			b[2] = b[headerLength]
   477  			b[3] = b[headerLength+1]
   478  			binary.BigEndian.PutUint32(b[8:12], uint32(track.track.SSRC()))
   479  			copy(b[headerLength:i-2], b[headerLength+2:i])
   480  
   481  			select {
   482  			case <-r.closed:
   483  				r.rtxPool.Put(b) // nolint:staticcheck
   484  				return
   485  			case track.repairStreamChannel <- rtxPacketWithAttributes{pkt: b[:i-2], attributes: attributes, pool: &r.rtxPool}:
   486  			default:
   487  				// skip the RTX packet if the repair stream channel is full, could be blocked in the application's read loop
   488  			}
   489  		}
   490  	}()
   491  	return nil
   492  }
   493  
   494  // SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever.
   495  func (r *RTPReceiver) SetReadDeadline(t time.Time) error {
   496  	r.mu.RLock()
   497  	defer r.mu.RUnlock()
   498  
   499  	return r.tracks[0].rtcpReadStream.SetReadDeadline(t)
   500  }
   501  
   502  // SetReadDeadlineSimulcast sets the max amount of time the RTCP stream for a given rid will block before returning. 0 is forever.
   503  func (r *RTPReceiver) SetReadDeadlineSimulcast(deadline time.Time, rid string) error {
   504  	r.mu.RLock()
   505  	defer r.mu.RUnlock()
   506  
   507  	for _, t := range r.tracks {
   508  		if t.track != nil && t.track.rid == rid {
   509  			return t.rtcpReadStream.SetReadDeadline(deadline)
   510  		}
   511  	}
   512  	return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
   513  }
   514  
   515  // setRTPReadDeadline sets the max amount of time the RTP stream will block before returning. 0 is forever.
   516  // This should be fired by calling SetReadDeadline on the TrackRemote
   517  func (r *RTPReceiver) setRTPReadDeadline(deadline time.Time, reader *TrackRemote) error {
   518  	r.mu.RLock()
   519  	defer r.mu.RUnlock()
   520  
   521  	if t := r.streamsForTrack(reader); t != nil {
   522  		return t.rtpReadStream.SetReadDeadline(deadline)
   523  	}
   524  	return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC())
   525  }
   526  
   527  // readRTX returns an RTX packet if one is available on the RTX track, otherwise returns nil
   528  func (r *RTPReceiver) readRTX(reader *TrackRemote) *rtxPacketWithAttributes {
   529  	if !reader.HasRTX() {
   530  		return nil
   531  	}
   532  
   533  	select {
   534  	case <-r.received:
   535  	default:
   536  		return nil
   537  	}
   538  
   539  	if t := r.streamsForTrack(reader); t != nil {
   540  		select {
   541  		case rtxPacketReceived := <-t.repairStreamChannel:
   542  			return &rtxPacketReceived
   543  		default:
   544  		}
   545  	}
   546  	return nil
   547  }