github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/tools/proxy/proxy.go (about) 1 package quicproxy 2 3 import ( 4 "net" 5 "sort" 6 "sync" 7 "time" 8 9 "github.com/daeuniverse/quic-go/internal/protocol" 10 "github.com/daeuniverse/quic-go/internal/utils" 11 ) 12 13 // Connection is a UDP connection 14 type connection struct { 15 ClientAddr *net.UDPAddr // Address of the client 16 ServerConn *net.UDPConn // UDP connection to server 17 18 incomingPackets chan packetEntry 19 20 Incoming *queue 21 Outgoing *queue 22 } 23 24 func (c *connection) queuePacket(t time.Time, b []byte) { 25 c.incomingPackets <- packetEntry{Time: t, Raw: b} 26 } 27 28 // Direction is the direction a packet is sent. 29 type Direction int 30 31 const ( 32 // DirectionIncoming is the direction from the client to the server. 33 DirectionIncoming Direction = iota 34 // DirectionOutgoing is the direction from the server to the client. 35 DirectionOutgoing 36 // DirectionBoth is both incoming and outgoing 37 DirectionBoth 38 ) 39 40 type packetEntry struct { 41 Time time.Time 42 Raw []byte 43 } 44 45 type packetEntries []packetEntry 46 47 func (e packetEntries) Len() int { return len(e) } 48 func (e packetEntries) Less(i, j int) bool { return e[i].Time.Before(e[j].Time) } 49 func (e packetEntries) Swap(i, j int) { e[i], e[j] = e[j], e[i] } 50 51 type queue struct { 52 sync.Mutex 53 54 timer *utils.Timer 55 Packets packetEntries 56 } 57 58 func newQueue() *queue { 59 return &queue{timer: utils.NewTimer()} 60 } 61 62 func (q *queue) Add(e packetEntry) { 63 q.Lock() 64 q.Packets = append(q.Packets, e) 65 if len(q.Packets) > 1 { 66 lastIndex := len(q.Packets) - 1 67 if q.Packets[lastIndex].Time.Before(q.Packets[lastIndex-1].Time) { 68 sort.Stable(q.Packets) 69 } 70 } 71 q.timer.Reset(q.Packets[0].Time) 72 q.Unlock() 73 } 74 75 func (q *queue) Get() []byte { 76 q.Lock() 77 raw := q.Packets[0].Raw 78 q.Packets = q.Packets[1:] 79 if len(q.Packets) > 0 { 80 q.timer.Reset(q.Packets[0].Time) 81 } 82 q.Unlock() 83 return raw 84 } 85 86 func (q *queue) Timer() <-chan time.Time { return q.timer.Chan() } 87 func (q *queue) SetTimerRead() { q.timer.SetRead() } 88 89 func (q *queue) Close() { q.timer.Stop() } 90 91 func (d Direction) String() string { 92 switch d { 93 case DirectionIncoming: 94 return "Incoming" 95 case DirectionOutgoing: 96 return "Outgoing" 97 case DirectionBoth: 98 return "both" 99 default: 100 panic("unknown direction") 101 } 102 } 103 104 // Is says if one direction matches another direction. 105 // For example, incoming matches both incoming and both, but not outgoing. 106 func (d Direction) Is(dir Direction) bool { 107 if d == DirectionBoth || dir == DirectionBoth { 108 return true 109 } 110 return d == dir 111 } 112 113 // DropCallback is a callback that determines which packet gets dropped. 114 type DropCallback func(dir Direction, packet []byte) bool 115 116 // NoDropper doesn't drop packets. 117 var NoDropper DropCallback = func(Direction, []byte) bool { 118 return false 119 } 120 121 // DelayCallback is a callback that determines how much delay to apply to a packet. 122 type DelayCallback func(dir Direction, packet []byte) time.Duration 123 124 // NoDelay doesn't apply a delay. 125 var NoDelay DelayCallback = func(Direction, []byte) time.Duration { 126 return 0 127 } 128 129 // Opts are proxy options. 130 type Opts struct { 131 // The address this proxy proxies packets to. 132 RemoteAddr string 133 // DropPacket determines whether a packet gets dropped. 134 DropPacket DropCallback 135 // DelayPacket determines how long a packet gets delayed. This allows 136 // simulating a connection with non-zero RTTs. 137 // Note that the RTT is the sum of the delay for the incoming and the outgoing packet. 138 DelayPacket DelayCallback 139 } 140 141 // QuicProxy is a QUIC proxy that can drop and delay packets. 142 type QuicProxy struct { 143 mutex sync.Mutex 144 145 closeChan chan struct{} 146 147 conn *net.UDPConn 148 serverAddr *net.UDPAddr 149 150 dropPacket DropCallback 151 delayPacket DelayCallback 152 153 // Mapping from client addresses (as host:port) to connection 154 clientDict map[string]*connection 155 156 logger utils.Logger 157 } 158 159 // NewQuicProxy creates a new UDP proxy 160 func NewQuicProxy(local string, opts *Opts) (*QuicProxy, error) { 161 if opts == nil { 162 opts = &Opts{} 163 } 164 laddr, err := net.ResolveUDPAddr("udp", local) 165 if err != nil { 166 return nil, err 167 } 168 conn, err := net.ListenUDP("udp", laddr) 169 if err != nil { 170 return nil, err 171 } 172 if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { 173 return nil, err 174 } 175 if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { 176 return nil, err 177 } 178 raddr, err := net.ResolveUDPAddr("udp", opts.RemoteAddr) 179 if err != nil { 180 return nil, err 181 } 182 183 packetDropper := NoDropper 184 if opts.DropPacket != nil { 185 packetDropper = opts.DropPacket 186 } 187 188 packetDelayer := NoDelay 189 if opts.DelayPacket != nil { 190 packetDelayer = opts.DelayPacket 191 } 192 193 p := QuicProxy{ 194 clientDict: make(map[string]*connection), 195 conn: conn, 196 closeChan: make(chan struct{}), 197 serverAddr: raddr, 198 dropPacket: packetDropper, 199 delayPacket: packetDelayer, 200 logger: utils.DefaultLogger.WithPrefix("proxy"), 201 } 202 203 p.logger.Debugf("Starting UDP Proxy %s <-> %s", conn.LocalAddr(), raddr) 204 go p.runProxy() 205 return &p, nil 206 } 207 208 // Close stops the UDP Proxy 209 func (p *QuicProxy) Close() error { 210 p.mutex.Lock() 211 defer p.mutex.Unlock() 212 close(p.closeChan) 213 for _, c := range p.clientDict { 214 if err := c.ServerConn.Close(); err != nil { 215 return err 216 } 217 c.Incoming.Close() 218 c.Outgoing.Close() 219 } 220 return p.conn.Close() 221 } 222 223 // LocalAddr is the address the proxy is listening on. 224 func (p *QuicProxy) LocalAddr() net.Addr { 225 return p.conn.LocalAddr() 226 } 227 228 // LocalPort is the UDP port number the proxy is listening on. 229 func (p *QuicProxy) LocalPort() int { 230 return p.conn.LocalAddr().(*net.UDPAddr).Port 231 } 232 233 func (p *QuicProxy) newConnection(cliAddr *net.UDPAddr) (*connection, error) { 234 conn, err := net.DialUDP("udp", nil, p.serverAddr) 235 if err != nil { 236 return nil, err 237 } 238 if err := conn.SetReadBuffer(protocol.DesiredReceiveBufferSize); err != nil { 239 return nil, err 240 } 241 if err := conn.SetWriteBuffer(protocol.DesiredSendBufferSize); err != nil { 242 return nil, err 243 } 244 return &connection{ 245 ClientAddr: cliAddr, 246 ServerConn: conn, 247 incomingPackets: make(chan packetEntry, 10), 248 Incoming: newQueue(), 249 Outgoing: newQueue(), 250 }, nil 251 } 252 253 // runProxy listens on the proxy address and handles incoming packets. 254 func (p *QuicProxy) runProxy() error { 255 for { 256 buffer := make([]byte, protocol.MaxPacketBufferSize) 257 n, cliaddr, err := p.conn.ReadFromUDP(buffer) 258 if err != nil { 259 return err 260 } 261 raw := buffer[0:n] 262 263 saddr := cliaddr.String() 264 p.mutex.Lock() 265 conn, ok := p.clientDict[saddr] 266 267 if !ok { 268 conn, err = p.newConnection(cliaddr) 269 if err != nil { 270 p.mutex.Unlock() 271 return err 272 } 273 p.clientDict[saddr] = conn 274 go p.runIncomingConnection(conn) 275 go p.runOutgoingConnection(conn) 276 } 277 p.mutex.Unlock() 278 279 if p.dropPacket(DirectionIncoming, raw) { 280 if p.logger.Debug() { 281 p.logger.Debugf("dropping incoming packet(%d bytes)", n) 282 } 283 continue 284 } 285 286 delay := p.delayPacket(DirectionIncoming, raw) 287 if delay == 0 { 288 if p.logger.Debug() { 289 p.logger.Debugf("forwarding incoming packet (%d bytes) to %s", len(raw), conn.ServerConn.RemoteAddr()) 290 } 291 if _, err := conn.ServerConn.Write(raw); err != nil { 292 return err 293 } 294 } else { 295 now := time.Now() 296 if p.logger.Debug() { 297 p.logger.Debugf("delaying incoming packet (%d bytes) to %s by %s", len(raw), conn.ServerConn.RemoteAddr(), delay) 298 } 299 conn.queuePacket(now.Add(delay), raw) 300 } 301 } 302 } 303 304 // runConnection handles packets from server to a single client 305 func (p *QuicProxy) runOutgoingConnection(conn *connection) error { 306 outgoingPackets := make(chan packetEntry, 10) 307 go func() { 308 for { 309 buffer := make([]byte, protocol.MaxPacketBufferSize) 310 n, err := conn.ServerConn.Read(buffer) 311 if err != nil { 312 return 313 } 314 raw := buffer[0:n] 315 316 if p.dropPacket(DirectionOutgoing, raw) { 317 if p.logger.Debug() { 318 p.logger.Debugf("dropping outgoing packet(%d bytes)", n) 319 } 320 continue 321 } 322 323 delay := p.delayPacket(DirectionOutgoing, raw) 324 if delay == 0 { 325 if p.logger.Debug() { 326 p.logger.Debugf("forwarding outgoing packet (%d bytes) to %s", len(raw), conn.ClientAddr) 327 } 328 if _, err := p.conn.WriteToUDP(raw, conn.ClientAddr); err != nil { 329 return 330 } 331 } else { 332 now := time.Now() 333 if p.logger.Debug() { 334 p.logger.Debugf("delaying outgoing packet (%d bytes) to %s by %s", len(raw), conn.ClientAddr, delay) 335 } 336 outgoingPackets <- packetEntry{Time: now.Add(delay), Raw: raw} 337 } 338 } 339 }() 340 341 for { 342 select { 343 case <-p.closeChan: 344 return nil 345 case e := <-outgoingPackets: 346 conn.Outgoing.Add(e) 347 case <-conn.Outgoing.Timer(): 348 conn.Outgoing.SetTimerRead() 349 if _, err := p.conn.WriteTo(conn.Outgoing.Get(), conn.ClientAddr); err != nil { 350 return err 351 } 352 } 353 } 354 } 355 356 func (p *QuicProxy) runIncomingConnection(conn *connection) error { 357 for { 358 select { 359 case <-p.closeChan: 360 return nil 361 case e := <-conn.incomingPackets: 362 // Send the packet to the server 363 conn.Incoming.Add(e) 364 case <-conn.Incoming.Timer(): 365 conn.Incoming.SetTimerRead() 366 if _, err := conn.ServerConn.Write(conn.Incoming.Get()); err != nil { 367 return err 368 } 369 } 370 } 371 }