github.com/pion/webrtc/v4@v4.0.1/srtp_writer_future.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"io"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/pion/rtp"
    16  	"github.com/pion/srtp/v3"
    17  )
    18  
    19  // srtpWriterFuture blocks Read/Write calls until
    20  // the SRTP Session is available
    21  type srtpWriterFuture struct {
    22  	ssrc           SSRC
    23  	rtpSender      *RTPSender
    24  	rtcpReadStream atomic.Value // *srtp.ReadStreamSRTCP
    25  	rtpWriteStream atomic.Value // *srtp.WriteStreamSRTP
    26  	mu             sync.Mutex
    27  	closed         bool
    28  }
    29  
    30  func (s *srtpWriterFuture) init(returnWhenNoSRTP bool) error {
    31  	if returnWhenNoSRTP {
    32  		select {
    33  		case <-s.rtpSender.stopCalled:
    34  			return io.ErrClosedPipe
    35  		case <-s.rtpSender.transport.srtpReady:
    36  		default:
    37  			return nil
    38  		}
    39  	} else {
    40  		select {
    41  		case <-s.rtpSender.stopCalled:
    42  			return io.ErrClosedPipe
    43  		case <-s.rtpSender.transport.srtpReady:
    44  		}
    45  	}
    46  
    47  	s.mu.Lock()
    48  	defer s.mu.Unlock()
    49  
    50  	if s.closed {
    51  		return io.ErrClosedPipe
    52  	}
    53  
    54  	srtcpSession, err := s.rtpSender.transport.getSRTCPSession()
    55  	if err != nil {
    56  		return err
    57  	}
    58  
    59  	rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(s.ssrc))
    60  	if err != nil {
    61  		return err
    62  	}
    63  
    64  	srtpSession, err := s.rtpSender.transport.getSRTPSession()
    65  	if err != nil {
    66  		return err
    67  	}
    68  
    69  	rtpWriteStream, err := srtpSession.OpenWriteStream()
    70  	if err != nil {
    71  		return err
    72  	}
    73  
    74  	s.rtcpReadStream.Store(rtcpReadStream)
    75  	s.rtpWriteStream.Store(rtpWriteStream)
    76  	return nil
    77  }
    78  
    79  func (s *srtpWriterFuture) Close() error {
    80  	s.mu.Lock()
    81  	defer s.mu.Unlock()
    82  
    83  	if s.closed {
    84  		return nil
    85  	}
    86  	s.closed = true
    87  
    88  	if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok {
    89  		return value.Close()
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  func (s *srtpWriterFuture) Read(b []byte) (n int, err error) {
    96  	if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok {
    97  		return value.Read(b)
    98  	}
    99  
   100  	if err := s.init(false); err != nil || s.rtcpReadStream.Load() == nil {
   101  		return 0, err
   102  	}
   103  
   104  	return s.Read(b)
   105  }
   106  
   107  func (s *srtpWriterFuture) SetReadDeadline(t time.Time) error {
   108  	if value, ok := s.rtcpReadStream.Load().(*srtp.ReadStreamSRTCP); ok {
   109  		return value.SetReadDeadline(t)
   110  	}
   111  
   112  	if err := s.init(false); err != nil || s.rtcpReadStream.Load() == nil {
   113  		return err
   114  	}
   115  
   116  	return s.SetReadDeadline(t)
   117  }
   118  
   119  func (s *srtpWriterFuture) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
   120  	if value, ok := s.rtpWriteStream.Load().(*srtp.WriteStreamSRTP); ok {
   121  		return value.WriteRTP(header, payload)
   122  	}
   123  
   124  	if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil {
   125  		return 0, err
   126  	}
   127  
   128  	return s.WriteRTP(header, payload)
   129  }
   130  
   131  func (s *srtpWriterFuture) Write(b []byte) (int, error) {
   132  	if value, ok := s.rtpWriteStream.Load().(*srtp.WriteStreamSRTP); ok {
   133  		return value.Write(b)
   134  	}
   135  
   136  	if err := s.init(true); err != nil || s.rtpWriteStream.Load() == nil {
   137  		return 0, err
   138  	}
   139  
   140  	return s.Write(b)
   141  }