github.com/pion/webrtc/v4@v4.0.1/sctptransport.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "errors" 11 "io" 12 "math" 13 "sync" 14 "time" 15 16 "github.com/pion/datachannel" 17 "github.com/pion/logging" 18 "github.com/pion/sctp" 19 "github.com/pion/webrtc/v4/pkg/rtcerr" 20 ) 21 22 const sctpMaxChannels = uint16(65535) 23 24 // SCTPTransport provides details about the SCTP transport. 25 type SCTPTransport struct { 26 lock sync.RWMutex 27 28 dtlsTransport *DTLSTransport 29 30 // State represents the current state of the SCTP transport. 31 state SCTPTransportState 32 33 // SCTPTransportState doesn't have an enum to distinguish between New/Connecting 34 // so we need a dedicated field 35 isStarted bool 36 37 // MaxMessageSize represents the maximum size of data that can be passed to 38 // DataChannel's send() method. 39 maxMessageSize float64 40 41 // MaxChannels represents the maximum amount of DataChannel's that can 42 // be used simultaneously. 43 maxChannels *uint16 44 45 // OnStateChange func() 46 47 onErrorHandler func(error) 48 onCloseHandler func(error) 49 50 sctpAssociation *sctp.Association 51 onDataChannelHandler func(*DataChannel) 52 onDataChannelOpenedHandler func(*DataChannel) 53 54 // DataChannels 55 dataChannels []*DataChannel 56 dataChannelIDsUsed map[uint16]struct{} 57 dataChannelsOpened uint32 58 dataChannelsRequested uint32 59 dataChannelsAccepted uint32 60 61 api *API 62 log logging.LeveledLogger 63 } 64 65 // NewSCTPTransport creates a new SCTPTransport. 66 // This constructor is part of the ORTC API. It is not 67 // meant to be used together with the basic WebRTC API. 68 func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport { 69 res := &SCTPTransport{ 70 dtlsTransport: dtls, 71 state: SCTPTransportStateConnecting, 72 api: api, 73 log: api.settingEngine.LoggerFactory.NewLogger("ortc"), 74 dataChannelIDsUsed: make(map[uint16]struct{}), 75 } 76 77 res.updateMessageSize() 78 res.updateMaxChannels() 79 80 return res 81 } 82 83 // Transport returns the DTLSTransport instance the SCTPTransport is sending over. 84 func (r *SCTPTransport) Transport() *DTLSTransport { 85 r.lock.RLock() 86 defer r.lock.RUnlock() 87 88 return r.dtlsTransport 89 } 90 91 // GetCapabilities returns the SCTPCapabilities of the SCTPTransport. 92 func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { 93 return SCTPCapabilities{ 94 MaxMessageSize: 0, 95 } 96 } 97 98 // Start the SCTPTransport. Since both local and remote parties must mutually 99 // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish 100 // a connection over SCTP. 101 func (r *SCTPTransport) Start(SCTPCapabilities) error { 102 if r.isStarted { 103 return nil 104 } 105 r.isStarted = true 106 107 dtlsTransport := r.Transport() 108 if dtlsTransport == nil || dtlsTransport.conn == nil { 109 return errSCTPTransportDTLS 110 } 111 sctpAssociation, err := sctp.Client(sctp.Config{ 112 NetConn: dtlsTransport.conn, 113 MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize, 114 EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum, 115 LoggerFactory: r.api.settingEngine.LoggerFactory, 116 RTOMax: float64(r.api.settingEngine.sctp.rtoMax) / float64(time.Millisecond), 117 }) 118 if err != nil { 119 return err 120 } 121 122 r.lock.Lock() 123 r.sctpAssociation = sctpAssociation 124 r.state = SCTPTransportStateConnected 125 dataChannels := append([]*DataChannel{}, r.dataChannels...) 126 r.lock.Unlock() 127 128 var openedDCCount uint32 129 for _, d := range dataChannels { 130 if d.ReadyState() == DataChannelStateConnecting { 131 err := d.open(r) 132 if err != nil { 133 r.log.Warnf("failed to open data channel: %s", err) 134 continue 135 } 136 openedDCCount++ 137 } 138 } 139 140 r.lock.Lock() 141 r.dataChannelsOpened += openedDCCount 142 r.lock.Unlock() 143 144 go r.acceptDataChannels(sctpAssociation) 145 146 return nil 147 } 148 149 // Stop stops the SCTPTransport 150 func (r *SCTPTransport) Stop() error { 151 r.lock.Lock() 152 defer r.lock.Unlock() 153 if r.sctpAssociation == nil { 154 return nil 155 } 156 157 r.sctpAssociation.Abort("") 158 159 r.sctpAssociation = nil 160 r.state = SCTPTransportStateClosed 161 162 return nil 163 } 164 165 func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { 166 r.lock.RLock() 167 dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels)) 168 for _, dc := range r.dataChannels { 169 dc.mu.Lock() 170 isNil := dc.dataChannel == nil 171 dc.mu.Unlock() 172 if isNil { 173 continue 174 } 175 dataChannels = append(dataChannels, dc.dataChannel) 176 } 177 r.lock.RUnlock() 178 179 ACCEPT: 180 for { 181 dc, err := datachannel.Accept(a, &datachannel.Config{ 182 LoggerFactory: r.api.settingEngine.LoggerFactory, 183 }, dataChannels...) 184 if err != nil { 185 if !errors.Is(err, io.EOF) { 186 r.log.Errorf("Failed to accept data channel: %v", err) 187 r.onError(err) 188 r.onClose(err) 189 } else { 190 r.onClose(nil) 191 } 192 return 193 } 194 for _, ch := range dataChannels { 195 if ch.StreamIdentifier() == dc.StreamIdentifier() { 196 continue ACCEPT 197 } 198 } 199 200 var ( 201 maxRetransmits *uint16 202 maxPacketLifeTime *uint16 203 ) 204 val := uint16(dc.Config.ReliabilityParameter) 205 ordered := true 206 207 switch dc.Config.ChannelType { 208 case datachannel.ChannelTypeReliable: 209 ordered = true 210 case datachannel.ChannelTypeReliableUnordered: 211 ordered = false 212 case datachannel.ChannelTypePartialReliableRexmit: 213 ordered = true 214 maxRetransmits = &val 215 case datachannel.ChannelTypePartialReliableRexmitUnordered: 216 ordered = false 217 maxRetransmits = &val 218 case datachannel.ChannelTypePartialReliableTimed: 219 ordered = true 220 maxPacketLifeTime = &val 221 case datachannel.ChannelTypePartialReliableTimedUnordered: 222 ordered = false 223 maxPacketLifeTime = &val 224 default: 225 } 226 227 sid := dc.StreamIdentifier() 228 rtcDC, err := r.api.newDataChannel(&DataChannelParameters{ 229 ID: &sid, 230 Label: dc.Config.Label, 231 Protocol: dc.Config.Protocol, 232 Negotiated: dc.Config.Negotiated, 233 Ordered: ordered, 234 MaxPacketLifeTime: maxPacketLifeTime, 235 MaxRetransmits: maxRetransmits, 236 }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc")) 237 if err != nil { 238 // This data channel is invalid. Close it and log an error. 239 if err1 := dc.Close(); err1 != nil { 240 r.log.Errorf("Failed to close invalid data channel: %v", err1) 241 } 242 r.log.Errorf("Failed to accept data channel: %v", err) 243 r.onError(err) 244 // We've received a datachannel with invalid configuration. We can still receive other datachannels. 245 continue ACCEPT 246 } 247 248 <-r.onDataChannel(rtcDC) 249 rtcDC.handleOpen(dc, true, dc.Config.Negotiated) 250 251 r.lock.Lock() 252 r.dataChannelsOpened++ 253 handler := r.onDataChannelOpenedHandler 254 r.lock.Unlock() 255 256 if handler != nil { 257 handler(rtcDC) 258 } 259 } 260 } 261 262 // OnError sets an event handler which is invoked when the SCTP Association errors. 263 func (r *SCTPTransport) OnError(f func(err error)) { 264 r.lock.Lock() 265 defer r.lock.Unlock() 266 r.onErrorHandler = f 267 } 268 269 func (r *SCTPTransport) onError(err error) { 270 r.lock.RLock() 271 handler := r.onErrorHandler 272 r.lock.RUnlock() 273 274 if handler != nil { 275 go handler(err) 276 } 277 } 278 279 // OnClose sets an event handler which is invoked when the SCTP Association closes. 280 func (r *SCTPTransport) OnClose(f func(err error)) { 281 r.lock.Lock() 282 defer r.lock.Unlock() 283 r.onCloseHandler = f 284 } 285 286 func (r *SCTPTransport) onClose(err error) { 287 r.lock.RLock() 288 handler := r.onCloseHandler 289 r.lock.RUnlock() 290 291 if handler != nil { 292 go handler(err) 293 } 294 } 295 296 // OnDataChannel sets an event handler which is invoked when a data 297 // channel message arrives from a remote peer. 298 func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) { 299 r.lock.Lock() 300 defer r.lock.Unlock() 301 r.onDataChannelHandler = f 302 } 303 304 // OnDataChannelOpened sets an event handler which is invoked when a data 305 // channel is opened 306 func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) { 307 r.lock.Lock() 308 defer r.lock.Unlock() 309 r.onDataChannelOpenedHandler = f 310 } 311 312 func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) { 313 r.lock.Lock() 314 r.dataChannels = append(r.dataChannels, dc) 315 r.dataChannelsAccepted++ 316 if dc.ID() != nil { 317 r.dataChannelIDsUsed[*dc.ID()] = struct{}{} 318 } else { 319 // This cannot happen, the constructor for this datachannel in the caller 320 // takes a pointer to the id. 321 r.log.Errorf("accepted data channel with no ID") 322 } 323 handler := r.onDataChannelHandler 324 r.lock.Unlock() 325 326 done = make(chan struct{}) 327 if handler == nil || dc == nil { 328 close(done) 329 return 330 } 331 332 // Run this synchronously to allow setup done in onDataChannelFn() 333 // to complete before datachannel event handlers might be called. 334 go func() { 335 handler(dc) 336 close(done) 337 }() 338 339 return 340 } 341 342 func (r *SCTPTransport) updateMessageSize() { 343 r.lock.Lock() 344 defer r.lock.Unlock() 345 346 var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758 347 var canSendSize float64 = 65536 // pion/webrtc#758 348 349 r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize) 350 } 351 352 func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 { 353 switch { 354 case remoteMaxMessageSize == 0 && 355 canSendSize == 0: 356 return math.Inf(1) 357 358 case remoteMaxMessageSize == 0: 359 return canSendSize 360 361 case canSendSize == 0: 362 return remoteMaxMessageSize 363 364 case canSendSize > remoteMaxMessageSize: 365 return remoteMaxMessageSize 366 367 default: 368 return canSendSize 369 } 370 } 371 372 func (r *SCTPTransport) updateMaxChannels() { 373 val := sctpMaxChannels 374 r.maxChannels = &val 375 } 376 377 // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously. 378 func (r *SCTPTransport) MaxChannels() uint16 { 379 r.lock.Lock() 380 defer r.lock.Unlock() 381 382 if r.maxChannels == nil { 383 return sctpMaxChannels 384 } 385 386 return *r.maxChannels 387 } 388 389 // State returns the current state of the SCTPTransport 390 func (r *SCTPTransport) State() SCTPTransportState { 391 r.lock.RLock() 392 defer r.lock.RUnlock() 393 return r.state 394 } 395 396 func (r *SCTPTransport) collectStats(collector *statsReportCollector) { 397 collector.Collecting() 398 399 stats := SCTPTransportStats{ 400 Timestamp: statsTimestampFrom(time.Now()), 401 Type: StatsTypeSCTPTransport, 402 ID: "sctpTransport", 403 } 404 405 association := r.association() 406 if association != nil { 407 stats.BytesSent = association.BytesSent() 408 stats.BytesReceived = association.BytesReceived() 409 stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds 410 stats.CongestionWindow = association.CWND() 411 stats.ReceiverWindow = association.RWND() 412 stats.MTU = association.MTU() 413 } 414 415 collector.Collect(stats.ID, stats) 416 } 417 418 func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error { 419 var id uint16 420 if dtlsRole != DTLSRoleClient { 421 id++ 422 } 423 424 maxVal := r.MaxChannels() 425 426 r.lock.Lock() 427 defer r.lock.Unlock() 428 429 for ; id < maxVal-1; id += 2 { 430 if _, ok := r.dataChannelIDsUsed[id]; ok { 431 continue 432 } 433 *idOut = &id 434 r.dataChannelIDsUsed[id] = struct{}{} 435 return nil 436 } 437 438 return &rtcerr.OperationError{Err: ErrMaxDataChannelID} 439 } 440 441 func (r *SCTPTransport) association() *sctp.Association { 442 if r == nil { 443 return nil 444 } 445 r.lock.RLock() 446 association := r.sctpAssociation 447 r.lock.RUnlock() 448 return association 449 }