github.com/pion/webrtc/v4@v4.0.1/interceptor_test.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 // 10 import ( 11 "context" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "github.com/pion/interceptor" 17 mock_interceptor "github.com/pion/interceptor/pkg/mock" 18 "github.com/pion/rtp" 19 "github.com/pion/transport/v3/test" 20 "github.com/pion/webrtc/v4/pkg/media" 21 "github.com/stretchr/testify/assert" 22 ) 23 24 // E2E test of the features of Interceptors 25 // * Assert an extension can be set on an outbound packet 26 // * Assert an extension can be read on an outbound packet 27 // * Assert that attributes set by an interceptor are returned to the Reader 28 func TestPeerConnection_Interceptor(t *testing.T) { 29 to := test.TimeOut(time.Second * 20) 30 defer to.Stop() 31 32 report := test.CheckRoutines(t) 33 defer report() 34 35 createPC := func() *PeerConnection { 36 ir := &interceptor.Registry{} 37 ir.Add(&mock_interceptor.Factory{ 38 NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { 39 return &mock_interceptor.Interceptor{ 40 BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { 41 return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { 42 // set extension on outgoing packet 43 header.Extension = true 44 header.ExtensionProfile = 0xBEDE 45 assert.NoError(t, header.SetExtension(2, []byte("foo"))) 46 47 return writer.Write(header, payload, attributes) 48 }) 49 }, 50 BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { 51 return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { 52 if a == nil { 53 a = interceptor.Attributes{} 54 } 55 56 a.Set("attribute", "value") 57 return reader.Read(b, a) 58 }) 59 }, 60 }, nil 61 }, 62 }) 63 64 pc, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) 65 assert.NoError(t, err) 66 67 return pc 68 } 69 70 offerer := createPC() 71 answerer := createPC() 72 73 track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") 74 assert.NoError(t, err) 75 76 _, err = offerer.AddTrack(track) 77 assert.NoError(t, err) 78 79 seenRTP, seenRTPCancel := context.WithCancel(context.Background()) 80 answerer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { 81 p, attributes, readErr := track.ReadRTP() 82 assert.NoError(t, readErr) 83 84 assert.Equal(t, p.Extension, true) 85 assert.Equal(t, "foo", string(p.GetExtension(2))) 86 assert.Equal(t, "value", attributes.Get("attribute")) 87 88 seenRTPCancel() 89 }) 90 91 assert.NoError(t, signalPair(offerer, answerer)) 92 93 func() { 94 ticker := time.NewTicker(time.Millisecond * 20) 95 defer ticker.Stop() 96 for { 97 select { 98 case <-seenRTP.Done(): 99 return 100 case <-ticker.C: 101 assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) 102 } 103 } 104 }() 105 106 closePairNow(t, offerer, answerer) 107 } 108 109 func Test_Interceptor_BindUnbind(t *testing.T) { 110 lim := test.TimeOut(time.Second * 10) 111 defer lim.Stop() 112 113 report := test.CheckRoutines(t) 114 defer report() 115 116 var ( 117 cntBindRTCPReader uint32 118 cntBindRTCPWriter uint32 119 cntBindLocalStream uint32 120 cntUnbindLocalStream uint32 121 cntBindRemoteStream uint32 122 cntUnbindRemoteStream uint32 123 cntClose uint32 124 ) 125 mockInterceptor := &mock_interceptor.Interceptor{ 126 BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader { 127 atomic.AddUint32(&cntBindRTCPReader, 1) 128 return reader 129 }, 130 BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter { 131 atomic.AddUint32(&cntBindRTCPWriter, 1) 132 return writer 133 }, 134 BindLocalStreamFn: func(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { 135 atomic.AddUint32(&cntBindLocalStream, 1) 136 return writer 137 }, 138 UnbindLocalStreamFn: func(*interceptor.StreamInfo) { 139 atomic.AddUint32(&cntUnbindLocalStream, 1) 140 }, 141 BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { 142 atomic.AddUint32(&cntBindRemoteStream, 1) 143 return reader 144 }, 145 UnbindRemoteStreamFn: func(_ *interceptor.StreamInfo) { 146 atomic.AddUint32(&cntUnbindRemoteStream, 1) 147 }, 148 CloseFn: func() error { 149 atomic.AddUint32(&cntClose, 1) 150 return nil 151 }, 152 } 153 ir := &interceptor.Registry{} 154 ir.Add(&mock_interceptor.Factory{ 155 NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil }, 156 }) 157 158 sender, receiver, err := NewAPI(WithInterceptorRegistry(ir)).newPair(Configuration{}) 159 assert.NoError(t, err) 160 161 track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") 162 assert.NoError(t, err) 163 164 _, err = sender.AddTrack(track) 165 assert.NoError(t, err) 166 167 receiverReady, receiverReadyFn := context.WithCancel(context.Background()) 168 receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { 169 _, _, readErr := track.ReadRTP() 170 assert.NoError(t, readErr) 171 receiverReadyFn() 172 }) 173 174 assert.NoError(t, signalPair(sender, receiver)) 175 176 ticker := time.NewTicker(time.Millisecond * 20) 177 defer ticker.Stop() 178 func() { 179 for { 180 select { 181 case <-receiverReady.Done(): 182 return 183 case <-ticker.C: 184 // Send packet to make receiver track actual creates RTPReceiver. 185 assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) 186 } 187 } 188 }() 189 190 closePairNow(t, sender, receiver) 191 192 // Bind/UnbindLocal/RemoteStream should be called from one side. 193 if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 { 194 t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt) 195 } 196 if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 { 197 t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt) 198 } 199 if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 2 { 200 t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt) 201 } 202 if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 2 { 203 t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt) 204 } 205 206 // BindRTCPWriter/Reader and Close should be called from both side. 207 if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 { 208 t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt) 209 } 210 if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 3 { 211 t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt) 212 } 213 if cnt := atomic.LoadUint32(&cntClose); cnt != 2 { 214 t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt) 215 } 216 } 217 218 func Test_InterceptorRegistry_Build(t *testing.T) { 219 registryBuildCount := 0 220 221 ir := &interceptor.Registry{} 222 ir.Add(&mock_interceptor.Factory{ 223 NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { 224 registryBuildCount++ 225 return &interceptor.NoOp{}, nil 226 }, 227 }) 228 229 peerConnectionA, peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).newPair(Configuration{}) 230 assert.NoError(t, err) 231 232 assert.Equal(t, 2, registryBuildCount) 233 closePairNow(t, peerConnectionA, peerConnectionB) 234 } 235 236 func Test_Interceptor_ZeroSSRC(t *testing.T) { 237 to := test.TimeOut(time.Second * 20) 238 defer to.Stop() 239 240 report := test.CheckRoutines(t) 241 defer report() 242 243 track, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") 244 assert.NoError(t, err) 245 246 offerer, answerer, err := newPair() 247 assert.NoError(t, err) 248 249 _, err = offerer.AddTrack(track) 250 assert.NoError(t, err) 251 252 probeReceiverCreated := make(chan struct{}) 253 254 go func() { 255 sequenceNumber := uint16(0) 256 ticker := time.NewTicker(time.Millisecond * 20) 257 defer ticker.Stop() 258 for range ticker.C { 259 track.mu.Lock() 260 if len(track.bindings) == 1 { 261 _, err = track.bindings[0].writeStream.WriteRTP(&rtp.Header{ 262 Version: 2, 263 SSRC: 0, 264 SequenceNumber: sequenceNumber, 265 }, []byte{0, 1, 2, 3, 4, 5}) 266 assert.NoError(t, err) 267 } 268 sequenceNumber++ 269 track.mu.Unlock() 270 271 if nonMediaBandwidthProbe, ok := answerer.nonMediaBandwidthProbe.Load().(*RTPReceiver); ok { 272 assert.Equal(t, len(nonMediaBandwidthProbe.Tracks()), 1) 273 close(probeReceiverCreated) 274 return 275 } 276 } 277 }() 278 279 assert.NoError(t, signalPair(offerer, answerer)) 280 281 peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, offerer, answerer) 282 peerConnectionConnected.Wait() 283 284 <-probeReceiverCreated 285 closePairNow(t, offerer, answerer) 286 }