github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/device.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "net/netip" 8 "os" 9 10 gun "github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/gvisor" 11 "github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol" 12 "github.com/tailscale/wireguard-go/tun" 13 "gvisor.dev/gvisor/pkg/tcpip" 14 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 15 "gvisor.dev/gvisor/pkg/tcpip/header" 16 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 17 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 18 "gvisor.dev/gvisor/pkg/tcpip/stack" 19 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 20 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 21 ) 22 23 type netTun struct { 24 ep *gun.Endpoint 25 stack *stack.Stack 26 events chan tun.Event 27 hasV4, hasV6 bool 28 dev *gun.ChannelTun 29 } 30 31 type Net netTun 32 33 func CreateNetTUN(localAddresses []netip.Prefix, mtu int) (tun.Device, *Net, error) { 34 opts := stack.Options{ 35 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 36 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 37 HandleLocal: true, 38 } 39 40 rwc := gun.NewChannelTun(context.TODO(), mtu) 41 dev := &netTun{ 42 ep: gun.NewEndpoint(gun.NewDevice(rwc, 0), uint32(mtu)), 43 dev: rwc, 44 stack: stack.New(opts), 45 events: make(chan tun.Event, 1), 46 } 47 48 sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default 49 tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) 50 if tcpipErr != nil { 51 dev.Close() 52 return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) 53 } 54 55 tcpipErr = dev.stack.CreateNIC(1, dev.ep) 56 if tcpipErr != nil { 57 dev.Close() 58 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 59 } 60 61 for _, ip := range localAddresses { 62 var protoNumber tcpip.NetworkProtocolNumber 63 if ip.Addr().Is4() { 64 protoNumber = ipv4.ProtocolNumber 65 } else if ip.Addr().Is6() { 66 protoNumber = ipv6.ProtocolNumber 67 } 68 69 protoAddr := tcpip.ProtocolAddress{ 70 AddressWithPrefix: tcpip.AddressWithPrefix{ 71 Address: tcpip.AddrFromSlice(ip.Addr().Unmap().AsSlice()), 72 PrefixLen: ip.Bits(), 73 }, 74 Protocol: protoNumber, 75 } 76 77 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 78 if tcpipErr != nil { 79 dev.Close() 80 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 81 } 82 if ip.Addr().Is4() { 83 dev.hasV4 = true 84 } else if ip.Addr().Is6() { 85 dev.hasV6 = true 86 } 87 } 88 89 if dev.hasV4 { 90 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 91 } 92 if dev.hasV6 { 93 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 94 } 95 96 opt := tcpip.CongestionControlOption("cubic") 97 if tcpipErr = dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); tcpipErr != nil { 98 dev.Close() 99 return nil, nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, tcpipErr) 100 } 101 102 dev.events <- tun.EventUp 103 return dev, (*Net)(dev), nil 104 } 105 106 // convert endpoint string to netip.Addr 107 func parseEndpoints(conf *protocol.Wireguard) ([]netip.Prefix, error) { 108 endpoints := make([]netip.Prefix, 0, len(conf.Endpoint)) 109 for _, str := range conf.Endpoint { 110 prefix, err := netip.ParsePrefix(str) 111 if err != nil { 112 addr, err := netip.ParseAddr(str) 113 if err != nil { 114 return nil, err 115 } 116 117 if addr.Is4() { 118 prefix = netip.PrefixFrom(addr, 32) 119 } else { 120 prefix = netip.PrefixFrom(addr, 128) 121 } 122 } 123 124 endpoints = append(endpoints, prefix) 125 } 126 127 return endpoints, nil 128 } 129 130 func (tun *netTun) Name() (string, error) { return "go", nil } 131 func (tun *netTun) File() *os.File { return nil } 132 func (tun *netTun) Events() <-chan tun.Event { return tun.events } 133 func (tun *netTun) BatchSize() int { return 1 } 134 135 func (tun *netTun) Read(buf [][]byte, size []int, offset int) (int, error) { 136 var err error 137 size[0], err = tun.dev.Inbound(buf[0][offset:]) 138 if err != nil { 139 return 0, err 140 } 141 142 return 1, nil 143 144 } 145 146 func (tun *netTun) Write(buffers [][]byte, offset int) (int, error) { 147 n := 0 148 for _, buf := range buffers { 149 packet := buf[offset:] 150 if len(packet) == 0 { 151 continue 152 } 153 154 err := tun.dev.Outbound(packet) 155 if err != nil { 156 if n > 0 { 157 return n, nil 158 } 159 return 0, err 160 } 161 162 n++ 163 } 164 165 return n, nil 166 } 167 168 func (tun *netTun) Flush() error { return nil } 169 170 func (tun *netTun) Close() error { 171 tun.stack.Destroy() 172 173 if tun.events != nil { 174 close(tun.events) 175 } 176 tun.ep.Close() 177 tun.dev.Close() 178 return nil 179 } 180 181 func (tun *netTun) MTU() (int, error) { return int(tun.ep.MTU()), nil } 182 183 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 184 var protoNumber tcpip.NetworkProtocolNumber 185 if endpoint.Addr().Is4() { 186 protoNumber = ipv4.ProtocolNumber 187 } else { 188 protoNumber = ipv6.ProtocolNumber 189 } 190 return tcpip.FullAddress{ 191 NIC: 1, 192 Addr: tcpip.AddrFromSlice(endpoint.Addr().Unmap().AsSlice()), 193 Port: endpoint.Port(), 194 }, protoNumber 195 } 196 197 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 198 fa, pn := convertToFullAddr(addr) 199 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 200 } 201 202 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 203 if addr == nil { 204 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 205 } 206 ip, _ := netip.AddrFromSlice(addr.IP) 207 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 208 } 209 210 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 211 fa, pn := convertToFullAddr(addr) 212 return gonet.DialTCP(net.stack, fa, pn) 213 } 214 215 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 216 if addr == nil { 217 return net.DialTCPAddrPort(netip.AddrPort{}) 218 } 219 ip, _ := netip.AddrFromSlice(addr.IP) 220 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 221 } 222 223 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 224 fa, pn := convertToFullAddr(addr) 225 return gonet.ListenTCP(net.stack, fa, pn) 226 } 227 228 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 229 if addr == nil { 230 return net.ListenTCPAddrPort(netip.AddrPort{}) 231 } 232 ip, _ := netip.AddrFromSlice(addr.IP) 233 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 234 } 235 236 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 237 var lfa, rfa *tcpip.FullAddress 238 var pn tcpip.NetworkProtocolNumber 239 if laddr.IsValid() || laddr.Port() > 0 { 240 var addr tcpip.FullAddress 241 addr, pn = convertToFullAddr(laddr) 242 lfa = &addr 243 } 244 if raddr.IsValid() || raddr.Port() > 0 { 245 var addr tcpip.FullAddress 246 addr, pn = convertToFullAddr(raddr) 247 rfa = &addr 248 } 249 250 if pn == 0 { 251 if net.HasV6() { 252 pn = ipv6.ProtocolNumber 253 } else { 254 pn = ipv4.ProtocolNumber 255 } 256 } 257 258 return gonet.DialUDP(net.stack, lfa, rfa, pn) 259 } 260 261 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 262 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 263 } 264 265 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 266 var la, ra netip.AddrPort 267 if laddr != nil { 268 ip, _ := netip.AddrFromSlice(laddr.IP) 269 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 270 } 271 if raddr != nil { 272 ip, _ := netip.AddrFromSlice(raddr.IP) 273 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 274 } 275 return net.DialUDPAddrPort(la, ra) 276 } 277 278 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 279 return net.DialUDP(laddr, nil) 280 } 281 282 func (n *Net) HasV4() bool { return n.hasV4 } 283 func (n *Net) HasV6() bool { return n.hasV6 }