github.com/pion/webrtc/v4@v4.0.1/internal/mux/mux.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  // Package mux multiplexes packets on a single socket (RFC7983)
     5  package mux
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"net"
    11  	"sync"
    12  
    13  	"github.com/pion/ice/v4"
    14  	"github.com/pion/logging"
    15  	"github.com/pion/transport/v3/packetio"
    16  )
    17  
    18  const (
    19  	// The maximum amount of data that can be buffered before returning errors.
    20  	maxBufferSize = 1000 * 1000 // 1MB
    21  
    22  	// How many total pending packets can be cached
    23  	maxPendingPackets = 15
    24  )
    25  
    26  // Config collects the arguments to mux.Mux construction into
    27  // a single structure
    28  type Config struct {
    29  	Conn          net.Conn
    30  	BufferSize    int
    31  	LoggerFactory logging.LoggerFactory
    32  }
    33  
    34  // Mux allows multiplexing
    35  type Mux struct {
    36  	nextConn   net.Conn
    37  	bufferSize int
    38  	lock       sync.Mutex
    39  	endpoints  map[*Endpoint]MatchFunc
    40  	isClosed   bool
    41  
    42  	pendingPackets [][]byte
    43  
    44  	closedCh chan struct{}
    45  	log      logging.LeveledLogger
    46  }
    47  
    48  // NewMux creates a new Mux
    49  func NewMux(config Config) *Mux {
    50  	m := &Mux{
    51  		nextConn:   config.Conn,
    52  		endpoints:  make(map[*Endpoint]MatchFunc),
    53  		bufferSize: config.BufferSize,
    54  		closedCh:   make(chan struct{}),
    55  		log:        config.LoggerFactory.NewLogger("mux"),
    56  	}
    57  
    58  	go m.readLoop()
    59  
    60  	return m
    61  }
    62  
    63  // NewEndpoint creates a new Endpoint
    64  func (m *Mux) NewEndpoint(f MatchFunc) *Endpoint {
    65  	e := &Endpoint{
    66  		mux:    m,
    67  		buffer: packetio.NewBuffer(),
    68  	}
    69  
    70  	// Set a maximum size of the buffer in bytes.
    71  	e.buffer.SetLimitSize(maxBufferSize)
    72  
    73  	m.lock.Lock()
    74  	m.endpoints[e] = f
    75  	m.lock.Unlock()
    76  
    77  	go m.handlePendingPackets(e, f)
    78  
    79  	return e
    80  }
    81  
    82  // RemoveEndpoint removes an endpoint from the Mux
    83  func (m *Mux) RemoveEndpoint(e *Endpoint) {
    84  	m.lock.Lock()
    85  	defer m.lock.Unlock()
    86  	delete(m.endpoints, e)
    87  }
    88  
    89  // Close closes the Mux and all associated Endpoints.
    90  func (m *Mux) Close() error {
    91  	m.lock.Lock()
    92  	for e := range m.endpoints {
    93  		if err := e.close(); err != nil {
    94  			m.lock.Unlock()
    95  			return err
    96  		}
    97  
    98  		delete(m.endpoints, e)
    99  	}
   100  	m.isClosed = true
   101  	m.lock.Unlock()
   102  
   103  	err := m.nextConn.Close()
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	// Wait for readLoop to end
   109  	<-m.closedCh
   110  
   111  	return nil
   112  }
   113  
   114  func (m *Mux) readLoop() {
   115  	defer func() {
   116  		close(m.closedCh)
   117  	}()
   118  
   119  	buf := make([]byte, m.bufferSize)
   120  	for {
   121  		n, err := m.nextConn.Read(buf)
   122  		switch {
   123  		case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed):
   124  			return
   125  		case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout):
   126  			m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error())
   127  			continue
   128  		case err != nil:
   129  			m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error())
   130  			return
   131  		}
   132  
   133  		if err = m.dispatch(buf[:n]); err != nil {
   134  			if errors.Is(err, io.ErrClosedPipe) {
   135  				// if the buffer was closed, that's not an error we care to report
   136  				return
   137  			}
   138  			m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error())
   139  			return
   140  		}
   141  	}
   142  }
   143  
   144  func (m *Mux) dispatch(buf []byte) error {
   145  	if len(buf) == 0 {
   146  		m.log.Warnf("Warning: mux: unable to dispatch zero length packet")
   147  		return nil
   148  	}
   149  
   150  	var endpoint *Endpoint
   151  
   152  	m.lock.Lock()
   153  	for e, f := range m.endpoints {
   154  		if f(buf) {
   155  			endpoint = e
   156  			break
   157  		}
   158  	}
   159  	if endpoint == nil {
   160  		defer m.lock.Unlock()
   161  
   162  		if !m.isClosed {
   163  			if len(m.pendingPackets) >= maxPendingPackets {
   164  				m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", buf[0], len(m.pendingPackets))
   165  			} else {
   166  				m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", buf[0], len(m.pendingPackets))
   167  				m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...))
   168  			}
   169  		}
   170  		return nil
   171  	}
   172  
   173  	m.lock.Unlock()
   174  	_, err := endpoint.buffer.Write(buf)
   175  
   176  	// Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
   177  	if errors.Is(err, packetio.ErrFull) {
   178  		m.log.Infof("mux: endpoint buffer is full, dropping packet")
   179  		return nil
   180  	}
   181  
   182  	return err
   183  }
   184  
   185  func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) {
   186  	m.lock.Lock()
   187  	defer m.lock.Unlock()
   188  
   189  	pendingPackets := make([][]byte, len(m.pendingPackets))
   190  	for _, buf := range m.pendingPackets {
   191  		if matchFunc(buf) {
   192  			if _, err := endpoint.buffer.Write(buf); err != nil {
   193  				m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err)
   194  			}
   195  		} else {
   196  			pendingPackets = append(pendingPackets, buf)
   197  		}
   198  	}
   199  	m.pendingPackets = pendingPackets
   200  }