github.com/pion/webrtc/v3@v3.2.24/sctptransport.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 //go:build !js 5 // +build !js 6 7 package webrtc 8 9 import ( 10 "errors" 11 "io" 12 "math" 13 "sync" 14 "time" 15 16 "github.com/pion/datachannel" 17 "github.com/pion/logging" 18 "github.com/pion/sctp" 19 "github.com/pion/webrtc/v3/pkg/rtcerr" 20 ) 21 22 const sctpMaxChannels = uint16(65535) 23 24 // SCTPTransport provides details about the SCTP transport. 25 type SCTPTransport struct { 26 lock sync.RWMutex 27 28 dtlsTransport *DTLSTransport 29 30 // State represents the current state of the SCTP transport. 31 state SCTPTransportState 32 33 // SCTPTransportState doesn't have an enum to distinguish between New/Connecting 34 // so we need a dedicated field 35 isStarted bool 36 37 // MaxMessageSize represents the maximum size of data that can be passed to 38 // DataChannel's send() method. 39 maxMessageSize float64 40 41 // MaxChannels represents the maximum amount of DataChannel's that can 42 // be used simultaneously. 43 maxChannels *uint16 44 45 // OnStateChange func() 46 47 onErrorHandler func(error) 48 49 sctpAssociation *sctp.Association 50 onDataChannelHandler func(*DataChannel) 51 onDataChannelOpenedHandler func(*DataChannel) 52 53 // DataChannels 54 dataChannels []*DataChannel 55 dataChannelsOpened uint32 56 dataChannelsRequested uint32 57 dataChannelsAccepted uint32 58 59 api *API 60 log logging.LeveledLogger 61 } 62 63 // NewSCTPTransport creates a new SCTPTransport. 64 // This constructor is part of the ORTC API. It is not 65 // meant to be used together with the basic WebRTC API. 66 func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport { 67 res := &SCTPTransport{ 68 dtlsTransport: dtls, 69 state: SCTPTransportStateConnecting, 70 api: api, 71 log: api.settingEngine.LoggerFactory.NewLogger("ortc"), 72 } 73 74 res.updateMessageSize() 75 res.updateMaxChannels() 76 77 return res 78 } 79 80 // Transport returns the DTLSTransport instance the SCTPTransport is sending over. 81 func (r *SCTPTransport) Transport() *DTLSTransport { 82 r.lock.RLock() 83 defer r.lock.RUnlock() 84 85 return r.dtlsTransport 86 } 87 88 // GetCapabilities returns the SCTPCapabilities of the SCTPTransport. 89 func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { 90 return SCTPCapabilities{ 91 MaxMessageSize: 0, 92 } 93 } 94 95 // Start the SCTPTransport. Since both local and remote parties must mutually 96 // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish 97 // a connection over SCTP. 98 func (r *SCTPTransport) Start(SCTPCapabilities) error { 99 if r.isStarted { 100 return nil 101 } 102 r.isStarted = true 103 104 dtlsTransport := r.Transport() 105 if dtlsTransport == nil || dtlsTransport.conn == nil { 106 return errSCTPTransportDTLS 107 } 108 109 sctpAssociation, err := sctp.Client(sctp.Config{ 110 NetConn: dtlsTransport.conn, 111 MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize, 112 LoggerFactory: r.api.settingEngine.LoggerFactory, 113 }) 114 if err != nil { 115 return err 116 } 117 118 r.lock.Lock() 119 r.sctpAssociation = sctpAssociation 120 r.state = SCTPTransportStateConnected 121 dataChannels := append([]*DataChannel{}, r.dataChannels...) 122 r.lock.Unlock() 123 124 var openedDCCount uint32 125 for _, d := range dataChannels { 126 if d.ReadyState() == DataChannelStateConnecting { 127 err := d.open(r) 128 if err != nil { 129 r.log.Warnf("failed to open data channel: %s", err) 130 continue 131 } 132 openedDCCount++ 133 } 134 } 135 136 r.lock.Lock() 137 r.dataChannelsOpened += openedDCCount 138 r.lock.Unlock() 139 140 go r.acceptDataChannels(sctpAssociation) 141 142 return nil 143 } 144 145 // Stop stops the SCTPTransport 146 func (r *SCTPTransport) Stop() error { 147 r.lock.Lock() 148 defer r.lock.Unlock() 149 if r.sctpAssociation == nil { 150 return nil 151 } 152 err := r.sctpAssociation.Close() 153 if err != nil { 154 return err 155 } 156 157 r.sctpAssociation = nil 158 r.state = SCTPTransportStateClosed 159 160 return nil 161 } 162 163 func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { 164 r.lock.RLock() 165 dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels)) 166 for _, dc := range r.dataChannels { 167 dc.mu.Lock() 168 isNil := dc.dataChannel == nil 169 dc.mu.Unlock() 170 if isNil { 171 continue 172 } 173 dataChannels = append(dataChannels, dc.dataChannel) 174 } 175 r.lock.RUnlock() 176 ACCEPT: 177 for { 178 dc, err := datachannel.Accept(a, &datachannel.Config{ 179 LoggerFactory: r.api.settingEngine.LoggerFactory, 180 }, dataChannels...) 181 if err != nil { 182 if !errors.Is(err, io.EOF) { 183 r.log.Errorf("Failed to accept data channel: %v", err) 184 r.onError(err) 185 } 186 return 187 } 188 for _, ch := range dataChannels { 189 if ch.StreamIdentifier() == dc.StreamIdentifier() { 190 continue ACCEPT 191 } 192 } 193 194 var ( 195 maxRetransmits *uint16 196 maxPacketLifeTime *uint16 197 ) 198 val := uint16(dc.Config.ReliabilityParameter) 199 ordered := true 200 201 switch dc.Config.ChannelType { 202 case datachannel.ChannelTypeReliable: 203 ordered = true 204 case datachannel.ChannelTypeReliableUnordered: 205 ordered = false 206 case datachannel.ChannelTypePartialReliableRexmit: 207 ordered = true 208 maxRetransmits = &val 209 case datachannel.ChannelTypePartialReliableRexmitUnordered: 210 ordered = false 211 maxRetransmits = &val 212 case datachannel.ChannelTypePartialReliableTimed: 213 ordered = true 214 maxPacketLifeTime = &val 215 case datachannel.ChannelTypePartialReliableTimedUnordered: 216 ordered = false 217 maxPacketLifeTime = &val 218 default: 219 } 220 221 sid := dc.StreamIdentifier() 222 rtcDC, err := r.api.newDataChannel(&DataChannelParameters{ 223 ID: &sid, 224 Label: dc.Config.Label, 225 Protocol: dc.Config.Protocol, 226 Negotiated: dc.Config.Negotiated, 227 Ordered: ordered, 228 MaxPacketLifeTime: maxPacketLifeTime, 229 MaxRetransmits: maxRetransmits, 230 }, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc")) 231 if err != nil { 232 r.log.Errorf("Failed to accept data channel: %v", err) 233 r.onError(err) 234 return 235 } 236 237 <-r.onDataChannel(rtcDC) 238 rtcDC.handleOpen(dc, true, dc.Config.Negotiated) 239 240 r.lock.Lock() 241 r.dataChannelsOpened++ 242 handler := r.onDataChannelOpenedHandler 243 r.lock.Unlock() 244 245 if handler != nil { 246 handler(rtcDC) 247 } 248 } 249 } 250 251 // OnError sets an event handler which is invoked when 252 // the SCTP connection error occurs. 253 func (r *SCTPTransport) OnError(f func(err error)) { 254 r.lock.Lock() 255 defer r.lock.Unlock() 256 r.onErrorHandler = f 257 } 258 259 func (r *SCTPTransport) onError(err error) { 260 r.lock.RLock() 261 handler := r.onErrorHandler 262 r.lock.RUnlock() 263 264 if handler != nil { 265 go handler(err) 266 } 267 } 268 269 // OnDataChannel sets an event handler which is invoked when a data 270 // channel message arrives from a remote peer. 271 func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) { 272 r.lock.Lock() 273 defer r.lock.Unlock() 274 r.onDataChannelHandler = f 275 } 276 277 // OnDataChannelOpened sets an event handler which is invoked when a data 278 // channel is opened 279 func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) { 280 r.lock.Lock() 281 defer r.lock.Unlock() 282 r.onDataChannelOpenedHandler = f 283 } 284 285 func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) { 286 r.lock.Lock() 287 r.dataChannels = append(r.dataChannels, dc) 288 r.dataChannelsAccepted++ 289 handler := r.onDataChannelHandler 290 r.lock.Unlock() 291 292 done = make(chan struct{}) 293 if handler == nil || dc == nil { 294 close(done) 295 return 296 } 297 298 // Run this synchronously to allow setup done in onDataChannelFn() 299 // to complete before datachannel event handlers might be called. 300 go func() { 301 handler(dc) 302 close(done) 303 }() 304 305 return 306 } 307 308 func (r *SCTPTransport) updateMessageSize() { 309 r.lock.Lock() 310 defer r.lock.Unlock() 311 312 var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758 313 var canSendSize float64 = 65536 // pion/webrtc#758 314 315 r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize) 316 } 317 318 func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 { 319 switch { 320 case remoteMaxMessageSize == 0 && 321 canSendSize == 0: 322 return math.Inf(1) 323 324 case remoteMaxMessageSize == 0: 325 return canSendSize 326 327 case canSendSize == 0: 328 return remoteMaxMessageSize 329 330 case canSendSize > remoteMaxMessageSize: 331 return remoteMaxMessageSize 332 333 default: 334 return canSendSize 335 } 336 } 337 338 func (r *SCTPTransport) updateMaxChannels() { 339 val := sctpMaxChannels 340 r.maxChannels = &val 341 } 342 343 // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously. 344 func (r *SCTPTransport) MaxChannels() uint16 { 345 r.lock.Lock() 346 defer r.lock.Unlock() 347 348 if r.maxChannels == nil { 349 return sctpMaxChannels 350 } 351 352 return *r.maxChannels 353 } 354 355 // State returns the current state of the SCTPTransport 356 func (r *SCTPTransport) State() SCTPTransportState { 357 r.lock.RLock() 358 defer r.lock.RUnlock() 359 return r.state 360 } 361 362 func (r *SCTPTransport) collectStats(collector *statsReportCollector) { 363 collector.Collecting() 364 365 stats := SCTPTransportStats{ 366 Timestamp: statsTimestampFrom(time.Now()), 367 Type: StatsTypeSCTPTransport, 368 ID: "sctpTransport", 369 } 370 371 association := r.association() 372 if association != nil { 373 stats.BytesSent = association.BytesSent() 374 stats.BytesReceived = association.BytesReceived() 375 stats.SmoothedRoundTripTime = association.SRTT() * 0.001 // convert milliseconds to seconds 376 stats.CongestionWindow = association.CWND() 377 stats.ReceiverWindow = association.RWND() 378 stats.MTU = association.MTU() 379 } 380 381 collector.Collect(stats.ID, stats) 382 } 383 384 func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error { 385 var id uint16 386 if dtlsRole != DTLSRoleClient { 387 id++ 388 } 389 390 max := r.MaxChannels() 391 392 r.lock.Lock() 393 defer r.lock.Unlock() 394 395 // Create map of ids so we can compare without double-looping each time. 396 idsMap := make(map[uint16]struct{}, len(r.dataChannels)) 397 for _, dc := range r.dataChannels { 398 if dc.ID() == nil { 399 continue 400 } 401 402 idsMap[*dc.ID()] = struct{}{} 403 } 404 405 for ; id < max-1; id += 2 { 406 if _, ok := idsMap[id]; ok { 407 continue 408 } 409 *idOut = &id 410 return nil 411 } 412 413 return &rtcerr.OperationError{Err: ErrMaxDataChannelID} 414 } 415 416 func (r *SCTPTransport) association() *sctp.Association { 417 if r == nil { 418 return nil 419 } 420 r.lock.RLock() 421 association := r.sctpAssociation 422 r.lock.RUnlock() 423 return association 424 }