github.com/pion/webrtc/v4@v4.0.1/dtlstransport.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "crypto/ecdsa" 11 "crypto/elliptic" 12 "crypto/rand" 13 "crypto/tls" 14 "crypto/x509" 15 "errors" 16 "fmt" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "github.com/pion/dtls/v3" 23 "github.com/pion/dtls/v3/pkg/crypto/fingerprint" 24 "github.com/pion/interceptor" 25 "github.com/pion/logging" 26 "github.com/pion/rtcp" 27 "github.com/pion/srtp/v3" 28 "github.com/pion/webrtc/v4/internal/mux" 29 "github.com/pion/webrtc/v4/internal/util" 30 "github.com/pion/webrtc/v4/pkg/rtcerr" 31 ) 32 33 // DTLSTransport allows an application access to information about the DTLS 34 // transport over which RTP and RTCP packets are sent and received by 35 // RTPSender and RTPReceiver, as well other data such as SCTP packets sent 36 // and received by data channels. 37 type DTLSTransport struct { 38 lock sync.RWMutex 39 40 iceTransport *ICETransport 41 certificates []Certificate 42 remoteParameters DTLSParameters 43 remoteCertificate []byte 44 state DTLSTransportState 45 srtpProtectionProfile srtp.ProtectionProfile 46 47 onStateChangeHandler func(DTLSTransportState) 48 internalOnCloseHandler func() 49 50 conn *dtls.Conn 51 52 srtpSession, srtcpSession atomic.Value 53 srtpEndpoint, srtcpEndpoint *mux.Endpoint 54 simulcastStreams []simulcastStreamPair 55 srtpReady chan struct{} 56 57 dtlsMatcher mux.MatchFunc 58 59 api *API 60 log logging.LeveledLogger 61 } 62 63 type simulcastStreamPair struct { 64 srtp *srtp.ReadStreamSRTP 65 srtcp *srtp.ReadStreamSRTCP 66 } 67 68 // NewDTLSTransport creates a new DTLSTransport. 69 // This constructor is part of the ORTC API. It is not 70 // meant to be used together with the basic WebRTC API. 71 func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) { 72 t := &DTLSTransport{ 73 iceTransport: transport, 74 api: api, 75 state: DTLSTransportStateNew, 76 dtlsMatcher: mux.MatchDTLS, 77 srtpReady: make(chan struct{}), 78 log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"), 79 } 80 81 if len(certificates) > 0 { 82 now := time.Now() 83 for _, x509Cert := range certificates { 84 if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) { 85 return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired} 86 } 87 t.certificates = append(t.certificates, x509Cert) 88 } 89 } else { 90 sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 91 if err != nil { 92 return nil, &rtcerr.UnknownError{Err: err} 93 } 94 certificate, err := GenerateCertificate(sk) 95 if err != nil { 96 return nil, err 97 } 98 t.certificates = []Certificate{*certificate} 99 } 100 101 return t, nil 102 } 103 104 // ICETransport returns the currently-configured *ICETransport or nil 105 // if one has not been configured 106 func (t *DTLSTransport) ICETransport() *ICETransport { 107 t.lock.RLock() 108 defer t.lock.RUnlock() 109 return t.iceTransport 110 } 111 112 // onStateChange requires the caller holds the lock 113 func (t *DTLSTransport) onStateChange(state DTLSTransportState) { 114 t.state = state 115 handler := t.onStateChangeHandler 116 if handler != nil { 117 handler(state) 118 } 119 } 120 121 // OnStateChange sets a handler that is fired when the DTLS 122 // connection state changes. 123 func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) { 124 t.lock.Lock() 125 defer t.lock.Unlock() 126 t.onStateChangeHandler = f 127 } 128 129 // State returns the current dtls transport state. 130 func (t *DTLSTransport) State() DTLSTransportState { 131 t.lock.RLock() 132 defer t.lock.RUnlock() 133 return t.state 134 } 135 136 // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the 137 // packet is discarded. 138 func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) { 139 raw, err := rtcp.Marshal(pkts) 140 if err != nil { 141 return 0, err 142 } 143 144 srtcpSession, err := t.getSRTCPSession() 145 if err != nil { 146 return 0, err 147 } 148 149 writeStream, err := srtcpSession.OpenWriteStream() 150 if err != nil { 151 // nolint 152 return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err) 153 } 154 155 return writeStream.Write(raw) 156 } 157 158 // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction. 159 func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) { 160 fingerprints := []DTLSFingerprint{} 161 162 for _, c := range t.certificates { 163 prints, err := c.GetFingerprints() 164 if err != nil { 165 return DTLSParameters{}, err 166 } 167 168 fingerprints = append(fingerprints, prints...) 169 } 170 171 return DTLSParameters{ 172 Role: DTLSRoleAuto, // always returns the default role 173 Fingerprints: fingerprints, 174 }, nil 175 } 176 177 // GetRemoteCertificate returns the certificate chain in use by the remote side 178 // returns an empty list prior to selection of the remote certificate 179 func (t *DTLSTransport) GetRemoteCertificate() []byte { 180 t.lock.RLock() 181 defer t.lock.RUnlock() 182 return t.remoteCertificate 183 } 184 185 func (t *DTLSTransport) startSRTP() error { 186 srtpConfig := &srtp.Config{ 187 Profile: t.srtpProtectionProfile, 188 BufferFactory: t.api.settingEngine.BufferFactory, 189 LoggerFactory: t.api.settingEngine.LoggerFactory, 190 } 191 if t.api.settingEngine.replayProtection.SRTP != nil { 192 srtpConfig.RemoteOptions = append( 193 srtpConfig.RemoteOptions, 194 srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP), 195 ) 196 } 197 198 if t.api.settingEngine.disableSRTPReplayProtection { 199 srtpConfig.RemoteOptions = append( 200 srtpConfig.RemoteOptions, 201 srtp.SRTPNoReplayProtection(), 202 ) 203 } 204 205 if t.api.settingEngine.replayProtection.SRTCP != nil { 206 srtpConfig.RemoteOptions = append( 207 srtpConfig.RemoteOptions, 208 srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP), 209 ) 210 } 211 212 if t.api.settingEngine.disableSRTCPReplayProtection { 213 srtpConfig.RemoteOptions = append( 214 srtpConfig.RemoteOptions, 215 srtp.SRTCPNoReplayProtection(), 216 ) 217 } 218 219 connState, ok := t.conn.ConnectionState() 220 if !ok { 221 // nolint 222 return fmt.Errorf("%w: Failed to get DTLS ConnectionState", errDtlsKeyExtractionFailed) 223 } 224 225 err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient) 226 if err != nil { 227 // nolint 228 return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) 229 } 230 231 srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) 232 if err != nil { 233 // nolint 234 return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) 235 } 236 237 srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig) 238 if err != nil { 239 // nolint 240 return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err) 241 } 242 243 t.srtpSession.Store(srtpSession) 244 t.srtcpSession.Store(srtcpSession) 245 close(t.srtpReady) 246 return nil 247 } 248 249 func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) { 250 if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { 251 return value, nil 252 } 253 254 return nil, errDtlsTransportNotStarted 255 } 256 257 func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) { 258 if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { 259 return value, nil 260 } 261 262 return nil, errDtlsTransportNotStarted 263 } 264 265 func (t *DTLSTransport) role() DTLSRole { 266 // If remote has an explicit role use the inverse 267 switch t.remoteParameters.Role { 268 case DTLSRoleClient: 269 return DTLSRoleServer 270 case DTLSRoleServer: 271 return DTLSRoleClient 272 default: 273 } 274 275 // If SettingEngine has an explicit role 276 switch t.api.settingEngine.answeringDTLSRole { 277 case DTLSRoleServer: 278 return DTLSRoleServer 279 case DTLSRoleClient: 280 return DTLSRoleClient 281 default: 282 } 283 284 // Remote was auto and no explicit role was configured via SettingEngine 285 if t.iceTransport.Role() == ICERoleControlling { 286 return DTLSRoleServer 287 } 288 return defaultDtlsRoleAnswer 289 } 290 291 // Start DTLS transport negotiation with the parameters of the remote DTLS transport 292 func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { //nolint: gocognit 293 // Take lock and prepare connection, we must not hold the lock 294 // when connecting 295 prepareTransport := func() (DTLSRole, *dtls.Config, error) { 296 t.lock.Lock() 297 defer t.lock.Unlock() 298 299 if err := t.ensureICEConn(); err != nil { 300 return DTLSRole(0), nil, err 301 } 302 303 if t.state != DTLSTransportStateNew { 304 return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)} 305 } 306 307 t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP) 308 t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP) 309 t.remoteParameters = remoteParameters 310 311 cert := t.certificates[0] 312 t.onStateChange(DTLSTransportStateConnecting) 313 314 return t.role(), &dtls.Config{ 315 Certificates: []tls.Certificate{ 316 { 317 Certificate: [][]byte{cert.x509Cert.Raw}, 318 PrivateKey: cert.privateKey, 319 }, 320 }, 321 SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile { 322 if len(t.api.settingEngine.srtpProtectionProfiles) > 0 { 323 return t.api.settingEngine.srtpProtectionProfiles 324 } 325 326 return defaultSrtpProtectionProfiles() 327 }(), 328 ClientAuth: dtls.RequireAnyClientCert, 329 LoggerFactory: t.api.settingEngine.LoggerFactory, 330 InsecureSkipVerify: !t.api.settingEngine.dtls.disableInsecureSkipVerify, 331 CustomCipherSuites: t.api.settingEngine.dtls.customCipherSuites, 332 }, nil 333 } 334 335 var dtlsConn *dtls.Conn 336 dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS) 337 dtlsEndpoint.SetOnClose(t.internalOnCloseHandler) 338 role, dtlsConfig, err := prepareTransport() 339 if err != nil { 340 return err 341 } 342 343 if t.api.settingEngine.replayProtection.DTLS != nil { 344 dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS) 345 } 346 347 if t.api.settingEngine.dtls.clientAuth != nil { 348 dtlsConfig.ClientAuth = *t.api.settingEngine.dtls.clientAuth 349 } 350 351 dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval 352 dtlsConfig.InsecureSkipVerifyHello = t.api.settingEngine.dtls.insecureSkipHelloVerify 353 dtlsConfig.EllipticCurves = t.api.settingEngine.dtls.ellipticCurves 354 dtlsConfig.ExtendedMasterSecret = t.api.settingEngine.dtls.extendedMasterSecret 355 dtlsConfig.ClientCAs = t.api.settingEngine.dtls.clientCAs 356 dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs 357 dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter 358 dtlsConfig.ClientHelloMessageHook = t.api.settingEngine.dtls.clientHelloMessageHook 359 dtlsConfig.ServerHelloMessageHook = t.api.settingEngine.dtls.serverHelloMessageHook 360 dtlsConfig.CertificateRequestMessageHook = t.api.settingEngine.dtls.certificateRequestMessageHook 361 362 // Connect as DTLS Client/Server, function is blocking and we 363 // must not hold the DTLSTransport lock 364 if role == DTLSRoleClient { 365 dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig) 366 } else { 367 dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsEndpoint.RemoteAddr(), dtlsConfig) 368 } 369 370 if err == nil { 371 if t.api.settingEngine.dtls.connectContextMaker != nil { 372 handshakeCtx, _ := t.api.settingEngine.dtls.connectContextMaker() 373 err = dtlsConn.HandshakeContext(handshakeCtx) 374 } else { 375 err = dtlsConn.Handshake() 376 } 377 } 378 379 // Re-take the lock, nothing beyond here is blocking 380 t.lock.Lock() 381 defer t.lock.Unlock() 382 383 if err != nil { 384 t.onStateChange(DTLSTransportStateFailed) 385 return err 386 } 387 388 srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile() 389 if !ok { 390 t.onStateChange(DTLSTransportStateFailed) 391 return ErrNoSRTPProtectionProfile 392 } 393 394 switch srtpProfile { 395 case dtls.SRTP_AEAD_AES_128_GCM: 396 t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm 397 case dtls.SRTP_AEAD_AES_256_GCM: 398 t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes256Gcm 399 case dtls.SRTP_AES128_CM_HMAC_SHA1_80: 400 t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80 401 case dtls.SRTP_NULL_HMAC_SHA1_80: 402 t.srtpProtectionProfile = srtp.ProtectionProfileNullHmacSha1_80 403 default: 404 t.onStateChange(DTLSTransportStateFailed) 405 return ErrNoSRTPProtectionProfile 406 } 407 408 // Check the fingerprint if a certificate was exchanged 409 connectionState, ok := dtlsConn.ConnectionState() 410 if !ok { 411 t.onStateChange(DTLSTransportStateFailed) 412 return errNoRemoteCertificate 413 } 414 415 if len(connectionState.PeerCertificates) == 0 { 416 t.onStateChange(DTLSTransportStateFailed) 417 return errNoRemoteCertificate 418 } 419 t.remoteCertificate = connectionState.PeerCertificates[0] 420 421 if !t.api.settingEngine.disableCertificateFingerprintVerification { 422 parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate) 423 if err != nil { 424 if closeErr := dtlsConn.Close(); closeErr != nil { 425 t.log.Error(err.Error()) 426 } 427 428 t.onStateChange(DTLSTransportStateFailed) 429 return err 430 } 431 432 if err = t.validateFingerPrint(parsedRemoteCert); err != nil { 433 if closeErr := dtlsConn.Close(); closeErr != nil { 434 t.log.Error(err.Error()) 435 } 436 437 t.onStateChange(DTLSTransportStateFailed) 438 return err 439 } 440 } 441 442 t.conn = dtlsConn 443 t.onStateChange(DTLSTransportStateConnected) 444 445 return t.startSRTP() 446 } 447 448 // Stop stops and closes the DTLSTransport object. 449 func (t *DTLSTransport) Stop() error { 450 t.lock.Lock() 451 defer t.lock.Unlock() 452 453 // Try closing everything and collect the errors 454 var closeErrs []error 455 456 if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil { 457 closeErrs = append(closeErrs, srtpSession.Close()) 458 } 459 460 if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil { 461 closeErrs = append(closeErrs, srtcpSession.Close()) 462 } 463 464 for i := range t.simulcastStreams { 465 closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close()) 466 closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close()) 467 } 468 469 if t.conn != nil { 470 // dtls connection may be closed on sctp close. 471 if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) { 472 closeErrs = append(closeErrs, err) 473 } 474 } 475 t.onStateChange(DTLSTransportStateClosed) 476 return util.FlattenErrs(closeErrs) 477 } 478 479 func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error { 480 for _, fp := range t.remoteParameters.Fingerprints { 481 hashAlgo, err := fingerprint.HashFromString(fp.Algorithm) 482 if err != nil { 483 return err 484 } 485 486 remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo) 487 if err != nil { 488 return err 489 } 490 491 if strings.EqualFold(remoteValue, fp.Value) { 492 return nil 493 } 494 } 495 496 return errNoMatchingCertificateFingerprint 497 } 498 499 func (t *DTLSTransport) ensureICEConn() error { 500 if t.iceTransport == nil { 501 return errICEConnectionNotStarted 502 } 503 504 return nil 505 } 506 507 func (t *DTLSTransport) storeSimulcastStream(srtpReadStream *srtp.ReadStreamSRTP, srtcpReadStream *srtp.ReadStreamSRTCP) { 508 t.lock.Lock() 509 defer t.lock.Unlock() 510 511 t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream}) 512 } 513 514 func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { 515 srtpSession, err := t.getSRTPSession() 516 if err != nil { 517 return nil, nil, nil, nil, err 518 } 519 520 rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) 521 if err != nil { 522 return nil, nil, nil, nil, err 523 } 524 525 rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { 526 n, err = rtpReadStream.Read(in) 527 return n, a, err 528 })) 529 530 srtcpSession, err := t.getSRTCPSession() 531 if err != nil { 532 return nil, nil, nil, nil, err 533 } 534 535 rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) 536 if err != nil { 537 return nil, nil, nil, nil, err 538 } 539 540 rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { 541 n, err = rtcpReadStream.Read(in) 542 return n, a, err 543 })) 544 545 return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil 546 }