github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tproxy/udp.go (about) 1 package tproxy 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "os" 9 "strconv" 10 "syscall" 11 "unsafe" 12 13 "github.com/Asutorufa/yuhaiin/pkg/log" 14 "github.com/Asutorufa/yuhaiin/pkg/net/nat" 15 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 16 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 17 "golang.org/x/sys/unix" 18 ) 19 20 func controlUDP(c syscall.RawConn) error { 21 var fn = func(s uintptr) { 22 err := syscall.SetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1) 23 if err != nil { 24 log.Error("set socket with SOL_IP, IP_TRANSPARENT failed", "err", err) 25 } 26 27 val, err := syscall.GetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_TRANSPARENT) 28 if err != nil { 29 log.Error("get socket with SOL_IP, IP_TRANSPARENT failed", "err", err) 30 } else { 31 log.Error("value of IP_TRANSPARENT option", "val", val) 32 } 33 34 err = syscall.SetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1) 35 if err != nil { 36 log.Error("set socket with SOL_IP, IP_RECVORIGDSTADDR failed", "err", err) 37 } 38 39 val, err = syscall.GetsockoptInt(int(s), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR) 40 if err != nil { 41 log.Error("get socket with SOL_IP, IP_RECVORIGDSTADDR failed", "err", err) 42 } else { 43 log.Error("value of IP_RECVORIGDSTADDR option", "val", val) 44 } 45 } 46 47 if err := c.Control(fn); err != nil { 48 return err 49 } 50 51 return nil 52 } 53 54 // DialUDP connects to the remote address raddr on the network net, 55 // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is 56 // used as the local address for the connection. 57 func DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (*net.UDPConn, error) { 58 remoteSocketAddress, err := udpAddrToSocketAddr(raddr) 59 if err != nil { 60 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %w", err)} 61 } 62 63 localSocketAddress, err := udpAddrToSocketAddr(laddr) 64 if err != nil { 65 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %w", err)} 66 } 67 68 fileDescriptor, err := syscall.Socket(udpAddrFamily(network, laddr, raddr), syscall.SOCK_DGRAM, 0) 69 if err != nil { 70 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %w", err)} 71 } 72 73 if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { 74 syscall.Close(fileDescriptor) 75 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %w", err)} 76 } 77 78 if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil { 79 syscall.Close(fileDescriptor) 80 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %w", err)} 81 } 82 83 if err = syscall.Bind(fileDescriptor, localSocketAddress); err != nil { 84 syscall.Close(fileDescriptor) 85 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind: %w", err)} 86 } 87 88 if err = syscall.Connect(fileDescriptor, remoteSocketAddress); err != nil { 89 syscall.Close(fileDescriptor) 90 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %w", err)} 91 } 92 93 fdFile := os.NewFile(uintptr(fileDescriptor), "net-udp-dial-"+raddr.String()) 94 defer fdFile.Close() 95 96 remoteConn, err := net.FileConn(fdFile) 97 if err != nil { 98 syscall.Close(fileDescriptor) 99 return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %w", err)} 100 } 101 102 return remoteConn.(*net.UDPConn), nil 103 } 104 105 // udpAddToSockerAddr will convert a UDPAddr 106 // into a Sockaddr that may be used when 107 // connecting and binding sockets 108 func udpAddrToSocketAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) { 109 switch { 110 case addr.IP.To4() != nil: 111 return &syscall.SockaddrInet4{Addr: [4]byte(addr.IP.To4()), Port: addr.Port}, nil 112 113 default: 114 var zoneID uint64 115 if addr.Zone != "" { 116 var err error 117 zoneID, err = strconv.ParseUint(addr.Zone, 10, 32) 118 if err != nil { 119 return nil, err 120 } 121 } 122 123 return &syscall.SockaddrInet6{Addr: [16]byte(addr.IP.To16()), Port: addr.Port, ZoneId: uint32(zoneID)}, nil 124 } 125 } 126 127 // udpAddrFamily will attempt to work 128 // out the address family based on the 129 // network and UDP addresses 130 func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int { 131 switch net[len(net)-1] { 132 case '4': 133 return syscall.AF_INET 134 case '6': 135 return syscall.AF_INET6 136 } 137 138 if (laddr == nil || laddr.IP.To4() != nil) && 139 (raddr == nil || raddr.IP.To4() != nil) { 140 return syscall.AF_INET 141 } 142 return syscall.AF_INET6 143 } 144 145 //credit: https://github.com/LiamHaworth/go-tproxy/blob/master/tproxy_udp.go , which is under MIT License 146 147 var errContinue = errors.New("continue") 148 149 // ReadFromUDP reads a UDP packet from c, copying the payload into b. 150 // It returns the number of bytes copied into b and the return address 151 // that was on the packet. 152 // 153 // Out-of-band data is also read in so that the original destination 154 // address can be identified and parsed. 155 func ReadFromUDP(conn *net.UDPConn, b []byte) (n int, srcAddr *net.UDPAddr, dstAddr *net.UDPAddr, err error) { 156 oob := make([]byte, 1024) 157 var oobn int 158 n, oobn, _, srcAddr, err = conn.ReadMsgUDP(b, oob) 159 if err != nil { 160 return 161 } 162 163 msgs, err := syscall.ParseSocketControlMessage(oob[:oobn]) 164 if err != nil { 165 err = fmt.Errorf("%w parsing socket control message: %s", errContinue, err) 166 return 167 } 168 169 //from golang.org/x/sys/unix/sockcmsg_linux.go ParseOrigDstAddr 170 171 for _, m := range msgs { 172 173 switch { 174 case m.Header.Level == syscall.SOL_IP && m.Header.Type == syscall.IP_ORIGDSTADDR: 175 pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&m.Data[0])) 176 177 p := (*[2]byte)(unsafe.Pointer(&pp.Port)) 178 179 dstAddr = &net.UDPAddr{ 180 IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), 181 Port: int(p[0])<<8 + int(p[1]), 182 } 183 184 case m.Header.Level == syscall.SOL_IPV6 && m.Header.Type == unix.IPV6_ORIGDSTADDR: 185 pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(&m.Data[0])) 186 p := (*[2]byte)(unsafe.Pointer(&pp.Port)) 187 dstAddr = &net.UDPAddr{ 188 IP: net.IP(pp.Addr[:]), 189 Port: int(p[0])<<8 + int(p[1]), 190 Zone: strconv.Itoa(int(pp.Scope_id)), 191 } 192 193 } 194 195 } 196 197 if dstAddr == nil { 198 err = fmt.Errorf("%w unable to obtain original destination: %v (src: %v)", errContinue, err, srcAddr) 199 return 200 } 201 202 return 203 } 204 205 func (t *Tproxy) newUDP() error { 206 lis, err := t.lis.Packet(t.Context()) 207 if err != nil { 208 return err 209 } 210 211 udpLis, ok := lis.(*net.UDPConn) 212 if !ok { 213 lis.Close() 214 return fmt.Errorf("listen is not udplistener") 215 } 216 217 sysConn, err := udpLis.SyscallConn() 218 if err != nil { 219 lis.Close() 220 return err 221 } 222 223 err = controlUDP(sysConn) 224 if err != nil { 225 lis.Close() 226 return err 227 } 228 229 log.Info("new tproxy udp server", "host", lis.LocalAddr()) 230 231 go func() { 232 defer lis.Close() 233 234 for { 235 buf := pool.GetBytesBuffer(nat.MaxSegmentSize) 236 n, src, dst, err := ReadFromUDP(udpLis, buf.Bytes()) 237 if err != nil { 238 buf.Free() 239 log.Error("start udp server failed", "err", err) 240 if !errors.Is(err, errContinue) { 241 break 242 } 243 continue 244 } 245 246 buf.Refactor(0, n) 247 248 dstAddr, _ := netapi.ParseSysAddr(dst) 249 250 err = t.SendPacket(&netapi.Packet{ 251 Src: src, 252 Dst: dstAddr, 253 Payload: buf, 254 WriteBack: func(b []byte, addr net.Addr) (int, error) { 255 ad, err := netapi.ParseSysAddr(addr) 256 if err != nil { 257 return 0, err 258 } 259 260 ur := ad.UDPAddr(context.Background()) 261 262 if ur.Err != nil { 263 return 0, ur.Err 264 } 265 266 back, err := DialUDP("udp", ur.V, src) 267 if err != nil { 268 return 0, fmt.Errorf("udp server dial failed: %w", err) 269 } 270 defer back.Close() 271 272 n, err := back.Write(b) 273 if err != nil { 274 return 0, err 275 } 276 277 return n, nil 278 }, 279 }) 280 281 if err != nil { 282 return 283 } 284 } 285 }() 286 287 return nil 288 }