github.com/pion/webrtc/v3@v3.2.24/dtlstransport.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "crypto/ecdsa" 11 "crypto/elliptic" 12 "crypto/rand" 13 "crypto/tls" 14 "crypto/x509" 15 "errors" 16 "fmt" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "github.com/pion/dtls/v2" 23 "github.com/pion/dtls/v2/pkg/crypto/fingerprint" 24 "github.com/pion/interceptor" 25 "github.com/pion/logging" 26 "github.com/pion/rtcp" 27 "github.com/pion/srtp/v2" 28 "github.com/pion/webrtc/v3/internal/mux" 29 "github.com/pion/webrtc/v3/internal/util" 30 "github.com/pion/webrtc/v3/pkg/rtcerr" 31 ) 32 33 // DTLSTransport allows an application access to information about the DTLS 34 // transport over which RTP and RTCP packets are sent and received by 35 // RTPSender and RTPReceiver, as well other data such as SCTP packets sent 36 // and received by data channels. 37 type DTLSTransport struct { 38 lock sync.RWMutex 39 40 iceTransport *ICETransport 41 certificates []Certificate 42 remoteParameters DTLSParameters 43 remoteCertificate []byte 44 state DTLSTransportState 45 srtpProtectionProfile srtp.ProtectionProfile 46 47 onStateChangeHandler func(DTLSTransportState) 48 49 conn *dtls.Conn 50 51 srtpSession, srtcpSession atomic.Value 52 srtpEndpoint, srtcpEndpoint *mux.Endpoint 53 simulcastStreams []*srtp.ReadStreamSRTP 54 srtpReady chan struct{} 55 56 dtlsMatcher mux.MatchFunc 57 58 api *API 59 log logging.LeveledLogger 60 } 61 62 // NewDTLSTransport creates a new DTLSTransport. 63 // This constructor is part of the ORTC API. It is not 64 // meant to be used together with the basic WebRTC API. 65 func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) { 66 t := &DTLSTransport{ 67 iceTransport: transport, 68 api: api, 69 state: DTLSTransportStateNew, 70 dtlsMatcher: mux.MatchDTLS, 71 srtpReady: make(chan struct{}), 72 log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"), 73 } 74 75 if len(certificates) > 0 { 76 now := time.Now() 77 for _, x509Cert := range certificates { 78 if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) { 79 return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired} 80 } 81 t.certificates = append(t.certificates, x509Cert) 82 } 83 } else { 84 sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 85 if err != nil { 86 return nil, &rtcerr.UnknownError{Err: err} 87 } 88 certificate, err := GenerateCertificate(sk) 89 if err != nil { 90 return nil, err 91 } 92 t.certificates = []Certificate{*certificate} 93 } 94 95 return t, nil 96 } 97 98 // ICETransport returns the currently-configured *ICETransport or nil 99 // if one has not been configured 100 func (t *DTLSTransport) ICETransport() *ICETransport { 101 t.lock.RLock() 102 defer t.lock.RUnlock() 103 return t.iceTransport 104 } 105 106 // onStateChange requires the caller holds the lock 107 func (t *DTLSTransport) onStateChange(state DTLSTransportState) { 108 t.state = state 109 handler := t.onStateChangeHandler 110 if handler != nil { 111 handler(state) 112 } 113 } 114 115 // OnStateChange sets a handler that is fired when the DTLS 116 // connection state changes. 117 func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) { 118 t.lock.Lock() 119 defer t.lock.Unlock() 120 t.onStateChangeHandler = f 121 } 122 123 // State returns the current dtls transport state. 124 func (t *DTLSTransport) State() DTLSTransportState { 125 t.lock.RLock() 126 defer t.lock.RUnlock() 127 return t.state 128 } 129 130 // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the 131 // packet is discarded. 132 func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) { 133 raw, err := rtcp.Marshal(pkts) 134 if err != nil { 135 return 0, err 136 } 137 138 srtcpSession, err := t.getSRTCPSession() 139 if err != nil { 140 return 0, err 141 } 142 143 writeStream, err := srtcpSession.OpenWriteStream() 144 if err != nil { 145 // nolint 146 return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) 147 } 148 149 return writeStream.Write(raw) 150 } 151 152 // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction. 153 func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) { 154 fingerprints := []DTLSFingerprint{} 155 156 for _, c := range t.certificates { 157 prints, err := c.GetFingerprints() 158 if err != nil { 159 return DTLSParameters{}, err 160 } 161 162 fingerprints = append(fingerprints, prints...) 163 } 164 165 return DTLSParameters{ 166 Role: DTLSRoleAuto, // always returns the default role 167 Fingerprints: fingerprints, 168 }, nil 169 } 170 171 // GetRemoteCertificate returns the certificate chain in use by the remote side 172 // returns an empty list prior to selection of the remote certificate 173 func (t *DTLSTransport) GetRemoteCertificate() []byte { 174 t.lock.RLock() 175 defer t.lock.RUnlock() 176 return t.remoteCertificate 177 } 178 179 func (t *DTLSTransport) startSRTP() error { 180 srtpConfig := &srtp.Config{ 181 Profile: t.srtpProtectionProfile, 182 BufferFactory: t.api.settingEngine.BufferFactory, 183 LoggerFactory: t.api.settingEngine.LoggerFactory, 184 } 185 if t.api.settingEngine.replayProtection.SRTP != nil { 186 srtpConfig.RemoteOptions = append( 187 srtpConfig.RemoteOptions, 188 srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP), 189 ) 190 } 191 192 if t.api.settingEngine.disableSRTPReplayProtection { 193 srtpConfig.RemoteOptions = append( 194 srtpConfig.RemoteOptions, 195 srtp.SRTPNoReplayProtection(), 196 ) 197 } 198 199 if t.api.settingEngine.replayProtection.SRTCP != nil { 200 srtpConfig.RemoteOptions = append( 201 srtpConfig.RemoteOptions, 202 srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP), 203 ) 204 } 205 206 if t.api.settingEngine.disableSRTCPReplayProtection { 207 srtpConfig.RemoteOptions = append( 208 srtpConfig.RemoteOptions, 209 srtp.SRTCPNoReplayProtection(), 210 ) 211 } 212 213 connState := t.conn.ConnectionState() 214 err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient) 215 if err != nil { 216 // nolint 217 return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) 218 } 219 220 srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) 221 if err != nil { 222 // nolint 223 return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) 224 } 225 226 srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) 227 if err != nil { 228 // nolint 229 return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) 230 } 231 232 t.srtpSession.Store(srtpSession) 233 t.srtcpSession.Store(srtcpSession) 234 close(t.srtpReady) 235 return nil 236 } 237 238 func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) { 239 if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { 240 return value, nil 241 } 242 243 return nil, errDtlsTransportNotStarted 244 } 245 246 func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) { 247 if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { 248 return value, nil 249 } 250 251 return nil, errDtlsTransportNotStarted 252 } 253 254 func (t *DTLSTransport) role() DTLSRole { 255 // If remote has an explicit role use the inverse 256 switch t.remoteParameters.Role { 257 case DTLSRoleClient: 258 return DTLSRoleServer 259 case DTLSRoleServer: 260 return DTLSRoleClient 261 default: 262 } 263 264 // If SettingEngine has an explicit role 265 switch t.api.settingEngine.answeringDTLSRole { 266 case DTLSRoleServer: 267 return DTLSRoleServer 268 case DTLSRoleClient: 269 return DTLSRoleClient 270 default: 271 } 272 273 // Remote was auto and no explicit role was configured via SettingEngine 274 if t.iceTransport.Role() == ICERoleControlling { 275 return DTLSRoleServer 276 } 277 return defaultDtlsRoleAnswer 278 } 279 280 // Start DTLS transport negotiation with the parameters of the remote DTLS transport 281 func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { 282 // Take lock and prepare connection, we must not hold the lock 283 // when connecting 284 prepareTransport := func() (DTLSRole, *dtls.Config, error) { 285 t.lock.Lock() 286 defer t.lock.Unlock() 287 288 if err := t.ensureICEConn(); err != nil { 289 return DTLSRole(0), nil, err 290 } 291 292 if t.state != DTLSTransportStateNew { 293 return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)} 294 } 295 296 t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP) 297 t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP) 298 t.remoteParameters = remoteParameters 299 300 cert := t.certificates[0] 301 t.onStateChange(DTLSTransportStateConnecting) 302 303 return t.role(), &dtls.Config{ 304 Certificates: []tls.Certificate{ 305 { 306 Certificate: [][]byte{cert.x509Cert.Raw}, 307 PrivateKey: cert.privateKey, 308 }, 309 }, 310 SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile { 311 if len(t.api.settingEngine.srtpProtectionProfiles) > 0 { 312 return t.api.settingEngine.srtpProtectionProfiles 313 } 314 315 return defaultSrtpProtectionProfiles() 316 }(), 317 ClientAuth: dtls.RequireAnyClientCert, 318 LoggerFactory: t.api.settingEngine.LoggerFactory, 319 InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify, 320 }, nil 321 } 322 323 var dtlsConn *dtls.Conn 324 dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS) 325 role, dtlsConfig, err := prepareTransport() 326 if err != nil { 327 return err 328 } 329 330 if t.api.settingEngine.replayProtection.DTLS != nil { 331 dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS) 332 } 333 334 if t.api.settingEngine.dtls.clientAuth != nil { 335 dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth 336 } 337 338 dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval 339 dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify 340 dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves 341 dtlsConfig.ConnectContextMaker = t.api.settingEngine.dtls.connectContextMaker 342 dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret 343 dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs 344 dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs 345 dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter 346 347 // Connect as DTLS Client/Server, function is blocking and we 348 // must not hold the DTLSTransport lock 349 if role == DTLSRoleClient { 350 dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig) 351 } else { 352 dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig) 353 } 354 355 // Re-take the lock, nothing beyond here is blocking 356 t.lock.Lock() 357 defer t.lock.Unlock() 358 359 if err != nil { 360 t.onStateChange(DTLSTransportStateFailed) 361 return err 362 } 363 364 srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile() 365 if !ok { 366 t.onStateChange(DTLSTransportStateFailed) 367 return ErrNoSRTPProtectionProfile 368 } 369 370 switch srtpProfile { 371 case dtls.SRTP_AEAD_AES_128_GCM: 372 t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm 373 case dtls.SRTP_AEAD_AES_256_GCM: 374 t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm 375 case dtls.SRTP_AES128_CM_HMAC_SHA1_80: 376 t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80 377 default: 378 t.onStateChange(DTLSTransportStateFailed) 379 return ErrNoSRTPProtectionProfile 380 } 381 382 // Check the fingerprint if a certificate was exchanged 383 remoteCerts := dtlsConn.ConnectionState().PeerCertificates 384 if len(remoteCerts) == 0 { 385 t.onStateChange(DTLSTransportStateFailed) 386 return errNoRemoteCertificate 387 } 388 t.remoteCertificate = remoteCerts[0] 389 390 if !t.api.settingEngine.disableCertificateFingerprintVerification { 391 parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate) 392 if err != nil { 393 if closeErr := dtlsConn.Close(); closeErr != nil { 394 t.log.Error(err.Error()) 395 } 396 397 t.onStateChange(DTLSTransportStateFailed) 398 return err 399 } 400 401 if err = t.validateFingerPrint(parsedRemoteCert); err != nil { 402 if closeErr := dtlsConn.Close(); closeErr != nil { 403 t.log.Error(err.Error()) 404 } 405 406 t.onStateChange(DTLSTransportStateFailed) 407 return err 408 } 409 } 410 411 t.conn = dtlsConn 412 t.onStateChange(DTLSTransportStateConnected) 413 414 return t.startSRTP() 415 } 416 417 // Stop stops and closes the DTLSTransport object. 418 func (t *DTLSTransport) Stop() error { 419 t.lock.Lock() 420 defer t.lock.Unlock() 421 422 // Try closing everything and collect the errors 423 var closeErrs []error 424 425 if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil { 426 closeErrs = append(closeErrs, srtpSession.Close()) 427 } 428 429 if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil { 430 closeErrs = append(closeErrs, srtcpSession.Close()) 431 } 432 433 for i := range t.simulcastStreams { 434 closeErrs = append(closeErrs, t.simulcastStreams[i].Close()) 435 } 436 437 if t.conn != nil { 438 // dtls connection may be closed on sctp close. 439 if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) { 440 closeErrs = append(closeErrs, err) 441 } 442 } 443 t.onStateChange(DTLSTransportStateClosed) 444 return util.FlattenErrs(closeErrs) 445 } 446 447 func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error { 448 for _, fp := range t.remoteParameters.Fingerprints { 449 hashAlgo, err := fingerprint.HashFromString(fp.Algorithm) 450 if err != nil { 451 return err 452 } 453 454 remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo) 455 if err != nil { 456 return err 457 } 458 459 if strings.EqualFold(remoteValue, fp.Value) { 460 return nil 461 } 462 } 463 464 return errNoMatchingCertificateFingerprint 465 } 466 467 func (t *DTLSTransport) ensureICEConn() error { 468 if t.iceTransport == nil { 469 return errICEConnectionNotStarted 470 } 471 472 return nil 473 } 474 475 func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) { 476 t.lock.Lock() 477 defer t.lock.Unlock() 478 479 t.simulcastStreams = append(t.simulcastStreams, s) 480 } 481 482 func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { 483 srtpSession, err := t.getSRTPSession() 484 if err != nil { 485 return nil, nil, nil, nil, err 486 } 487 488 rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) 489 if err != nil { 490 return nil, nil, nil, nil, err 491 } 492 493 rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { 494 n, err = rtpReadStream.Read(in) 495 return n, a, err 496 })) 497 498 srtcpSession, err := t.getSRTCPSession() 499 if err != nil { 500 return nil, nil, nil, nil, err 501 } 502 503 rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) 504 if err != nil { 505 return nil, nil, nil, nil, err 506 } 507 508 rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { 509 n, err = rtcpReadStream.Read(in) 510 return n, a, err 511 })) 512 513 return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil 514 }