github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tproxy/udp.go (about)

     1  package tproxy
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"os"
     9  	"strconv"
    10  	"syscall"
    11  	"unsafe"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/log"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    17  	"golang.org/x/sys/unix"
    18  )
    19  
    20  func controlUDP(c syscall.RawConn) error {
    21  	var fn = func(s uintptr) {
    22  		err := syscall.SetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
    23  		if err != nil {
    24  			log.Error("set socket with SOL_IP, IP_TRANSPARENT failed", "err", err)
    25  		}
    26  
    27  		val, err := syscall.GetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_TRANSPARENT)
    28  		if err != nil {
    29  			log.Error("get socket with SOL_IP, IP_TRANSPARENT failed", "err", err)
    30  		} else {
    31  			log.Error("value of IP_TRANSPARENT option", "val", val)
    32  		}
    33  
    34  		err = syscall.SetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
    35  		if err != nil {
    36  			log.Error("set socket with SOL_IP, IP_RECVORIGDSTADDR failed", "err", err)
    37  		}
    38  
    39  		val, err = syscall.GetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR)
    40  		if err != nil {
    41  			log.Error("get socket with SOL_IP, IP_RECVORIGDSTADDR failed", "err", err)
    42  		} else {
    43  			log.Error("value of IP_RECVORIGDSTADDR option", "val", val)
    44  		}
    45  	}
    46  
    47  	if err := c.Control(fn); err != nil {
    48  		return err
    49  	}
    50  
    51  	return nil
    52  }
    53  
    54  // DialUDP connects to the remote address raddr on the network net,
    55  // which must be "udp", "udp4", or "udp6".  If laddr is not nil, it is
    56  // used as the local address for the connection.
    57  func DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
    58  	remoteSocketAddress, err := udpAddrToSocketAddr(raddr)
    59  	if err != nil {
    60  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %w", err)}
    61  	}
    62  
    63  	localSocketAddress, err := udpAddrToSocketAddr(laddr)
    64  	if err != nil {
    65  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %w", err)}
    66  	}
    67  
    68  	fileDescriptor, err := syscall.Socket(udpAddrFamily(network, laddr, raddr), syscall.SOCK_DGRAM, 0)
    69  	if err != nil {
    70  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %w", err)}
    71  	}
    72  
    73  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
    74  		syscall.Close(fileDescriptor)
    75  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %w", err)}
    76  	}
    77  
    78  	if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil {
    79  		syscall.Close(fileDescriptor)
    80  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %w", err)}
    81  	}
    82  
    83  	if err = syscall.Bind(fileDescriptor, localSocketAddress); err != nil {
    84  		syscall.Close(fileDescriptor)
    85  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind: %w", err)}
    86  	}
    87  
    88  	if err = syscall.Connect(fileDescriptor, remoteSocketAddress); err != nil {
    89  		syscall.Close(fileDescriptor)
    90  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %w", err)}
    91  	}
    92  
    93  	fdFile := os.NewFile(uintptr(fileDescriptor), "net-udp-dial-"+raddr.String())
    94  	defer fdFile.Close()
    95  
    96  	remoteConn, err := net.FileConn(fdFile)
    97  	if err != nil {
    98  		syscall.Close(fileDescriptor)
    99  		return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %w", err)}
   100  	}
   101  
   102  	return remoteConn.(*net.UDPConn), nil
   103  }
   104  
   105  // udpAddToSockerAddr will convert a UDPAddr
   106  // into a Sockaddr that may be used when
   107  // connecting and binding sockets
   108  func udpAddrToSocketAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) {
   109  	switch {
   110  	case addr.IP.To4() != nil:
   111  		return &syscall.SockaddrInet4{Addr: [4]byte(addr.IP.To4()), Port: addr.Port}, nil
   112  
   113  	default:
   114  		var zoneID uint64
   115  		if addr.Zone != "" {
   116  			var err error
   117  			zoneID, err = strconv.ParseUint(addr.Zone, 10, 32)
   118  			if err != nil {
   119  				return nil, err
   120  			}
   121  		}
   122  
   123  		return &syscall.SockaddrInet6{Addr: [16]byte(addr.IP.To16()), Port: addr.Port, ZoneId: uint32(zoneID)}, nil
   124  	}
   125  }
   126  
   127  // udpAddrFamily will attempt to work
   128  // out the address family based on the
   129  // network and UDP addresses
   130  func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int {
   131  	switch net[len(net)-1] {
   132  	case '4':
   133  		return syscall.AF_INET
   134  	case '6':
   135  		return syscall.AF_INET6
   136  	}
   137  
   138  	if (laddr == nil || laddr.IP.To4() != nil) &&
   139  		(raddr == nil || raddr.IP.To4() != nil) {
   140  		return syscall.AF_INET
   141  	}
   142  	return syscall.AF_INET6
   143  }
   144  
   145  //credit: https://github.com/LiamHaworth/go-tproxy/blob/master/tproxy_udp.go ,  which is under MIT License
   146  
   147  var errContinue = errors.New("continue")
   148  
   149  // ReadFromUDP reads a UDP packet from c, copying the payload into b.
   150  // It returns the number of bytes copied into b and the return address
   151  // that was on the packet.
   152  //
   153  // Out-of-band data is also read in so that the original destination
   154  // address can be identified and parsed.
   155  func ReadFromUDP(conn *net.UDPConn, b []byte) (n int, srcAddr *net.UDPAddr, dstAddr *net.UDPAddr, err error) {
   156  	oob := make([]byte, 1024)
   157  	var oobn int
   158  	n, oobn, _, srcAddr, err = conn.ReadMsgUDP(b, oob)
   159  	if err != nil {
   160  		return
   161  	}
   162  
   163  	msgs, err := syscall.ParseSocketControlMessage(oob[:oobn])
   164  	if err != nil {
   165  		err = fmt.Errorf("%w parsing socket control message: %s", errContinue, err)
   166  		return
   167  	}
   168  
   169  	//from golang.org/x/sys/unix/sockcmsg_linux.go ParseOrigDstAddr
   170  
   171  	for _, m := range msgs {
   172  
   173  		switch {
   174  		case m.Header.Level == syscall.SOL_IP && m.Header.Type == syscall.IP_ORIGDSTADDR:
   175  			pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&m.Data[0]))
   176  
   177  			p := (*[2]byte)(unsafe.Pointer(&pp.Port))
   178  
   179  			dstAddr = &net.UDPAddr{
   180  				IP:   net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
   181  				Port: int(p[0])<<8 + int(p[1]),
   182  			}
   183  
   184  		case m.Header.Level == syscall.SOL_IPV6 && m.Header.Type == unix.IPV6_ORIGDSTADDR:
   185  			pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(&m.Data[0]))
   186  			p := (*[2]byte)(unsafe.Pointer(&pp.Port))
   187  			dstAddr = &net.UDPAddr{
   188  				IP:   net.IP(pp.Addr[:]),
   189  				Port: int(p[0])<<8 + int(p[1]),
   190  				Zone: strconv.Itoa(int(pp.Scope_id)),
   191  			}
   192  
   193  		}
   194  
   195  	}
   196  
   197  	if dstAddr == nil {
   198  		err = fmt.Errorf("%w unable to obtain original destination: %v (src: %v)", errContinue, err, srcAddr)
   199  		return
   200  	}
   201  
   202  	return
   203  }
   204  
   205  func (t *Tproxy) newUDP() error {
   206  	lis, err := t.lis.Packet(t.Context())
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	udpLis, ok := lis.(*net.UDPConn)
   212  	if !ok {
   213  		lis.Close()
   214  		return fmt.Errorf("listen is not udplistener")
   215  	}
   216  
   217  	sysConn, err := udpLis.SyscallConn()
   218  	if err != nil {
   219  		lis.Close()
   220  		return err
   221  	}
   222  
   223  	err = controlUDP(sysConn)
   224  	if err != nil {
   225  		lis.Close()
   226  		return err
   227  	}
   228  
   229  	log.Info("new tproxy udp server", "host", lis.LocalAddr())
   230  
   231  	go func() {
   232  		defer lis.Close()
   233  
   234  		for {
   235  			buf := pool.GetBytesBuffer(nat.MaxSegmentSize)
   236  			n, src, dst, err := ReadFromUDP(udpLis, buf.Bytes())
   237  			if err != nil {
   238  				buf.Free()
   239  				log.Error("start udp server failed", "err", err)
   240  				if !errors.Is(err, errContinue) {
   241  					break
   242  				}
   243  				continue
   244  			}
   245  
   246  			buf.Refactor(0, n)
   247  
   248  			dstAddr, _ := netapi.ParseSysAddr(dst)
   249  
   250  			err = t.SendPacket(&netapi.Packet{
   251  				Src:     src,
   252  				Dst:     dstAddr,
   253  				Payload: buf,
   254  				WriteBack: func(b []byte, addr net.Addr) (int, error) {
   255  					ad, err := netapi.ParseSysAddr(addr)
   256  					if err != nil {
   257  						return 0, err
   258  					}
   259  
   260  					ur := ad.UDPAddr(context.Background())
   261  
   262  					if ur.Err != nil {
   263  						return 0, ur.Err
   264  					}
   265  
   266  					back, err := DialUDP("udp", ur.V, src)
   267  					if err != nil {
   268  						return 0, fmt.Errorf("udp server dial failed: %w", err)
   269  					}
   270  					defer back.Close()
   271  
   272  					n, err := back.Write(b)
   273  					if err != nil {
   274  						return 0, err
   275  					}
   276  
   277  					return n, nil
   278  				},
   279  			})
   280  
   281  			if err != nil {
   282  				return
   283  			}
   284  		}
   285  	}()
   286  
   287  	return nil
   288  }