github.com/pion/webrtc/v4@v4.0.1/dtlstransport.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  //go:build !js
     5  // +build !js
     6  
     7  package webrtc
     8  
     9  import (
    10  	"crypto/ecdsa"
    11  	"crypto/elliptic"
    12  	"crypto/rand"
    13  	"crypto/tls"
    14  	"crypto/x509"
    15  	"errors"
    16  	"fmt"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/pion/dtls/v3"
    23  	"github.com/pion/dtls/v3/pkg/crypto/fingerprint"
    24  	"github.com/pion/interceptor"
    25  	"github.com/pion/logging"
    26  	"github.com/pion/rtcp"
    27  	"github.com/pion/srtp/v3"
    28  	"github.com/pion/webrtc/v4/internal/mux"
    29  	"github.com/pion/webrtc/v4/internal/util"
    30  	"github.com/pion/webrtc/v4/pkg/rtcerr"
    31  )
    32  
    33  // DTLSTransport allows an application access to information about the DTLS
    34  // transport over which RTP and RTCP packets are sent and received by
    35  // RTPSender and RTPReceiver, as well other data such as SCTP packets sent
    36  // and received by data channels.
    37  type DTLSTransport struct {
    38  	lock sync.RWMutex
    39  
    40  	iceTransport          *ICETransport
    41  	certificates          []Certificate
    42  	remoteParameters      DTLSParameters
    43  	remoteCertificate     []byte
    44  	state                 DTLSTransportState
    45  	srtpProtectionProfile srtp.ProtectionProfile
    46  
    47  	onStateChangeHandler   func(DTLSTransportState)
    48  	internalOnCloseHandler func()
    49  
    50  	conn *dtls.Conn
    51  
    52  	srtpSession, srtcpSession   atomic.Value
    53  	srtpEndpoint, srtcpEndpoint *mux.Endpoint
    54  	simulcastStreams            []simulcastStreamPair
    55  	srtpReady                   chan struct{}
    56  
    57  	dtlsMatcher mux.MatchFunc
    58  
    59  	api *API
    60  	log logging.LeveledLogger
    61  }
    62  
    63  type simulcastStreamPair struct {
    64  	srtp  *srtp.ReadStreamSRTP
    65  	srtcp *srtp.ReadStreamSRTCP
    66  }
    67  
    68  // NewDTLSTransport creates a new DTLSTransport.
    69  // This constructor is part of the ORTC API. It is not
    70  // meant to be used together with the basic WebRTC API.
    71  func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
    72  	t := &DTLSTransport{
    73  		iceTransport: transport,
    74  		api:          api,
    75  		state:        DTLSTransportStateNew,
    76  		dtlsMatcher:  mux.MatchDTLS,
    77  		srtpReady:    make(chan struct{}),
    78  		log:          api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
    79  	}
    80  
    81  	if len(certificates) > 0 {
    82  		now := time.Now()
    83  		for _, x509Cert := range certificates {
    84  			if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
    85  				return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
    86  			}
    87  			t.certificates = append(t.certificates, x509Cert)
    88  		}
    89  	} else {
    90  		sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
    91  		if err != nil {
    92  			return nil, &rtcerr.UnknownError{Err: err}
    93  		}
    94  		certificate, err := GenerateCertificate(sk)
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  		t.certificates = []Certificate{*certificate}
    99  	}
   100  
   101  	return t, nil
   102  }
   103  
   104  // ICETransport returns the currently-configured *ICETransport or nil
   105  // if one has not been configured
   106  func (t *DTLSTransport) ICETransport() *ICETransport {
   107  	t.lock.RLock()
   108  	defer t.lock.RUnlock()
   109  	return t.iceTransport
   110  }
   111  
   112  // onStateChange requires the caller holds the lock
   113  func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
   114  	t.state = state
   115  	handler := t.onStateChangeHandler
   116  	if handler != nil {
   117  		handler(state)
   118  	}
   119  }
   120  
   121  // OnStateChange sets a handler that is fired when the DTLS
   122  // connection state changes.
   123  func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
   124  	t.lock.Lock()
   125  	defer t.lock.Unlock()
   126  	t.onStateChangeHandler = f
   127  }
   128  
   129  // State returns the current dtls transport state.
   130  func (t *DTLSTransport) State() DTLSTransportState {
   131  	t.lock.RLock()
   132  	defer t.lock.RUnlock()
   133  	return t.state
   134  }
   135  
   136  // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
   137  // packet is discarded.
   138  func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) {
   139  	raw, err := rtcp.Marshal(pkts)
   140  	if err != nil {
   141  		return 0, err
   142  	}
   143  
   144  	srtcpSession, err := t.getSRTCPSession()
   145  	if err != nil {
   146  		return 0, err
   147  	}
   148  
   149  	writeStream, err := srtcpSession.OpenWriteStream()
   150  	if err != nil {
   151  		// nolint
   152  		return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
   153  	}
   154  
   155  	return writeStream.Write(raw)
   156  }
   157  
   158  // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
   159  func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
   160  	fingerprints := []DTLSFingerprint{}
   161  
   162  	for _, c := range t.certificates {
   163  		prints, err := c.GetFingerprints()
   164  		if err != nil {
   165  			return DTLSParameters{}, err
   166  		}
   167  
   168  		fingerprints = append(fingerprints, prints...)
   169  	}
   170  
   171  	return DTLSParameters{
   172  		Role:         DTLSRoleAuto, // always returns the default role
   173  		Fingerprints: fingerprints,
   174  	}, nil
   175  }
   176  
   177  // GetRemoteCertificate returns the certificate chain in use by the remote side
   178  // returns an empty list prior to selection of the remote certificate
   179  func (t *DTLSTransport) GetRemoteCertificate() []byte {
   180  	t.lock.RLock()
   181  	defer t.lock.RUnlock()
   182  	return t.remoteCertificate
   183  }
   184  
   185  func (t *DTLSTransport) startSRTP() error {
   186  	srtpConfig := &srtp.Config{
   187  		Profile:       t.srtpProtectionProfile,
   188  		BufferFactory: t.api.settingEngine.BufferFactory,
   189  		LoggerFactory: t.api.settingEngine.LoggerFactory,
   190  	}
   191  	if t.api.settingEngine.replayProtection.SRTP != nil {
   192  		srtpConfig.RemoteOptions = append(
   193  			srtpConfig.RemoteOptions,
   194  			srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
   195  		)
   196  	}
   197  
   198  	if t.api.settingEngine.disableSRTPReplayProtection {
   199  		srtpConfig.RemoteOptions = append(
   200  			srtpConfig.RemoteOptions,
   201  			srtp.SRTPNoReplayProtection(),
   202  		)
   203  	}
   204  
   205  	if t.api.settingEngine.replayProtection.SRTCP != nil {
   206  		srtpConfig.RemoteOptions = append(
   207  			srtpConfig.RemoteOptions,
   208  			srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
   209  		)
   210  	}
   211  
   212  	if t.api.settingEngine.disableSRTCPReplayProtection {
   213  		srtpConfig.RemoteOptions = append(
   214  			srtpConfig.RemoteOptions,
   215  			srtp.SRTCPNoReplayProtection(),
   216  		)
   217  	}
   218  
   219  	connState, ok := t.conn.ConnectionState()
   220  	if !ok {
   221  		// nolint
   222  		return fmt.Errorf("%w: Failed to get DTLS ConnectionState", errDtlsKeyExtractionFailed)
   223  	}
   224  
   225  	err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
   226  	if err != nil {
   227  		// nolint
   228  		return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
   229  	}
   230  
   231  	srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
   232  	if err != nil {
   233  		// nolint
   234  		return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
   235  	}
   236  
   237  	srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
   238  	if err != nil {
   239  		// nolint
   240  		return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
   241  	}
   242  
   243  	t.srtpSession.Store(srtpSession)
   244  	t.srtcpSession.Store(srtcpSession)
   245  	close(t.srtpReady)
   246  	return nil
   247  }
   248  
   249  func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
   250  	if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
   251  		return value, nil
   252  	}
   253  
   254  	return nil, errDtlsTransportNotStarted
   255  }
   256  
   257  func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
   258  	if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
   259  		return value, nil
   260  	}
   261  
   262  	return nil, errDtlsTransportNotStarted
   263  }
   264  
   265  func (t *DTLSTransport) role() DTLSRole {
   266  	// If remote has an explicit role use the inverse
   267  	switch t.remoteParameters.Role {
   268  	case DTLSRoleClient:
   269  		return DTLSRoleServer
   270  	case DTLSRoleServer:
   271  		return DTLSRoleClient
   272  	default:
   273  	}
   274  
   275  	// If SettingEngine has an explicit role
   276  	switch t.api.settingEngine.answeringDTLSRole {
   277  	case DTLSRoleServer:
   278  		return DTLSRoleServer
   279  	case DTLSRoleClient:
   280  		return DTLSRoleClient
   281  	default:
   282  	}
   283  
   284  	// Remote was auto and no explicit role was configured via SettingEngine
   285  	if t.iceTransport.Role() == ICERoleControlling {
   286  		return DTLSRoleServer
   287  	}
   288  	return defaultDtlsRoleAnswer
   289  }
   290  
   291  // Start DTLS transport negotiation with the parameters of the remote DTLS transport
   292  func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint: gocognit
   293  	// Take lock and prepare connection, we must not hold the lock
   294  	// when connecting
   295  	prepareTransport := func() (DTLSRole, *dtls.Config, error) {
   296  		t.lock.Lock()
   297  		defer t.lock.Unlock()
   298  
   299  		if err := t.ensureICEConn(); err != nil {
   300  			return DTLSRole(0), nil, err
   301  		}
   302  
   303  		if t.state != DTLSTransportStateNew {
   304  			return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
   305  		}
   306  
   307  		t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
   308  		t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
   309  		t.remoteParameters = remoteParameters
   310  
   311  		cert := t.certificates[0]
   312  		t.onStateChange(DTLSTransportStateConnecting)
   313  
   314  		return t.role(), &dtls.Config{
   315  			Certificates: []tls.Certificate{
   316  				{
   317  					Certificate: [][]byte{cert.x509Cert.Raw},
   318  					PrivateKey:  cert.privateKey,
   319  				},
   320  			},
   321  			SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
   322  				if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
   323  					return t.api.settingEngine.srtpProtectionProfiles
   324  				}
   325  
   326  				return defaultSrtpProtectionProfiles()
   327  			}(),
   328  			ClientAuth:         dtls.RequireAnyClientCert,
   329  			LoggerFactory:      t.api.settingEngine.LoggerFactory,
   330  			InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify,
   331  			CustomCipherSuites: t.api.settingEngine.dtls.customCipherSuites,
   332  		}, nil
   333  	}
   334  
   335  	var dtlsConn *dtls.Conn
   336  	dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
   337  	dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)
   338  	role, dtlsConfig, err := prepareTransport()
   339  	if err != nil {
   340  		return err
   341  	}
   342  
   343  	if t.api.settingEngine.replayProtection.DTLS != nil {
   344  		dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
   345  	}
   346  
   347  	if t.api.settingEngine.dtls.clientAuth != nil {
   348  		dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth
   349  	}
   350  
   351  	dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
   352  	dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify
   353  	dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves
   354  	dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret
   355  	dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs
   356  	dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
   357  	dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter
   358  	dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook
   359  	dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook
   360  	dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook
   361  
   362  	// Connect as DTLS Client/Server, function is blocking and we
   363  	// must not hold the DTLSTransport lock
   364  	if role == DTLSRoleClient {
   365  		dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
   366  	} else {
   367  		dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig)
   368  	}
   369  
   370  	if err == nil {
   371  		if t.api.settingEngine.dtls.connectContextMaker != nil {
   372  			handshakeCtx, _ := t.api.settingEngine.dtls.connectContextMaker()
   373  			err = dtlsConn.HandshakeContext(handshakeCtx)
   374  		} else {
   375  			err = dtlsConn.Handshake()
   376  		}
   377  	}
   378  
   379  	// Re-take the lock, nothing beyond here is blocking
   380  	t.lock.Lock()
   381  	defer t.lock.Unlock()
   382  
   383  	if err != nil {
   384  		t.onStateChange(DTLSTransportStateFailed)
   385  		return err
   386  	}
   387  
   388  	srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
   389  	if !ok {
   390  		t.onStateChange(DTLSTransportStateFailed)
   391  		return ErrNoSRTPProtectionProfile
   392  	}
   393  
   394  	switch srtpProfile {
   395  	case dtls.SRTP_AEAD_AES_128_GCM:
   396  		t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
   397  	case dtls.SRTP_AEAD_AES_256_GCM:
   398  		t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm
   399  	case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
   400  		t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
   401  	case dtls.SRTP_NULL_HMAC_SHA1_80:
   402  		t.srtpProtectionProfile = srtp.ProtectionProfileNullHmacSha1_80
   403  	default:
   404  		t.onStateChange(DTLSTransportStateFailed)
   405  		return ErrNoSRTPProtectionProfile
   406  	}
   407  
   408  	// Check the fingerprint if a certificate was exchanged
   409  	connectionState, ok := dtlsConn.ConnectionState()
   410  	if !ok {
   411  		t.onStateChange(DTLSTransportStateFailed)
   412  		return errNoRemoteCertificate
   413  	}
   414  
   415  	if len(connectionState.PeerCertificates) == 0 {
   416  		t.onStateChange(DTLSTransportStateFailed)
   417  		return errNoRemoteCertificate
   418  	}
   419  	t.remoteCertificate = connectionState.PeerCertificates[0]
   420  
   421  	if !t.api.settingEngine.disableCertificateFingerprintVerification {
   422  		parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
   423  		if err != nil {
   424  			if closeErr := dtlsConn.Close(); closeErr != nil {
   425  				t.log.Error(err.Error())
   426  			}
   427  
   428  			t.onStateChange(DTLSTransportStateFailed)
   429  			return err
   430  		}
   431  
   432  		if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
   433  			if closeErr := dtlsConn.Close(); closeErr != nil {
   434  				t.log.Error(err.Error())
   435  			}
   436  
   437  			t.onStateChange(DTLSTransportStateFailed)
   438  			return err
   439  		}
   440  	}
   441  
   442  	t.conn = dtlsConn
   443  	t.onStateChange(DTLSTransportStateConnected)
   444  
   445  	return t.startSRTP()
   446  }
   447  
   448  // Stop stops and closes the DTLSTransport object.
   449  func (t *DTLSTransport) Stop() error {
   450  	t.lock.Lock()
   451  	defer t.lock.Unlock()
   452  
   453  	// Try closing everything and collect the errors
   454  	var closeErrs []error
   455  
   456  	if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil {
   457  		closeErrs = append(closeErrs, srtpSession.Close())
   458  	}
   459  
   460  	if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil {
   461  		closeErrs = append(closeErrs, srtcpSession.Close())
   462  	}
   463  
   464  	for i := range t.simulcastStreams {
   465  		closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close())
   466  		closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close())
   467  	}
   468  
   469  	if t.conn != nil {
   470  		// dtls connection may be closed on sctp close.
   471  		if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
   472  			closeErrs = append(closeErrs, err)
   473  		}
   474  	}
   475  	t.onStateChange(DTLSTransportStateClosed)
   476  	return util.FlattenErrs(closeErrs)
   477  }
   478  
   479  func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
   480  	for _, fp := range t.remoteParameters.Fingerprints {
   481  		hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
   482  		if err != nil {
   483  			return err
   484  		}
   485  
   486  		remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
   487  		if err != nil {
   488  			return err
   489  		}
   490  
   491  		if strings.EqualFold(remoteValue, fp.Value) {
   492  			return nil
   493  		}
   494  	}
   495  
   496  	return errNoMatchingCertificateFingerprint
   497  }
   498  
   499  func (t *DTLSTransport) ensureICEConn() error {
   500  	if t.iceTransport == nil {
   501  		return errICEConnectionNotStarted
   502  	}
   503  
   504  	return nil
   505  }
   506  
   507  func (t *DTLSTransport) storeSimulcastStream(srtpReadStream *srtp.ReadStreamSRTP, srtcpReadStream *srtp.ReadStreamSRTCP) {
   508  	t.lock.Lock()
   509  	defer t.lock.Unlock()
   510  
   511  	t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream})
   512  }
   513  
   514  func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
   515  	srtpSession, err := t.getSRTPSession()
   516  	if err != nil {
   517  		return nil, nil, nil, nil, err
   518  	}
   519  
   520  	rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
   521  	if err != nil {
   522  		return nil, nil, nil, nil, err
   523  	}
   524  
   525  	rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
   526  		n, err = rtpReadStream.Read(in)
   527  		return n, a, err
   528  	}))
   529  
   530  	srtcpSession, err := t.getSRTCPSession()
   531  	if err != nil {
   532  		return nil, nil, nil, nil, err
   533  	}
   534  
   535  	rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
   536  	if err != nil {
   537  		return nil, nil, nil, nil, err
   538  	}
   539  
   540  	rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
   541  		n, err = rtcpReadStream.Read(in)
   542  		return n, a, err
   543  	}))
   544  
   545  	return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
   546  }