github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/tools/proxy/proxy.go (about)

     1  package quicproxy
     2  
     3  import (
     4  	"net"
     5  	"sort"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/daeuniverse/quic-go/internal/protocol"
    10  	"github.com/daeuniverse/quic-go/internal/utils"
    11  )
    12  
    13  // Connection is a UDP connection
    14  type connection struct {
    15  	ClientAddr *net.UDPAddr // Address of the client
    16  	ServerConn *net.UDPConn // UDP connection to server
    17  
    18  	incomingPackets chan packetEntry
    19  
    20  	Incoming *queue
    21  	Outgoing *queue
    22  }
    23  
    24  func (c *connection) queuePacket(t time.Time, b []byte) {
    25  	c.incomingPackets <- packetEntry{Time: t, Raw: b}
    26  }
    27  
    28  // Direction is the direction a packet is sent.
    29  type Direction int
    30  
    31  const (
    32  	// DirectionIncoming is the direction from the client to the server.
    33  	DirectionIncoming Direction = iota
    34  	// DirectionOutgoing is the direction from the server to the client.
    35  	DirectionOutgoing
    36  	// DirectionBoth is both incoming and outgoing
    37  	DirectionBoth
    38  )
    39  
    40  type packetEntry struct {
    41  	Time time.Time
    42  	Raw  []byte
    43  }
    44  
    45  type packetEntries []packetEntry
    46  
    47  func (e packetEntries) Len() int           { return len(e) }
    48  func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) }
    49  func (e packetEntries) Swap(i, j int)      { e[i], e[j] = e[j], e[i] }
    50  
    51  type queue struct {
    52  	sync.Mutex
    53  
    54  	timer   *utils.Timer
    55  	Packets packetEntries
    56  }
    57  
    58  func newQueue() *queue {
    59  	return &queue{timer: utils.NewTimer()}
    60  }
    61  
    62  func (q *queue) Add(e packetEntry) {
    63  	q.Lock()
    64  	q.Packets = append(q.Packets, e)
    65  	if len(q.Packets) > 1 {
    66  		lastIndex := len(q.Packets) - 1
    67  		if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) {
    68  			sort.Stable(q.Packets)
    69  		}
    70  	}
    71  	q.timer.Reset(q.Packets[0].Time)
    72  	q.Unlock()
    73  }
    74  
    75  func (q *queue) Get() []byte {
    76  	q.Lock()
    77  	raw := q.Packets[0].Raw
    78  	q.Packets = q.Packets[1:]
    79  	if len(q.Packets) > 0 {
    80  		q.timer.Reset(q.Packets[0].Time)
    81  	}
    82  	q.Unlock()
    83  	return raw
    84  }
    85  
    86  func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() }
    87  func (q *queue) SetTimerRead()           { q.timer.SetRead() }
    88  
    89  func (q *queue) Close() { q.timer.Stop() }
    90  
    91  func (d Direction) String() string {
    92  	switch d {
    93  	case DirectionIncoming:
    94  		return "Incoming"
    95  	case DirectionOutgoing:
    96  		return "Outgoing"
    97  	case DirectionBoth:
    98  		return "both"
    99  	default:
   100  		panic("unknown direction")
   101  	}
   102  }
   103  
   104  // Is says if one direction matches another direction.
   105  // For example, incoming matches both incoming and both, but not outgoing.
   106  func (d Direction) Is(dir Direction) bool {
   107  	if d == DirectionBoth || dir == DirectionBoth {
   108  		return true
   109  	}
   110  	return d == dir
   111  }
   112  
   113  // DropCallback is a callback that determines which packet gets dropped.
   114  type DropCallback func(dir Direction, packet []byte) bool
   115  
   116  // NoDropper doesn't drop packets.
   117  var NoDropper DropCallback = func(Direction, []byte) bool {
   118  	return false
   119  }
   120  
   121  // DelayCallback is a callback that determines how much delay to apply to a packet.
   122  type DelayCallback func(dir Direction, packet []byte) time.Duration
   123  
   124  // NoDelay doesn't apply a delay.
   125  var NoDelay DelayCallback = func(Direction, []byte) time.Duration {
   126  	return 0
   127  }
   128  
   129  // Opts are proxy options.
   130  type Opts struct {
   131  	// The address this proxy proxies packets to.
   132  	RemoteAddr string
   133  	// DropPacket determines whether a packet gets dropped.
   134  	DropPacket DropCallback
   135  	// DelayPacket determines how long a packet gets delayed. This allows
   136  	// simulating a connection with non-zero RTTs.
   137  	// Note that the RTT is the sum of the delay for the incoming and the outgoing packet.
   138  	DelayPacket DelayCallback
   139  }
   140  
   141  // QuicProxy is a QUIC proxy that can drop and delay packets.
   142  type QuicProxy struct {
   143  	mutex sync.Mutex
   144  
   145  	closeChan chan struct{}
   146  
   147  	conn       *net.UDPConn
   148  	serverAddr *net.UDPAddr
   149  
   150  	dropPacket  DropCallback
   151  	delayPacket DelayCallback
   152  
   153  	// Mapping from client addresses (as host:port) to connection
   154  	clientDict map[string]*connection
   155  
   156  	logger utils.Logger
   157  }
   158  
   159  // NewQuicProxy creates a new UDP proxy
   160  func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) {
   161  	if opts == nil {
   162  		opts = &Opts{}
   163  	}
   164  	laddr, err := net.ResolveUDPAddr("udp", local)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	conn, err := net.ListenUDP("udp", laddr)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
   173  		return nil, err
   174  	}
   175  	if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
   176  		return nil, err
   177  	}
   178  	raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	packetDropper := NoDropper
   184  	if opts.DropPacket != nil {
   185  		packetDropper = opts.DropPacket
   186  	}
   187  
   188  	packetDelayer := NoDelay
   189  	if opts.DelayPacket != nil {
   190  		packetDelayer = opts.DelayPacket
   191  	}
   192  
   193  	p := QuicProxy{
   194  		clientDict:  make(map[string]*connection),
   195  		conn:        conn,
   196  		closeChan:   make(chan struct{}),
   197  		serverAddr:  raddr,
   198  		dropPacket:  packetDropper,
   199  		delayPacket: packetDelayer,
   200  		logger:      utils.DefaultLogger.WithPrefix("proxy"),
   201  	}
   202  
   203  	p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr)
   204  	go p.runProxy()
   205  	return &p, nil
   206  }
   207  
   208  // Close stops the UDP Proxy
   209  func (p *QuicProxy) Close() error {
   210  	p.mutex.Lock()
   211  	defer p.mutex.Unlock()
   212  	close(p.closeChan)
   213  	for _, c := range p.clientDict {
   214  		if err := c.ServerConn.Close(); err != nil {
   215  			return err
   216  		}
   217  		c.Incoming.Close()
   218  		c.Outgoing.Close()
   219  	}
   220  	return p.conn.Close()
   221  }
   222  
   223  // LocalAddr is the address the proxy is listening on.
   224  func (p *QuicProxy) LocalAddr() net.Addr {
   225  	return p.conn.LocalAddr()
   226  }
   227  
   228  // LocalPort is the UDP port number the proxy is listening on.
   229  func (p *QuicProxy) LocalPort() int {
   230  	return p.conn.LocalAddr().(*net.UDPAddr).Port
   231  }
   232  
   233  func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) {
   234  	conn, err := net.DialUDP("udp", nil, p.serverAddr)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil {
   239  		return nil, err
   240  	}
   241  	if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil {
   242  		return nil, err
   243  	}
   244  	return &connection{
   245  		ClientAddr:      cliAddr,
   246  		ServerConn:      conn,
   247  		incomingPackets: make(chan packetEntry, 10),
   248  		Incoming:        newQueue(),
   249  		Outgoing:        newQueue(),
   250  	}, nil
   251  }
   252  
   253  // runProxy listens on the proxy address and handles incoming packets.
   254  func (p *QuicProxy) runProxy() error {
   255  	for {
   256  		buffer := make([]byte, protocol.MaxPacketBufferSize)
   257  		n, cliaddr, err := p.conn.ReadFromUDP(buffer)
   258  		if err != nil {
   259  			return err
   260  		}
   261  		raw := buffer[0:n]
   262  
   263  		saddr := cliaddr.String()
   264  		p.mutex.Lock()
   265  		conn, ok := p.clientDict[saddr]
   266  
   267  		if !ok {
   268  			conn, err = p.newConnection(cliaddr)
   269  			if err != nil {
   270  				p.mutex.Unlock()
   271  				return err
   272  			}
   273  			p.clientDict[saddr] = conn
   274  			go p.runIncomingConnection(conn)
   275  			go p.runOutgoingConnection(conn)
   276  		}
   277  		p.mutex.Unlock()
   278  
   279  		if p.dropPacket(DirectionIncoming, raw) {
   280  			if p.logger.Debug() {
   281  				p.logger.Debugf("dropping incoming packet(%d bytes)", n)
   282  			}
   283  			continue
   284  		}
   285  
   286  		delay := p.delayPacket(DirectionIncoming, raw)
   287  		if delay == 0 {
   288  			if p.logger.Debug() {
   289  				p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr())
   290  			}
   291  			if _, err := conn.ServerConn.Write(raw); err != nil {
   292  				return err
   293  			}
   294  		} else {
   295  			now := time.Now()
   296  			if p.logger.Debug() {
   297  				p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay)
   298  			}
   299  			conn.queuePacket(now.Add(delay), raw)
   300  		}
   301  	}
   302  }
   303  
   304  // runConnection handles packets from server to a single client
   305  func (p *QuicProxy) runOutgoingConnection(conn *connection) error {
   306  	outgoingPackets := make(chan packetEntry, 10)
   307  	go func() {
   308  		for {
   309  			buffer := make([]byte, protocol.MaxPacketBufferSize)
   310  			n, err := conn.ServerConn.Read(buffer)
   311  			if err != nil {
   312  				return
   313  			}
   314  			raw := buffer[0:n]
   315  
   316  			if p.dropPacket(DirectionOutgoing, raw) {
   317  				if p.logger.Debug() {
   318  					p.logger.Debugf("dropping outgoing packet(%d bytes)", n)
   319  				}
   320  				continue
   321  			}
   322  
   323  			delay := p.delayPacket(DirectionOutgoing, raw)
   324  			if delay == 0 {
   325  				if p.logger.Debug() {
   326  					p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr)
   327  				}
   328  				if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil {
   329  					return
   330  				}
   331  			} else {
   332  				now := time.Now()
   333  				if p.logger.Debug() {
   334  					p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay)
   335  				}
   336  				outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw}
   337  			}
   338  		}
   339  	}()
   340  
   341  	for {
   342  		select {
   343  		case <-p.closeChan:
   344  			return nil
   345  		case e := <-outgoingPackets:
   346  			conn.Outgoing.Add(e)
   347  		case <-conn.Outgoing.Timer():
   348  			conn.Outgoing.SetTimerRead()
   349  			if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil {
   350  				return err
   351  			}
   352  		}
   353  	}
   354  }
   355  
   356  func (p *QuicProxy) runIncomingConnection(conn *connection) error {
   357  	for {
   358  		select {
   359  		case <-p.closeChan:
   360  			return nil
   361  		case e := <-conn.incomingPackets:
   362  			// Send the packet to the server
   363  			conn.Incoming.Add(e)
   364  		case <-conn.Incoming.Timer():
   365  			conn.Incoming.SetTimerRead()
   366  			if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil {
   367  				return err
   368  			}
   369  		}
   370  	}
   371  }