
     1  package udp
     3  import (
     4  	"errors"
     5  	"net"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"syscall"
    10  	"time"
    12  	""
    13  	""
    15  	""
    16  )
    18  const (
    19  	packetQueueSize = 1024
    20  )
    22  // ObfsUDPHopClientPacketConn is the UDP port-hopping packet connection for client side.
    23  // It hops to a different local & server port every once in a while.
    24  type ObfsUDPHopClientPacketConn struct {
    25  	serverAddr  net.Addr // Combined udpHopAddr
    26  	serverAddrs []net.Addr
    27  	hopInterval time.Duration
    29  	obfs obfs.Obfuscator
    31  	connMutex   sync.RWMutex
    32  	prevConn    net.PacketConn
    33  	currentConn net.PacketConn
    34  	addrIndex   int
    36  	readBufferSize  int
    37  	writeBufferSize int
    39  	recvQueue chan *udpPacket
    40  	closeChan chan struct{}
    41  	closed    bool
    43  	bufPool sync.Pool
    44  }
    46  type udpHopAddr string
    48  func (a *udpHopAddr) Network() string {
    49  	return "udp-hop"
    50  }
    52  func (a *udpHopAddr) String() string {
    53  	return string(*a)
    54  }
    56  type udpPacket struct {
    57  	buf  []byte
    58  	n    int
    59  	addr net.Addr
    60  }
    62  func NewObfsUDPHopClientPacketConn(server string, serverPorts string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (net.PacketConn, error) {
    63  	ports, err := parsePorts(serverPorts)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	// Resolve the server IP address, then attach the ports to UDP addresses
    68  	rAddr, err := dialer.RemoteAddr(server)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	ip, _, err := net.SplitHostPort(rAddr.String())
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	serverAddrs := make([]net.Addr, len(ports))
    77  	for i, port := range ports {
    78  		serverAddrs[i] = &net.UDPAddr{
    79  			IP:   net.ParseIP(ip),
    80  			Port: int(port),
    81  		}
    82  	}
    83  	hopAddr := udpHopAddr(server)
    84  	conn := &ObfsUDPHopClientPacketConn{
    85  		serverAddr:  &hopAddr,
    86  		serverAddrs: serverAddrs,
    87  		hopInterval: hopInterval,
    88  		obfs:        obfs,
    89  		addrIndex:   fastrand.Intn(len(serverAddrs)),
    90  		recvQueue:   make(chan *udpPacket, packetQueueSize),
    91  		closeChan:   make(chan struct{}),
    92  		bufPool: sync.Pool{
    93  			New: func() interface{} {
    94  				return make([]byte, udpBufferSize)
    95  			},
    96  		},
    97  	}
    98  	curConn, err := dialer.ListenPacket(rAddr)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	if obfs != nil {
   103  		conn.currentConn = NewObfsUDPConn(curConn, obfs)
   104  	} else {
   105  		conn.currentConn = curConn
   106  	}
   107  	go conn.recvRoutine(conn.currentConn)
   108  	go conn.hopRoutine(dialer, rAddr)
   109  	if _, ok := conn.currentConn.(syscall.Conn); ok {
   110  		return &ObfsUDPHopClientPacketConnWithSyscall{conn}, nil
   111  	}
   112  	return conn, nil
   113  }
   115  func (c *ObfsUDPHopClientPacketConn) recvRoutine(conn net.PacketConn) {
   116  	for {
   117  		buf := c.bufPool.Get().([]byte)
   118  		n, addr, err := conn.ReadFrom(buf)
   119  		if err != nil {
   120  			return
   121  		}
   122  		select {
   123  		case c.recvQueue <- &udpPacket{buf, n, addr}:
   124  		default:
   125  			// Drop the packet if the queue is full
   126  			c.bufPool.Put(buf)
   127  		}
   128  	}
   129  }
   131  func (c *ObfsUDPHopClientPacketConn) hopRoutine(dialer utils.PacketDialer, rAddr net.Addr) {
   132  	ticker := time.NewTicker(c.hopInterval)
   133  	defer ticker.Stop()
   134  	for {
   135  		select {
   136  		case <-ticker.C:
   137  			c.hop(dialer, rAddr)
   138  		case <-c.closeChan:
   139  			return
   140  		}
   141  	}
   142  }
   144  func (c *ObfsUDPHopClientPacketConn) hop(dialer utils.PacketDialer, rAddr net.Addr) {
   145  	c.connMutex.Lock()
   146  	defer c.connMutex.Unlock()
   147  	if c.closed {
   148  		return
   149  	}
   150  	newConn, err := dialer.ListenPacket(rAddr)
   151  	if err != nil {
   152  		// Skip this hop if failed to listen
   153  		return
   154  	}
   155  	// Close prevConn,
   156  	// prevConn <- currentConn
   157  	// currentConn <- newConn
   158  	// update addrIndex
   159  	//
   160  	// We need to keep receiving packets from the previous connection,
   161  	// because otherwise there will be packet loss due to the time gap
   162  	// between we hop to a new port and the server acknowledges this change.
   163  	if c.prevConn != nil {
   164  		_ = c.prevConn.Close() // recvRoutine will exit on error
   165  	}
   166  	c.prevConn = c.currentConn
   167  	if c.obfs != nil {
   168  		c.currentConn = NewObfsUDPConn(newConn, c.obfs)
   169  	} else {
   170  		c.currentConn = newConn
   171  	}
   172  	// Set buffer sizes if previously set
   173  	if c.readBufferSize > 0 {
   174  		_ = trySetPacketConnReadBuffer(c.currentConn, c.readBufferSize)
   175  	}
   176  	if c.writeBufferSize > 0 {
   177  		_ = trySetPacketConnWriteBuffer(c.currentConn, c.writeBufferSize)
   178  	}
   179  	go c.recvRoutine(c.currentConn)
   180  	c.addrIndex = fastrand.Intn(len(c.serverAddrs))
   181  }
   183  func (c *ObfsUDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   184  	for {
   185  		select {
   186  		case p := <-c.recvQueue:
   187  			/*
   188  				// Check if the packet is from one of the server addresses
   189  				for _, addr := range c.serverAddrs {
   190  					if addr.String() == p.addr.String() {
   191  						// Copy the packet to the buffer
   192  						n := copy(b, p.buf[:p.n])
   193  						c.bufPool.Put(p.buf)
   194  						return n, c.serverAddr, nil
   195  					}
   196  				}
   197  				// Drop the packet, continue
   198  				c.bufPool.Put(p.buf)
   199  			*/
   200  			// The above code was causing performance issues when the range is large,
   201  			// so we skip the check for now. Should probably still check by using a map
   202  			// or something in the future.
   203  			n := copy(b, p.buf[:p.n])
   204  			c.bufPool.Put(p.buf)
   205  			return n, c.serverAddr, nil
   206  		case <-c.closeChan:
   207  			return 0, nil, net.ErrClosed
   208  		}
   209  		// Ignore packets from other addresses
   210  	}
   211  }
   213  func (c *ObfsUDPHopClientPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   214  	c.connMutex.RLock()
   215  	defer c.connMutex.RUnlock()
   216  	if c.closed {
   217  		return 0, net.ErrClosed
   218  	}
   219  	/*
   220  		// Check if the address is the server address
   221  		if addr.String() != c.serverAddr.String() {
   222  			return 0, net.ErrWriteToConnected
   223  		}
   224  	*/
   225  	// Skip the check for now, always write to the server
   226  	return c.currentConn.WriteTo(b, c.serverAddrs[c.addrIndex])
   227  }
   229  func (c *ObfsUDPHopClientPacketConn) Close() error {
   230  	c.connMutex.Lock()
   231  	defer c.connMutex.Unlock()
   232  	if c.closed {
   233  		return nil
   234  	}
   235  	// Close prevConn and currentConn
   236  	// Close closeChan to unblock ReadFrom & hopRoutine
   237  	// Set closed flag to true to prevent double close
   238  	if c.prevConn != nil {
   239  		_ = c.prevConn.Close()
   240  	}
   241  	err := c.currentConn.Close()
   242  	close(c.closeChan)
   243  	c.closed = true
   244  	c.serverAddrs = nil // For GC
   245  	return err
   246  }
   248  func (c *ObfsUDPHopClientPacketConn) LocalAddr() net.Addr {
   249  	c.connMutex.RLock()
   250  	defer c.connMutex.RUnlock()
   251  	return c.currentConn.LocalAddr()
   252  }
   254  func (c *ObfsUDPHopClientPacketConn) SetReadDeadline(t time.Time) error {
   255  	// Not supported
   256  	return nil
   257  }
   259  func (c *ObfsUDPHopClientPacketConn) SetWriteDeadline(t time.Time) error {
   260  	// Not supported
   261  	return nil
   262  }
   264  func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error {
   265  	err := c.SetReadDeadline(t)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	return c.SetWriteDeadline(t)
   270  }
   272  func (c *ObfsUDPHopClientPacketConn) SetReadBuffer(bytes int) error {
   273  	c.connMutex.Lock()
   274  	defer c.connMutex.Unlock()
   275  	c.readBufferSize = bytes
   276  	if c.prevConn != nil {
   277  		_ = trySetPacketConnReadBuffer(c.prevConn, bytes)
   278  	}
   279  	return trySetPacketConnReadBuffer(c.currentConn, bytes)
   280  }
   282  func (c *ObfsUDPHopClientPacketConn) SetWriteBuffer(bytes int) error {
   283  	c.connMutex.Lock()
   284  	defer c.connMutex.Unlock()
   285  	c.writeBufferSize = bytes
   286  	if c.prevConn != nil {
   287  		_ = trySetPacketConnWriteBuffer(c.prevConn, bytes)
   288  	}
   289  	return trySetPacketConnWriteBuffer(c.currentConn, bytes)
   290  }
   292  func trySetPacketConnReadBuffer(pc net.PacketConn, bytes int) error {
   293  	sc, ok := pc.(interface {
   294  		SetReadBuffer(bytes int) error
   295  	})
   296  	if ok {
   297  		return sc.SetReadBuffer(bytes)
   298  	}
   299  	return nil
   300  }
   302  func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error {
   303  	sc, ok := pc.(interface {
   304  		SetWriteBuffer(bytes int) error
   305  	})
   306  	if ok {
   307  		return sc.SetWriteBuffer(bytes)
   308  	}
   309  	return nil
   310  }
   312  type ObfsUDPHopClientPacketConnWithSyscall struct {
   313  	*ObfsUDPHopClientPacketConn
   314  }
   316  func (c *ObfsUDPHopClientPacketConnWithSyscall) SyscallConn() (syscall.RawConn, error) {
   317  	c.connMutex.RLock()
   318  	defer c.connMutex.RUnlock()
   319  	sc, ok := c.currentConn.(syscall.Conn)
   320  	if !ok {
   321  		return nil, errors.New("not supported")
   322  	}
   323  	return sc.SyscallConn()
   324  }
   326  // parsePorts parses the multi-port server address and returns the host and ports.
   327  // Supports both comma-separated single ports and dash-separated port ranges.
   328  // Format: "host:port1,port2-port3,port4"
   329  func parsePorts(serverPorts string) (ports []uint16, err error) {
   330  	portStrs := strings.Split(serverPorts, ",")
   331  	for _, portStr := range portStrs {
   332  		if strings.Contains(portStr, "-") {
   333  			// Port range
   334  			portRange := strings.Split(portStr, "-")
   335  			if len(portRange) != 2 {
   336  				return nil, net.InvalidAddrError("invalid port range")
   337  			}
   338  			start, err := strconv.ParseUint(portRange[0], 10, 16)
   339  			if err != nil {
   340  				return nil, net.InvalidAddrError("invalid port range")
   341  			}
   342  			end, err := strconv.ParseUint(portRange[1], 10, 16)
   343  			if err != nil {
   344  				return nil, net.InvalidAddrError("invalid port range")
   345  			}
   346  			if start > end {
   347  				start, end = end, start
   348  			}
   349  			for i := start; i <= end; i++ {
   350  				ports = append(ports, uint16(i))
   351  			}
   352  		} else {
   353  			// Single port
   354  			port, err := strconv.ParseUint(portStr, 10, 16)
   355  			if err != nil {
   356  				return nil, net.InvalidAddrError("invalid port")
   357  			}
   358  			ports = append(ports, uint16(port))
   359  		}
   360  	}
   361  	if len(ports) == 0 {
   362  		return nil, net.InvalidAddrError("invalid port")
   363  	}
   364  	return ports, nil
   365  }