github.com/metacubex/mihomo@v1.18.5/transport/hysteria/conns/udp/hop.go (about)

     1  package udp
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"syscall"
    10  	"time"
    11  
    12  	"github.com/metacubex/mihomo/transport/hysteria/obfs"
    13  	"github.com/metacubex/mihomo/transport/hysteria/utils"
    14  
    15  	"github.com/zhangyunhao116/fastrand"
    16  )
    17  
    18  const (
    19  	packetQueueSize = 1024
    20  )
    21  
    22  // ObfsUDPHopClientPacketConn is the UDP port-hopping packet connection for client side.
    23  // It hops to a different local & server port every once in a while.
    24  type ObfsUDPHopClientPacketConn struct {
    25  	serverAddr  net.Addr // Combined udpHopAddr
    26  	serverAddrs []net.Addr
    27  	hopInterval time.Duration
    28  
    29  	obfs obfs.Obfuscator
    30  
    31  	connMutex   sync.RWMutex
    32  	prevConn    net.PacketConn
    33  	currentConn net.PacketConn
    34  	addrIndex   int
    35  
    36  	readBufferSize  int
    37  	writeBufferSize int
    38  
    39  	recvQueue chan *udpPacket
    40  	closeChan chan struct{}
    41  	closed    bool
    42  
    43  	bufPool sync.Pool
    44  }
    45  
    46  type udpHopAddr string
    47  
    48  func (a *udpHopAddr) Network() string {
    49  	return "udp-hop"
    50  }
    51  
    52  func (a *udpHopAddr) String() string {
    53  	return string(*a)
    54  }
    55  
    56  type udpPacket struct {
    57  	buf  []byte
    58  	n    int
    59  	addr net.Addr
    60  }
    61  
    62  func NewObfsUDPHopClientPacketConn(server string, serverPorts string, hopInterval time.Duration, obfs obfs.Obfuscator, dialer utils.PacketDialer) (net.PacketConn, error) {
    63  	ports, err := parsePorts(serverPorts)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	// Resolve the server IP address, then attach the ports to UDP addresses
    68  	rAddr, err := dialer.RemoteAddr(server)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	ip, _, err := net.SplitHostPort(rAddr.String())
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	serverAddrs := make([]net.Addr, len(ports))
    77  	for i, port := range ports {
    78  		serverAddrs[i] = &net.UDPAddr{
    79  			IP:   net.ParseIP(ip),
    80  			Port: int(port),
    81  		}
    82  	}
    83  	hopAddr := udpHopAddr(server)
    84  	conn := &ObfsUDPHopClientPacketConn{
    85  		serverAddr:  &hopAddr,
    86  		serverAddrs: serverAddrs,
    87  		hopInterval: hopInterval,
    88  		obfs:        obfs,
    89  		addrIndex:   fastrand.Intn(len(serverAddrs)),
    90  		recvQueue:   make(chan *udpPacket, packetQueueSize),
    91  		closeChan:   make(chan struct{}),
    92  		bufPool: sync.Pool{
    93  			New: func() interface{} {
    94  				return make([]byte, udpBufferSize)
    95  			},
    96  		},
    97  	}
    98  	curConn, err := dialer.ListenPacket(rAddr)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	if obfs != nil {
   103  		conn.currentConn = NewObfsUDPConn(curConn, obfs)
   104  	} else {
   105  		conn.currentConn = curConn
   106  	}
   107  	go conn.recvRoutine(conn.currentConn)
   108  	go conn.hopRoutine(dialer, rAddr)
   109  	if _, ok := conn.currentConn.(syscall.Conn); ok {
   110  		return &ObfsUDPHopClientPacketConnWithSyscall{conn}, nil
   111  	}
   112  	return conn, nil
   113  }
   114  
   115  func (c *ObfsUDPHopClientPacketConn) recvRoutine(conn net.PacketConn) {
   116  	for {
   117  		buf := c.bufPool.Get().([]byte)
   118  		n, addr, err := conn.ReadFrom(buf)
   119  		if err != nil {
   120  			return
   121  		}
   122  		select {
   123  		case c.recvQueue <- &udpPacket{buf, n, addr}:
   124  		default:
   125  			// Drop the packet if the queue is full
   126  			c.bufPool.Put(buf)
   127  		}
   128  	}
   129  }
   130  
   131  func (c *ObfsUDPHopClientPacketConn) hopRoutine(dialer utils.PacketDialer, rAddr net.Addr) {
   132  	ticker := time.NewTicker(c.hopInterval)
   133  	defer ticker.Stop()
   134  	for {
   135  		select {
   136  		case <-ticker.C:
   137  			c.hop(dialer, rAddr)
   138  		case <-c.closeChan:
   139  			return
   140  		}
   141  	}
   142  }
   143  
   144  func (c *ObfsUDPHopClientPacketConn) hop(dialer utils.PacketDialer, rAddr net.Addr) {
   145  	c.connMutex.Lock()
   146  	defer c.connMutex.Unlock()
   147  	if c.closed {
   148  		return
   149  	}
   150  	newConn, err := dialer.ListenPacket(rAddr)
   151  	if err != nil {
   152  		// Skip this hop if failed to listen
   153  		return
   154  	}
   155  	// Close prevConn,
   156  	// prevConn <- currentConn
   157  	// currentConn <- newConn
   158  	// update addrIndex
   159  	//
   160  	// We need to keep receiving packets from the previous connection,
   161  	// because otherwise there will be packet loss due to the time gap
   162  	// between we hop to a new port and the server acknowledges this change.
   163  	if c.prevConn != nil {
   164  		_ = c.prevConn.Close() // recvRoutine will exit on error
   165  	}
   166  	c.prevConn = c.currentConn
   167  	if c.obfs != nil {
   168  		c.currentConn = NewObfsUDPConn(newConn, c.obfs)
   169  	} else {
   170  		c.currentConn = newConn
   171  	}
   172  	// Set buffer sizes if previously set
   173  	if c.readBufferSize > 0 {
   174  		_ = trySetPacketConnReadBuffer(c.currentConn, c.readBufferSize)
   175  	}
   176  	if c.writeBufferSize > 0 {
   177  		_ = trySetPacketConnWriteBuffer(c.currentConn, c.writeBufferSize)
   178  	}
   179  	go c.recvRoutine(c.currentConn)
   180  	c.addrIndex = fastrand.Intn(len(c.serverAddrs))
   181  }
   182  
   183  func (c *ObfsUDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   184  	for {
   185  		select {
   186  		case p := <-c.recvQueue:
   187  			/*
   188  				// Check if the packet is from one of the server addresses
   189  				for _, addr := range c.serverAddrs {
   190  					if addr.String() == p.addr.String() {
   191  						// Copy the packet to the buffer
   192  						n := copy(b, p.buf[:p.n])
   193  						c.bufPool.Put(p.buf)
   194  						return n, c.serverAddr, nil
   195  					}
   196  				}
   197  				// Drop the packet, continue
   198  				c.bufPool.Put(p.buf)
   199  			*/
   200  			// The above code was causing performance issues when the range is large,
   201  			// so we skip the check for now. Should probably still check by using a map
   202  			// or something in the future.
   203  			n := copy(b, p.buf[:p.n])
   204  			c.bufPool.Put(p.buf)
   205  			return n, c.serverAddr, nil
   206  		case <-c.closeChan:
   207  			return 0, nil, net.ErrClosed
   208  		}
   209  		// Ignore packets from other addresses
   210  	}
   211  }
   212  
   213  func (c *ObfsUDPHopClientPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   214  	c.connMutex.RLock()
   215  	defer c.connMutex.RUnlock()
   216  	if c.closed {
   217  		return 0, net.ErrClosed
   218  	}
   219  	/*
   220  		// Check if the address is the server address
   221  		if addr.String() != c.serverAddr.String() {
   222  			return 0, net.ErrWriteToConnected
   223  		}
   224  	*/
   225  	// Skip the check for now, always write to the server
   226  	return c.currentConn.WriteTo(b, c.serverAddrs[c.addrIndex])
   227  }
   228  
   229  func (c *ObfsUDPHopClientPacketConn) Close() error {
   230  	c.connMutex.Lock()
   231  	defer c.connMutex.Unlock()
   232  	if c.closed {
   233  		return nil
   234  	}
   235  	// Close prevConn and currentConn
   236  	// Close closeChan to unblock ReadFrom & hopRoutine
   237  	// Set closed flag to true to prevent double close
   238  	if c.prevConn != nil {
   239  		_ = c.prevConn.Close()
   240  	}
   241  	err := c.currentConn.Close()
   242  	close(c.closeChan)
   243  	c.closed = true
   244  	c.serverAddrs = nil // For GC
   245  	return err
   246  }
   247  
   248  func (c *ObfsUDPHopClientPacketConn) LocalAddr() net.Addr {
   249  	c.connMutex.RLock()
   250  	defer c.connMutex.RUnlock()
   251  	return c.currentConn.LocalAddr()
   252  }
   253  
   254  func (c *ObfsUDPHopClientPacketConn) SetReadDeadline(t time.Time) error {
   255  	// Not supported
   256  	return nil
   257  }
   258  
   259  func (c *ObfsUDPHopClientPacketConn) SetWriteDeadline(t time.Time) error {
   260  	// Not supported
   261  	return nil
   262  }
   263  
   264  func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error {
   265  	err := c.SetReadDeadline(t)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	return c.SetWriteDeadline(t)
   270  }
   271  
   272  func (c *ObfsUDPHopClientPacketConn) SetReadBuffer(bytes int) error {
   273  	c.connMutex.Lock()
   274  	defer c.connMutex.Unlock()
   275  	c.readBufferSize = bytes
   276  	if c.prevConn != nil {
   277  		_ = trySetPacketConnReadBuffer(c.prevConn, bytes)
   278  	}
   279  	return trySetPacketConnReadBuffer(c.currentConn, bytes)
   280  }
   281  
   282  func (c *ObfsUDPHopClientPacketConn) SetWriteBuffer(bytes int) error {
   283  	c.connMutex.Lock()
   284  	defer c.connMutex.Unlock()
   285  	c.writeBufferSize = bytes
   286  	if c.prevConn != nil {
   287  		_ = trySetPacketConnWriteBuffer(c.prevConn, bytes)
   288  	}
   289  	return trySetPacketConnWriteBuffer(c.currentConn, bytes)
   290  }
   291  
   292  func trySetPacketConnReadBuffer(pc net.PacketConn, bytes int) error {
   293  	sc, ok := pc.(interface {
   294  		SetReadBuffer(bytes int) error
   295  	})
   296  	if ok {
   297  		return sc.SetReadBuffer(bytes)
   298  	}
   299  	return nil
   300  }
   301  
   302  func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error {
   303  	sc, ok := pc.(interface {
   304  		SetWriteBuffer(bytes int) error
   305  	})
   306  	if ok {
   307  		return sc.SetWriteBuffer(bytes)
   308  	}
   309  	return nil
   310  }
   311  
   312  type ObfsUDPHopClientPacketConnWithSyscall struct {
   313  	*ObfsUDPHopClientPacketConn
   314  }
   315  
   316  func (c *ObfsUDPHopClientPacketConnWithSyscall) SyscallConn() (syscall.RawConn, error) {
   317  	c.connMutex.RLock()
   318  	defer c.connMutex.RUnlock()
   319  	sc, ok := c.currentConn.(syscall.Conn)
   320  	if !ok {
   321  		return nil, errors.New("not supported")
   322  	}
   323  	return sc.SyscallConn()
   324  }
   325  
   326  // parsePorts parses the multi-port server address and returns the host and ports.
   327  // Supports both comma-separated single ports and dash-separated port ranges.
   328  // Format: "host:port1,port2-port3,port4"
   329  func parsePorts(serverPorts string) (ports []uint16, err error) {
   330  	portStrs := strings.Split(serverPorts, ",")
   331  	for _, portStr := range portStrs {
   332  		if strings.Contains(portStr, "-") {
   333  			// Port range
   334  			portRange := strings.Split(portStr, "-")
   335  			if len(portRange) != 2 {
   336  				return nil, net.InvalidAddrError("invalid port range")
   337  			}
   338  			start, err := strconv.ParseUint(portRange[0], 10, 16)
   339  			if err != nil {
   340  				return nil, net.InvalidAddrError("invalid port range")
   341  			}
   342  			end, err := strconv.ParseUint(portRange[1], 10, 16)
   343  			if err != nil {
   344  				return nil, net.InvalidAddrError("invalid port range")
   345  			}
   346  			if start > end {
   347  				start, end = end, start
   348  			}
   349  			for i := start; i <= end; i++ {
   350  				ports = append(ports, uint16(i))
   351  			}
   352  		} else {
   353  			// Single port
   354  			port, err := strconv.ParseUint(portStr, 10, 16)
   355  			if err != nil {
   356  				return nil, net.InvalidAddrError("invalid port")
   357  			}
   358  			ports = append(ports, uint16(port))
   359  		}
   360  	}
   361  	if len(ports) == 0 {
   362  		return nil, net.InvalidAddrError("invalid port")
   363  	}
   364  	return ports, nil
   365  }