github.com/pion/webrtc/v4@v4.0.1/interceptor_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  //
    10  import (
    11  	"context"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/pion/interceptor"
    17  	mock_interceptor "github.com/pion/interceptor/pkg/mock"
    18  	"github.com/pion/rtp"
    19  	"github.com/pion/transport/v3/test"
    20  	"github.com/pion/webrtc/v4/pkg/media"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  // E2E test of the features of Interceptors
    25  // * Assert an extension can be set on an outbound packet
    26  // * Assert an extension can be read on an outbound packet
    27  // * Assert that attributes set by an interceptor are returned to the Reader
    28  func TestPeerConnection_Interceptor(t *testing.T) {
    29  	to := test.TimeOut(time.Second * 20)
    30  	defer to.Stop()
    31  
    32  	report := test.CheckRoutines(t)
    33  	defer report()
    34  
    35  	createPC := func() *PeerConnection {
    36  		ir := &interceptor.Registry{}
    37  		ir.Add(&mock_interceptor.Factory{
    38  			NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
    39  				return &mock_interceptor.Interceptor{
    40  					BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
    41  						return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
    42  							// set extension on outgoing packet
    43  							header.Extension = true
    44  							header.ExtensionProfile = 0xBEDE
    45  							assert.NoError(t, header.SetExtension(2, []byte("foo")))
    46  
    47  							return writer.Write(header, payload, attributes)
    48  						})
    49  					},
    50  					BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
    51  						return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
    52  							if a == nil {
    53  								a = interceptor.Attributes{}
    54  							}
    55  
    56  							a.Set("attribute", "value")
    57  							return reader.Read(b, a)
    58  						})
    59  					},
    60  				}, nil
    61  			},
    62  		})
    63  
    64  		pc, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
    65  		assert.NoError(t, err)
    66  
    67  		return pc
    68  	}
    69  
    70  	offerer := createPC()
    71  	answerer := createPC()
    72  
    73  	track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
    74  	assert.NoError(t, err)
    75  
    76  	_, err = offerer.AddTrack(track)
    77  	assert.NoError(t, err)
    78  
    79  	seenRTP, seenRTPCancel := context.WithCancel(context.Background())
    80  	answerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
    81  		p, attributes, readErr := track.ReadRTP()
    82  		assert.NoError(t, readErr)
    83  
    84  		assert.Equal(t, p.Extension, true)
    85  		assert.Equal(t, "foo", string(p.GetExtension(2)))
    86  		assert.Equal(t, "value", attributes.Get("attribute"))
    87  
    88  		seenRTPCancel()
    89  	})
    90  
    91  	assert.NoError(t, signalPair(offerer, answerer))
    92  
    93  	func() {
    94  		ticker := time.NewTicker(time.Millisecond * 20)
    95  		defer ticker.Stop()
    96  		for {
    97  			select {
    98  			case <-seenRTP.Done():
    99  				return
   100  			case <-ticker.C:
   101  				assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
   102  			}
   103  		}
   104  	}()
   105  
   106  	closePairNow(t, offerer, answerer)
   107  }
   108  
   109  func Test_Interceptor_BindUnbind(t *testing.T) {
   110  	lim := test.TimeOut(time.Second * 10)
   111  	defer lim.Stop()
   112  
   113  	report := test.CheckRoutines(t)
   114  	defer report()
   115  
   116  	var (
   117  		cntBindRTCPReader     uint32
   118  		cntBindRTCPWriter     uint32
   119  		cntBindLocalStream    uint32
   120  		cntUnbindLocalStream  uint32
   121  		cntBindRemoteStream   uint32
   122  		cntUnbindRemoteStream uint32
   123  		cntClose              uint32
   124  	)
   125  	mockInterceptor := &mock_interceptor.Interceptor{
   126  		BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader {
   127  			atomic.AddUint32(&cntBindRTCPReader, 1)
   128  			return reader
   129  		},
   130  		BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
   131  			atomic.AddUint32(&cntBindRTCPWriter, 1)
   132  			return writer
   133  		},
   134  		BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
   135  			atomic.AddUint32(&cntBindLocalStream, 1)
   136  			return writer
   137  		},
   138  		UnbindLocalStreamFn: func(*interceptor.StreamInfo) {
   139  			atomic.AddUint32(&cntUnbindLocalStream, 1)
   140  		},
   141  		BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
   142  			atomic.AddUint32(&cntBindRemoteStream, 1)
   143  			return reader
   144  		},
   145  		UnbindRemoteStreamFn: func(_ *interceptor.StreamInfo) {
   146  			atomic.AddUint32(&cntUnbindRemoteStream, 1)
   147  		},
   148  		CloseFn: func() error {
   149  			atomic.AddUint32(&cntClose, 1)
   150  			return nil
   151  		},
   152  	}
   153  	ir := &interceptor.Registry{}
   154  	ir.Add(&mock_interceptor.Factory{
   155  		NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil },
   156  	})
   157  
   158  	sender, receiver, err := NewAPI(WithInterceptorRegistry(ir)).newPair(Configuration{})
   159  	assert.NoError(t, err)
   160  
   161  	track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
   162  	assert.NoError(t, err)
   163  
   164  	_, err = sender.AddTrack(track)
   165  	assert.NoError(t, err)
   166  
   167  	receiverReady, receiverReadyFn := context.WithCancel(context.Background())
   168  	receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
   169  		_, _, readErr := track.ReadRTP()
   170  		assert.NoError(t, readErr)
   171  		receiverReadyFn()
   172  	})
   173  
   174  	assert.NoError(t, signalPair(sender, receiver))
   175  
   176  	ticker := time.NewTicker(time.Millisecond * 20)
   177  	defer ticker.Stop()
   178  	func() {
   179  		for {
   180  			select {
   181  			case <-receiverReady.Done():
   182  				return
   183  			case <-ticker.C:
   184  				// Send packet to make receiver track actual creates RTPReceiver.
   185  				assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
   186  			}
   187  		}
   188  	}()
   189  
   190  	closePairNow(t, sender, receiver)
   191  
   192  	// Bind/UnbindLocal/RemoteStream should be called from one side.
   193  	if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 {
   194  		t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt)
   195  	}
   196  	if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 {
   197  		t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt)
   198  	}
   199  	if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 2 {
   200  		t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt)
   201  	}
   202  	if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 2 {
   203  		t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt)
   204  	}
   205  
   206  	// BindRTCPWriter/Reader and Close should be called from both side.
   207  	if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 {
   208  		t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt)
   209  	}
   210  	if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 3 {
   211  		t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt)
   212  	}
   213  	if cnt := atomic.LoadUint32(&cntClose); cnt != 2 {
   214  		t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt)
   215  	}
   216  }
   217  
   218  func Test_InterceptorRegistry_Build(t *testing.T) {
   219  	registryBuildCount := 0
   220  
   221  	ir := &interceptor.Registry{}
   222  	ir.Add(&mock_interceptor.Factory{
   223  		NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
   224  			registryBuildCount++
   225  			return &interceptor.NoOp{}, nil
   226  		},
   227  	})
   228  
   229  	peerConnectionA, peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).newPair(Configuration{})
   230  	assert.NoError(t, err)
   231  
   232  	assert.Equal(t, 2, registryBuildCount)
   233  	closePairNow(t, peerConnectionA, peerConnectionB)
   234  }
   235  
   236  func Test_Interceptor_ZeroSSRC(t *testing.T) {
   237  	to := test.TimeOut(time.Second * 20)
   238  	defer to.Stop()
   239  
   240  	report := test.CheckRoutines(t)
   241  	defer report()
   242  
   243  	track, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
   244  	assert.NoError(t, err)
   245  
   246  	offerer, answerer, err := newPair()
   247  	assert.NoError(t, err)
   248  
   249  	_, err = offerer.AddTrack(track)
   250  	assert.NoError(t, err)
   251  
   252  	probeReceiverCreated := make(chan struct{})
   253  
   254  	go func() {
   255  		sequenceNumber := uint16(0)
   256  		ticker := time.NewTicker(time.Millisecond * 20)
   257  		defer ticker.Stop()
   258  		for range ticker.C {
   259  			track.mu.Lock()
   260  			if len(track.bindings) == 1 {
   261  				_, err = track.bindings[0].writeStream.WriteRTP(&rtp.Header{
   262  					Version:        2,
   263  					SSRC:           0,
   264  					SequenceNumber: sequenceNumber,
   265  				}, []byte{0, 1, 2, 3, 4, 5})
   266  				assert.NoError(t, err)
   267  			}
   268  			sequenceNumber++
   269  			track.mu.Unlock()
   270  
   271  			if nonMediaBandwidthProbe, ok := answerer.nonMediaBandwidthProbe.Load().(*RTPReceiver); ok {
   272  				assert.Equal(t, len(nonMediaBandwidthProbe.Tracks()), 1)
   273  				close(probeReceiverCreated)
   274  				return
   275  			}
   276  		}
   277  	}()
   278  
   279  	assert.NoError(t, signalPair(offerer, answerer))
   280  
   281  	peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, offerer, answerer)
   282  	peerConnectionConnected.Wait()
   283  
   284  	<-probeReceiverCreated
   285  	closePairNow(t, offerer, answerer)
   286  }