github.com/pion/webrtc/v3@v3.2.24/interceptor_test.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  //
    10  import (
    11  	"context"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/pion/interceptor"
    17  	mock_interceptor "github.com/pion/interceptor/pkg/mock"
    18  	"github.com/pion/rtp"
    19  	"github.com/pion/transport/v2/test"
    20  	"github.com/pion/webrtc/v3/pkg/media"
    21  	"github.com/stretchr/testify/assert"
    22  )
    23  
    24  // E2E test of the features of Interceptors
    25  // * Assert an extension can be set on an outbound packet
    26  // * Assert an extension can be read on an outbound packet
    27  // * Assert that attributes set by an interceptor are returned to the Reader
    28  func TestPeerConnection_Interceptor(t *testing.T) {
    29  	to := test.TimeOut(time.Second * 20)
    30  	defer to.Stop()
    31  
    32  	report := test.CheckRoutines(t)
    33  	defer report()
    34  
    35  	createPC := func() *PeerConnection {
    36  		m := &MediaEngine{}
    37  		assert.NoError(t, m.RegisterDefaultCodecs())
    38  
    39  		ir := &interceptor.Registry{}
    40  		ir.Add(&mock_interceptor.Factory{
    41  			NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
    42  				return &mock_interceptor.Interceptor{
    43  					BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
    44  						return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
    45  							// set extension on outgoing packet
    46  							header.Extension = true
    47  							header.ExtensionProfile = 0xBEDE
    48  							assert.NoError(t, header.SetExtension(2, []byte("foo")))
    49  
    50  							return writer.Write(header, payload, attributes)
    51  						})
    52  					},
    53  					BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
    54  						return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
    55  							if a == nil {
    56  								a = interceptor.Attributes{}
    57  							}
    58  
    59  							a.Set("attribute", "value")
    60  							return reader.Read(b, a)
    61  						})
    62  					},
    63  				}, nil
    64  			},
    65  		})
    66  
    67  		pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
    68  		assert.NoError(t, err)
    69  
    70  		return pc
    71  	}
    72  
    73  	offerer := createPC()
    74  	answerer := createPC()
    75  
    76  	track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
    77  	assert.NoError(t, err)
    78  
    79  	_, err = offerer.AddTrack(track)
    80  	assert.NoError(t, err)
    81  
    82  	seenRTP, seenRTPCancel := context.WithCancel(context.Background())
    83  	answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
    84  		p, attributes, readErr := track.ReadRTP()
    85  		assert.NoError(t, readErr)
    86  
    87  		assert.Equal(t, p.Extension, true)
    88  		assert.Equal(t, "foo", string(p.GetExtension(2)))
    89  		assert.Equal(t, "value", attributes.Get("attribute"))
    90  
    91  		seenRTPCancel()
    92  	})
    93  
    94  	assert.NoError(t, signalPair(offerer, answerer))
    95  
    96  	func() {
    97  		ticker := time.NewTicker(time.Millisecond * 20)
    98  		for {
    99  			select {
   100  			case <-seenRTP.Done():
   101  				return
   102  			case <-ticker.C:
   103  				assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
   104  			}
   105  		}
   106  	}()
   107  
   108  	closePairNow(t, offerer, answerer)
   109  }
   110  
   111  func Test_Interceptor_BindUnbind(t *testing.T) {
   112  	lim := test.TimeOut(time.Second * 10)
   113  	defer lim.Stop()
   114  
   115  	report := test.CheckRoutines(t)
   116  	defer report()
   117  
   118  	m := &MediaEngine{}
   119  	assert.NoError(t, m.RegisterDefaultCodecs())
   120  
   121  	var (
   122  		cntBindRTCPReader     uint32
   123  		cntBindRTCPWriter     uint32
   124  		cntBindLocalStream    uint32
   125  		cntUnbindLocalStream  uint32
   126  		cntBindRemoteStream   uint32
   127  		cntUnbindRemoteStream uint32
   128  		cntClose              uint32
   129  	)
   130  	mockInterceptor := &mock_interceptor.Interceptor{
   131  		BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader {
   132  			atomic.AddUint32(&cntBindRTCPReader, 1)
   133  			return reader
   134  		},
   135  		BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
   136  			atomic.AddUint32(&cntBindRTCPWriter, 1)
   137  			return writer
   138  		},
   139  		BindLocalStreamFn: func(i *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
   140  			atomic.AddUint32(&cntBindLocalStream, 1)
   141  			return writer
   142  		},
   143  		UnbindLocalStreamFn: func(i *interceptor.StreamInfo) {
   144  			atomic.AddUint32(&cntUnbindLocalStream, 1)
   145  		},
   146  		BindRemoteStreamFn: func(i *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
   147  			atomic.AddUint32(&cntBindRemoteStream, 1)
   148  			return reader
   149  		},
   150  		UnbindRemoteStreamFn: func(i *interceptor.StreamInfo) {
   151  			atomic.AddUint32(&cntUnbindRemoteStream, 1)
   152  		},
   153  		CloseFn: func() error {
   154  			atomic.AddUint32(&cntClose, 1)
   155  			return nil
   156  		},
   157  	}
   158  	ir := &interceptor.Registry{}
   159  	ir.Add(&mock_interceptor.Factory{
   160  		NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil },
   161  	})
   162  
   163  	sender, receiver, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{})
   164  	assert.NoError(t, err)
   165  
   166  	track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
   167  	assert.NoError(t, err)
   168  
   169  	_, err = sender.AddTrack(track)
   170  	assert.NoError(t, err)
   171  
   172  	receiverReady, receiverReadyFn := context.WithCancel(context.Background())
   173  	receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
   174  		_, _, readErr := track.ReadRTP()
   175  		assert.NoError(t, readErr)
   176  		receiverReadyFn()
   177  	})
   178  
   179  	assert.NoError(t, signalPair(sender, receiver))
   180  
   181  	ticker := time.NewTicker(time.Millisecond * 20)
   182  	defer ticker.Stop()
   183  	func() {
   184  		for {
   185  			select {
   186  			case <-receiverReady.Done():
   187  				return
   188  			case <-ticker.C:
   189  				// Send packet to make receiver track actual creates RTPReceiver.
   190  				assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
   191  			}
   192  		}
   193  	}()
   194  
   195  	closePairNow(t, sender, receiver)
   196  
   197  	// Bind/UnbindLocal/RemoteStream should be called from one side.
   198  	if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 {
   199  		t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt)
   200  	}
   201  	if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 {
   202  		t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt)
   203  	}
   204  	if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 1 {
   205  		t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt)
   206  	}
   207  	if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 1 {
   208  		t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt)
   209  	}
   210  
   211  	// BindRTCPWriter/Reader and Close should be called from both side.
   212  	if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 {
   213  		t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt)
   214  	}
   215  	if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 2 {
   216  		t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt)
   217  	}
   218  	if cnt := atomic.LoadUint32(&cntClose); cnt != 2 {
   219  		t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt)
   220  	}
   221  }
   222  
   223  func Test_InterceptorRegistry_Build(t *testing.T) {
   224  	registryBuildCount := 0
   225  
   226  	ir := &interceptor.Registry{}
   227  	ir.Add(&mock_interceptor.Factory{
   228  		NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) {
   229  			registryBuildCount++
   230  			return &interceptor.NoOp{}, nil
   231  		},
   232  	})
   233  
   234  	peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
   235  	assert.NoError(t, err)
   236  
   237  	peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
   238  	assert.NoError(t, err)
   239  
   240  	assert.Equal(t, 2, registryBuildCount)
   241  	closePairNow(t, peerConnectionA, peerConnectionB)
   242  }