github.com/pion/webrtc/v3@v3.2.24/dtlstransport.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"crypto/ecdsa"
    11  	"crypto/elliptic"
    12  	"crypto/rand"
    13  	"crypto/tls"
    14  	"crypto/x509"
    15  	"errors"
    16  	"fmt"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/pion/dtls/v2"
    23  	"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
    24  	"github.com/pion/interceptor"
    25  	"github.com/pion/logging"
    26  	"github.com/pion/rtcp"
    27  	"github.com/pion/srtp/v2"
    28  	"github.com/pion/webrtc/v3/internal/mux"
    29  	"github.com/pion/webrtc/v3/internal/util"
    30  	"github.com/pion/webrtc/v3/pkg/rtcerr"
    31  )
    32  
    33  // DTLSTransport allows an application access to information about the DTLS
    34  // transport over which RTP and RTCP packets are sent and received by
    35  // RTPSender and RTPReceiver, as well other data such as SCTP packets sent
    36  // and received by data channels.
    37  type DTLSTransport struct {
    38  	lock sync.RWMutex
    39  
    40  	iceTransport          *ICETransport
    41  	certificates          []Certificate
    42  	remoteParameters      DTLSParameters
    43  	remoteCertificate     []byte
    44  	state                 DTLSTransportState
    45  	srtpProtectionProfile srtp.ProtectionProfile
    46  
    47  	onStateChangeHandler func(DTLSTransportState)
    48  
    49  	conn *dtls.Conn
    50  
    51  	srtpSession, srtcpSession   atomic.Value
    52  	srtpEndpoint, srtcpEndpoint *mux.Endpoint
    53  	simulcastStreams            []*srtp.ReadStreamSRTP
    54  	srtpReady                   chan struct{}
    55  
    56  	dtlsMatcher mux.MatchFunc
    57  
    58  	api *API
    59  	log logging.LeveledLogger
    60  }
    61  
    62  // NewDTLSTransport creates a new DTLSTransport.
    63  // This constructor is part of the ORTC API. It is not
    64  // meant to be used together with the basic WebRTC API.
    65  func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
    66  	t := &DTLSTransport{
    67  		iceTransport: transport,
    68  		api:          api,
    69  		state:        DTLSTransportStateNew,
    70  		dtlsMatcher:  mux.MatchDTLS,
    71  		srtpReady:    make(chan struct{}),
    72  		log:          api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
    73  	}
    74  
    75  	if len(certificates) > 0 {
    76  		now := time.Now()
    77  		for _, x509Cert := range certificates {
    78  			if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
    79  				return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
    80  			}
    81  			t.certificates = append(t.certificates, x509Cert)
    82  		}
    83  	} else {
    84  		sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    85  		if err != nil {
    86  			return nil, &rtcerr.UnknownError{Err: err}
    87  		}
    88  		certificate, err := GenerateCertificate(sk)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		t.certificates = []Certificate{*certificate}
    93  	}
    94  
    95  	return t, nil
    96  }
    97  
    98  // ICETransport returns the currently-configured *ICETransport or nil
    99  // if one has not been configured
   100  func (t *DTLSTransport) ICETransport() *ICETransport {
   101  	t.lock.RLock()
   102  	defer t.lock.RUnlock()
   103  	return t.iceTransport
   104  }
   105  
   106  // onStateChange requires the caller holds the lock
   107  func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
   108  	t.state = state
   109  	handler := t.onStateChangeHandler
   110  	if handler != nil {
   111  		handler(state)
   112  	}
   113  }
   114  
   115  // OnStateChange sets a handler that is fired when the DTLS
   116  // connection state changes.
   117  func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
   118  	t.lock.Lock()
   119  	defer t.lock.Unlock()
   120  	t.onStateChangeHandler = f
   121  }
   122  
   123  // State returns the current dtls transport state.
   124  func (t *DTLSTransport) State() DTLSTransportState {
   125  	t.lock.RLock()
   126  	defer t.lock.RUnlock()
   127  	return t.state
   128  }
   129  
   130  // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
   131  // packet is discarded.
   132  func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) {
   133  	raw, err := rtcp.Marshal(pkts)
   134  	if err != nil {
   135  		return 0, err
   136  	}
   137  
   138  	srtcpSession, err := t.getSRTCPSession()
   139  	if err != nil {
   140  		return 0, err
   141  	}
   142  
   143  	writeStream, err := srtcpSession.OpenWriteStream()
   144  	if err != nil {
   145  		// nolint
   146  		return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
   147  	}
   148  
   149  	return writeStream.Write(raw)
   150  }
   151  
   152  // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
   153  func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
   154  	fingerprints := []DTLSFingerprint{}
   155  
   156  	for _, c := range t.certificates {
   157  		prints, err := c.GetFingerprints()
   158  		if err != nil {
   159  			return DTLSParameters{}, err
   160  		}
   161  
   162  		fingerprints = append(fingerprints, prints...)
   163  	}
   164  
   165  	return DTLSParameters{
   166  		Role:         DTLSRoleAuto, // always returns the default role
   167  		Fingerprints: fingerprints,
   168  	}, nil
   169  }
   170  
   171  // GetRemoteCertificate returns the certificate chain in use by the remote side
   172  // returns an empty list prior to selection of the remote certificate
   173  func (t *DTLSTransport) GetRemoteCertificate() []byte {
   174  	t.lock.RLock()
   175  	defer t.lock.RUnlock()
   176  	return t.remoteCertificate
   177  }
   178  
   179  func (t *DTLSTransport) startSRTP() error {
   180  	srtpConfig := &srtp.Config{
   181  		Profile:       t.srtpProtectionProfile,
   182  		BufferFactory: t.api.settingEngine.BufferFactory,
   183  		LoggerFactory: t.api.settingEngine.LoggerFactory,
   184  	}
   185  	if t.api.settingEngine.replayProtection.SRTP != nil {
   186  		srtpConfig.RemoteOptions = append(
   187  			srtpConfig.RemoteOptions,
   188  			srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
   189  		)
   190  	}
   191  
   192  	if t.api.settingEngine.disableSRTPReplayProtection {
   193  		srtpConfig.RemoteOptions = append(
   194  			srtpConfig.RemoteOptions,
   195  			srtp.SRTPNoReplayProtection(),
   196  		)
   197  	}
   198  
   199  	if t.api.settingEngine.replayProtection.SRTCP != nil {
   200  		srtpConfig.RemoteOptions = append(
   201  			srtpConfig.RemoteOptions,
   202  			srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
   203  		)
   204  	}
   205  
   206  	if t.api.settingEngine.disableSRTCPReplayProtection {
   207  		srtpConfig.RemoteOptions = append(
   208  			srtpConfig.RemoteOptions,
   209  			srtp.SRTCPNoReplayProtection(),
   210  		)
   211  	}
   212  
   213  	connState := t.conn.ConnectionState()
   214  	err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
   215  	if err != nil {
   216  		// nolint
   217  		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
   218  	}
   219  
   220  	srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
   221  	if err != nil {
   222  		// nolint
   223  		return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
   224  	}
   225  
   226  	srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
   227  	if err != nil {
   228  		// nolint
   229  		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
   230  	}
   231  
   232  	t.srtpSession.Store(srtpSession)
   233  	t.srtcpSession.Store(srtcpSession)
   234  	close(t.srtpReady)
   235  	return nil
   236  }
   237  
   238  func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
   239  	if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
   240  		return value, nil
   241  	}
   242  
   243  	return nil, errDtlsTransportNotStarted
   244  }
   245  
   246  func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
   247  	if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
   248  		return value, nil
   249  	}
   250  
   251  	return nil, errDtlsTransportNotStarted
   252  }
   253  
   254  func (t *DTLSTransport) role() DTLSRole {
   255  	// If remote has an explicit role use the inverse
   256  	switch t.remoteParameters.Role {
   257  	case DTLSRoleClient:
   258  		return DTLSRoleServer
   259  	case DTLSRoleServer:
   260  		return DTLSRoleClient
   261  	default:
   262  	}
   263  
   264  	// If SettingEngine has an explicit role
   265  	switch t.api.settingEngine.answeringDTLSRole {
   266  	case DTLSRoleServer:
   267  		return DTLSRoleServer
   268  	case DTLSRoleClient:
   269  		return DTLSRoleClient
   270  	default:
   271  	}
   272  
   273  	// Remote was auto and no explicit role was configured via SettingEngine
   274  	if t.iceTransport.Role() == ICERoleControlling {
   275  		return DTLSRoleServer
   276  	}
   277  	return defaultDtlsRoleAnswer
   278  }
   279  
   280  // Start DTLS transport negotiation with the parameters of the remote DTLS transport
   281  func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
   282  	// Take lock and prepare connection, we must not hold the lock
   283  	// when connecting
   284  	prepareTransport := func() (DTLSRole, *dtls.Config, error) {
   285  		t.lock.Lock()
   286  		defer t.lock.Unlock()
   287  
   288  		if err := t.ensureICEConn(); err != nil {
   289  			return DTLSRole(0), nil, err
   290  		}
   291  
   292  		if t.state != DTLSTransportStateNew {
   293  			return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
   294  		}
   295  
   296  		t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
   297  		t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
   298  		t.remoteParameters = remoteParameters
   299  
   300  		cert := t.certificates[0]
   301  		t.onStateChange(DTLSTransportStateConnecting)
   302  
   303  		return t.role(), &dtls.Config{
   304  			Certificates: []tls.Certificate{
   305  				{
   306  					Certificate: [][]byte{cert.x509Cert.Raw},
   307  					PrivateKey:  cert.privateKey,
   308  				},
   309  			},
   310  			SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
   311  				if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
   312  					return t.api.settingEngine.srtpProtectionProfiles
   313  				}
   314  
   315  				return defaultSrtpProtectionProfiles()
   316  			}(),
   317  			ClientAuth:         dtls.RequireAnyClientCert,
   318  			LoggerFactory:      t.api.settingEngine.LoggerFactory,
   319  			InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify,
   320  		}, nil
   321  	}
   322  
   323  	var dtlsConn *dtls.Conn
   324  	dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
   325  	role, dtlsConfig, err := prepareTransport()
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	if t.api.settingEngine.replayProtection.DTLS != nil {
   331  		dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
   332  	}
   333  
   334  	if t.api.settingEngine.dtls.clientAuth != nil {
   335  		dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth
   336  	}
   337  
   338  	dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
   339  	dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify
   340  	dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves
   341  	dtlsConfig.ConnectContextMaker = t.api.settingEngine.dtls.connectContextMaker
   342  	dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret
   343  	dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
   344  	dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
   345  	dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
   346  
   347  	// Connect as DTLS Client/Server, function is blocking and we
   348  	// must not hold the DTLSTransport lock
   349  	if role == DTLSRoleClient {
   350  		dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig)
   351  	} else {
   352  		dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig)
   353  	}
   354  
   355  	// Re-take the lock, nothing beyond here is blocking
   356  	t.lock.Lock()
   357  	defer t.lock.Unlock()
   358  
   359  	if err != nil {
   360  		t.onStateChange(DTLSTransportStateFailed)
   361  		return err
   362  	}
   363  
   364  	srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
   365  	if !ok {
   366  		t.onStateChange(DTLSTransportStateFailed)
   367  		return ErrNoSRTPProtectionProfile
   368  	}
   369  
   370  	switch srtpProfile {
   371  	case dtls.SRTP_AEAD_AES_128_GCM:
   372  		t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
   373  	case dtls.SRTP_AEAD_AES_256_GCM:
   374  		t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
   375  	case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
   376  		t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
   377  	default:
   378  		t.onStateChange(DTLSTransportStateFailed)
   379  		return ErrNoSRTPProtectionProfile
   380  	}
   381  
   382  	// Check the fingerprint if a certificate was exchanged
   383  	remoteCerts := dtlsConn.ConnectionState().PeerCertificates
   384  	if len(remoteCerts) == 0 {
   385  		t.onStateChange(DTLSTransportStateFailed)
   386  		return errNoRemoteCertificate
   387  	}
   388  	t.remoteCertificate = remoteCerts[0]
   389  
   390  	if !t.api.settingEngine.disableCertificateFingerprintVerification {
   391  		parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
   392  		if err != nil {
   393  			if closeErr := dtlsConn.Close(); closeErr != nil {
   394  				t.log.Error(err.Error())
   395  			}
   396  
   397  			t.onStateChange(DTLSTransportStateFailed)
   398  			return err
   399  		}
   400  
   401  		if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
   402  			if closeErr := dtlsConn.Close(); closeErr != nil {
   403  				t.log.Error(err.Error())
   404  			}
   405  
   406  			t.onStateChange(DTLSTransportStateFailed)
   407  			return err
   408  		}
   409  	}
   410  
   411  	t.conn = dtlsConn
   412  	t.onStateChange(DTLSTransportStateConnected)
   413  
   414  	return t.startSRTP()
   415  }
   416  
   417  // Stop stops and closes the DTLSTransport object.
   418  func (t *DTLSTransport) Stop() error {
   419  	t.lock.Lock()
   420  	defer t.lock.Unlock()
   421  
   422  	// Try closing everything and collect the errors
   423  	var closeErrs []error
   424  
   425  	if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil {
   426  		closeErrs = append(closeErrs, srtpSession.Close())
   427  	}
   428  
   429  	if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil {
   430  		closeErrs = append(closeErrs, srtcpSession.Close())
   431  	}
   432  
   433  	for i := range t.simulcastStreams {
   434  		closeErrs = append(closeErrs, t.simulcastStreams[i].Close())
   435  	}
   436  
   437  	if t.conn != nil {
   438  		// dtls connection may be closed on sctp close.
   439  		if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
   440  			closeErrs = append(closeErrs, err)
   441  		}
   442  	}
   443  	t.onStateChange(DTLSTransportStateClosed)
   444  	return util.FlattenErrs(closeErrs)
   445  }
   446  
   447  func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
   448  	for _, fp := range t.remoteParameters.Fingerprints {
   449  		hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
   450  		if err != nil {
   451  			return err
   452  		}
   453  
   454  		remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
   455  		if err != nil {
   456  			return err
   457  		}
   458  
   459  		if strings.EqualFold(remoteValue, fp.Value) {
   460  			return nil
   461  		}
   462  	}
   463  
   464  	return errNoMatchingCertificateFingerprint
   465  }
   466  
   467  func (t *DTLSTransport) ensureICEConn() error {
   468  	if t.iceTransport == nil {
   469  		return errICEConnectionNotStarted
   470  	}
   471  
   472  	return nil
   473  }
   474  
   475  func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {
   476  	t.lock.Lock()
   477  	defer t.lock.Unlock()
   478  
   479  	t.simulcastStreams = append(t.simulcastStreams, s)
   480  }
   481  
   482  func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
   483  	srtpSession, err := t.getSRTPSession()
   484  	if err != nil {
   485  		return nil, nil, nil, nil, err
   486  	}
   487  
   488  	rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
   489  	if err != nil {
   490  		return nil, nil, nil, nil, err
   491  	}
   492  
   493  	rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
   494  		n, err = rtpReadStream.Read(in)
   495  		return n, a, err
   496  	}))
   497  
   498  	srtcpSession, err := t.getSRTCPSession()
   499  	if err != nil {
   500  		return nil, nil, nil, nil, err
   501  	}
   502  
   503  	rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
   504  	if err != nil {
   505  		return nil, nil, nil, nil, err
   506  	}
   507  
   508  	rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
   509  		n, err = rtcpReadStream.Read(in)
   510  		return n, a, err
   511  	}))
   512  
   513  	return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
   514  }