github.com/pion/webrtc/v3@v3.2.24/internal/mux/mux.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  // Package mux multiplexes packets on a single socket (RFC7983)
     5  package mux
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"net"
    11  	"sync"
    12  
    13  	"github.com/pion/ice/v2"
    14  	"github.com/pion/logging"
    15  	"github.com/pion/transport/v2/packetio"
    16  )
    17  
    18  // The maximum amount of data that can be buffered before returning errors.
    19  const maxBufferSize = 1000 * 1000 // 1MB
    20  
    21  // Config collects the arguments to mux.Mux construction into
    22  // a single structure
    23  type Config struct {
    24  	Conn          net.Conn
    25  	BufferSize    int
    26  	LoggerFactory logging.LoggerFactory
    27  }
    28  
    29  // Mux allows multiplexing
    30  type Mux struct {
    31  	lock       sync.RWMutex
    32  	nextConn   net.Conn
    33  	endpoints  map[*Endpoint]MatchFunc
    34  	bufferSize int
    35  	closedCh   chan struct{}
    36  
    37  	log logging.LeveledLogger
    38  }
    39  
    40  // NewMux creates a new Mux
    41  func NewMux(config Config) *Mux {
    42  	m := &Mux{
    43  		nextConn:   config.Conn,
    44  		endpoints:  make(map[*Endpoint]MatchFunc),
    45  		bufferSize: config.BufferSize,
    46  		closedCh:   make(chan struct{}),
    47  		log:        config.LoggerFactory.NewLogger("mux"),
    48  	}
    49  
    50  	go m.readLoop()
    51  
    52  	return m
    53  }
    54  
    55  // NewEndpoint creates a new Endpoint
    56  func (m *Mux) NewEndpoint(f MatchFunc) *Endpoint {
    57  	e := &Endpoint{
    58  		mux:    m,
    59  		buffer: packetio.NewBuffer(),
    60  	}
    61  
    62  	// Set a maximum size of the buffer in bytes.
    63  	e.buffer.SetLimitSize(maxBufferSize)
    64  
    65  	m.lock.Lock()
    66  	m.endpoints[e] = f
    67  	m.lock.Unlock()
    68  
    69  	return e
    70  }
    71  
    72  // RemoveEndpoint removes an endpoint from the Mux
    73  func (m *Mux) RemoveEndpoint(e *Endpoint) {
    74  	m.lock.Lock()
    75  	defer m.lock.Unlock()
    76  	delete(m.endpoints, e)
    77  }
    78  
    79  // Close closes the Mux and all associated Endpoints.
    80  func (m *Mux) Close() error {
    81  	m.lock.Lock()
    82  	for e := range m.endpoints {
    83  		if err := e.close(); err != nil {
    84  			m.lock.Unlock()
    85  			return err
    86  		}
    87  
    88  		delete(m.endpoints, e)
    89  	}
    90  	m.lock.Unlock()
    91  
    92  	err := m.nextConn.Close()
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	// Wait for readLoop to end
    98  	<-m.closedCh
    99  
   100  	return nil
   101  }
   102  
   103  func (m *Mux) readLoop() {
   104  	defer func() {
   105  		close(m.closedCh)
   106  	}()
   107  
   108  	buf := make([]byte, m.bufferSize)
   109  	for {
   110  		n, err := m.nextConn.Read(buf)
   111  		switch {
   112  		case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed):
   113  			return
   114  		case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout):
   115  			m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error())
   116  			continue
   117  		case err != nil:
   118  			m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error())
   119  			return
   120  		}
   121  
   122  		if err = m.dispatch(buf[:n]); err != nil {
   123  			m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error())
   124  			return
   125  		}
   126  	}
   127  }
   128  
   129  func (m *Mux) dispatch(buf []byte) error {
   130  	var endpoint *Endpoint
   131  
   132  	m.lock.Lock()
   133  	for e, f := range m.endpoints {
   134  		if f(buf) {
   135  			endpoint = e
   136  			break
   137  		}
   138  	}
   139  	m.lock.Unlock()
   140  
   141  	if endpoint == nil {
   142  		if len(buf) > 0 {
   143  			m.log.Warnf("Warning: mux: no endpoint for packet starting with %d", buf[0])
   144  		} else {
   145  			m.log.Warnf("Warning: mux: no endpoint for zero length packet")
   146  		}
   147  		return nil
   148  	}
   149  
   150  	_, err := endpoint.buffer.Write(buf)
   151  
   152  	// Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
   153  	if errors.Is(err, packetio.ErrFull) {
   154  		m.log.Infof("mux: endpoint buffer is full, dropping packet")
   155  		return nil
   156  	}
   157  
   158  	return err
   159  }