github.com/pion/webrtc/v4@v4.0.1/internal/mux/mux.go (about) 1 // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly> 2 // SPDX-License-Identifier: MIT 3 4 // Package mux multiplexes packets on a single socket (RFC7983) 5 package mux 6 7 import ( 8 "errors" 9 "io" 10 "net" 11 "sync" 12 13 "github.com/pion/ice/v4" 14 "github.com/pion/logging" 15 "github.com/pion/transport/v3/packetio" 16 ) 17 18 const ( 19 // The maximum amount of data that can be buffered before returning errors. 20 maxBufferSize = 1000 * 1000 // 1MB 21 22 // How many total pending packets can be cached 23 maxPendingPackets = 15 24 ) 25 26 // Config collects the arguments to mux.Mux construction into 27 // a single structure 28 type Config struct { 29 Conn net.Conn 30 BufferSize int 31 LoggerFactory logging.LoggerFactory 32 } 33 34 // Mux allows multiplexing 35 type Mux struct { 36 nextConn net.Conn 37 bufferSize int 38 lock sync.Mutex 39 endpoints map[*Endpoint]MatchFunc 40 isClosed bool 41 42 pendingPackets [][]byte 43 44 closedCh chan struct{} 45 log logging.LeveledLogger 46 } 47 48 // NewMux creates a new Mux 49 func NewMux(config Config) *Mux { 50 m := &Mux{ 51 nextConn: config.Conn, 52 endpoints: make(map[*Endpoint]MatchFunc), 53 bufferSize: config.BufferSize, 54 closedCh: make(chan struct{}), 55 log: config.LoggerFactory.NewLogger("mux"), 56 } 57 58 go m.readLoop() 59 60 return m 61 } 62 63 // NewEndpoint creates a new Endpoint 64 func (m *Mux) NewEndpoint(f MatchFunc) *Endpoint { 65 e := &Endpoint{ 66 mux: m, 67 buffer: packetio.NewBuffer(), 68 } 69 70 // Set a maximum size of the buffer in bytes. 71 e.buffer.SetLimitSize(maxBufferSize) 72 73 m.lock.Lock() 74 m.endpoints[e] = f 75 m.lock.Unlock() 76 77 go m.handlePendingPackets(e, f) 78 79 return e 80 } 81 82 // RemoveEndpoint removes an endpoint from the Mux 83 func (m *Mux) RemoveEndpoint(e *Endpoint) { 84 m.lock.Lock() 85 defer m.lock.Unlock() 86 delete(m.endpoints, e) 87 } 88 89 // Close closes the Mux and all associated Endpoints. 90 func (m *Mux) Close() error { 91 m.lock.Lock() 92 for e := range m.endpoints { 93 if err := e.close(); err != nil { 94 m.lock.Unlock() 95 return err 96 } 97 98 delete(m.endpoints, e) 99 } 100 m.isClosed = true 101 m.lock.Unlock() 102 103 err := m.nextConn.Close() 104 if err != nil { 105 return err 106 } 107 108 // Wait for readLoop to end 109 <-m.closedCh 110 111 return nil 112 } 113 114 func (m *Mux) readLoop() { 115 defer func() { 116 close(m.closedCh) 117 }() 118 119 buf := make([]byte, m.bufferSize) 120 for { 121 n, err := m.nextConn.Read(buf) 122 switch { 123 case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed): 124 return 125 case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout): 126 m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error()) 127 continue 128 case err != nil: 129 m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error()) 130 return 131 } 132 133 if err = m.dispatch(buf[:n]); err != nil { 134 if errors.Is(err, io.ErrClosedPipe) { 135 // if the buffer was closed, that's not an error we care to report 136 return 137 } 138 m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error()) 139 return 140 } 141 } 142 } 143 144 func (m *Mux) dispatch(buf []byte) error { 145 if len(buf) == 0 { 146 m.log.Warnf("Warning: mux: unable to dispatch zero length packet") 147 return nil 148 } 149 150 var endpoint *Endpoint 151 152 m.lock.Lock() 153 for e, f := range m.endpoints { 154 if f(buf) { 155 endpoint = e 156 break 157 } 158 } 159 if endpoint == nil { 160 defer m.lock.Unlock() 161 162 if !m.isClosed { 163 if len(m.pendingPackets) >= maxPendingPackets { 164 m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", buf[0], len(m.pendingPackets)) 165 } else { 166 m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", buf[0], len(m.pendingPackets)) 167 m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...)) 168 } 169 } 170 return nil 171 } 172 173 m.lock.Unlock() 174 _, err := endpoint.buffer.Write(buf) 175 176 // Expected when bytes are received faster than the endpoint can process them (#2152, #2180) 177 if errors.Is(err, packetio.ErrFull) { 178 m.log.Infof("mux: endpoint buffer is full, dropping packet") 179 return nil 180 } 181 182 return err 183 } 184 185 func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) { 186 m.lock.Lock() 187 defer m.lock.Unlock() 188 189 pendingPackets := make([][]byte, len(m.pendingPackets)) 190 for _, buf := range m.pendingPackets { 191 if matchFunc(buf) { 192 if _, err := endpoint.buffer.Write(buf); err != nil { 193 m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err) 194 } 195 } else { 196 pendingPackets = append(pendingPackets, buf) 197 } 198 } 199 m.pendingPackets = pendingPackets 200 }