github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/device.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"os"
     9  
    10  	gun "github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/gvisor"
    11  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    12  	"github.com/tailscale/wireguard-go/tun"
    13  	"gvisor.dev/gvisor/pkg/tcpip"
    14  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    15  	"gvisor.dev/gvisor/pkg/tcpip/header"
    16  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    17  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    18  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    19  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    20  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    21  )
    22  
    23  type netTun struct {
    24  	ep           *gun.Endpoint
    25  	stack        *stack.Stack
    26  	events       chan tun.Event
    27  	hasV4, hasV6 bool
    28  	dev          *gun.ChannelTun
    29  }
    30  
    31  type Net netTun
    32  
    33  func CreateNetTUN(localAddresses []netip.Prefix, mtu int) (tun.Device, *Net, error) {
    34  	opts := stack.Options{
    35  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    36  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    37  		HandleLocal:        true,
    38  	}
    39  
    40  	rwc := gun.NewChannelTun(context.TODO(), mtu)
    41  	dev := &netTun{
    42  		ep:     gun.NewEndpoint(gun.NewDevice(rwc, 0), uint32(mtu)),
    43  		dev:    rwc,
    44  		stack:  stack.New(opts),
    45  		events: make(chan tun.Event, 1),
    46  	}
    47  
    48  	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
    49  	tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
    50  	if tcpipErr != nil {
    51  		dev.Close()
    52  		return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
    53  	}
    54  
    55  	tcpipErr = dev.stack.CreateNIC(1, dev.ep)
    56  	if tcpipErr != nil {
    57  		dev.Close()
    58  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
    59  	}
    60  
    61  	for _, ip := range localAddresses {
    62  		var protoNumber tcpip.NetworkProtocolNumber
    63  		if ip.Addr().Is4() {
    64  			protoNumber = ipv4.ProtocolNumber
    65  		} else if ip.Addr().Is6() {
    66  			protoNumber = ipv6.ProtocolNumber
    67  		}
    68  
    69  		protoAddr := tcpip.ProtocolAddress{
    70  			AddressWithPrefix: tcpip.AddressWithPrefix{
    71  				Address:   tcpip.AddrFromSlice(ip.Addr().Unmap().AsSlice()),
    72  				PrefixLen: ip.Bits(),
    73  			},
    74  			Protocol: protoNumber,
    75  		}
    76  
    77  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    78  		if tcpipErr != nil {
    79  			dev.Close()
    80  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    81  		}
    82  		if ip.Addr().Is4() {
    83  			dev.hasV4 = true
    84  		} else if ip.Addr().Is6() {
    85  			dev.hasV6 = true
    86  		}
    87  	}
    88  
    89  	if dev.hasV4 {
    90  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
    91  	}
    92  	if dev.hasV6 {
    93  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
    94  	}
    95  
    96  	opt := tcpip.CongestionControlOption("cubic")
    97  	if tcpipErr = dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); tcpipErr != nil {
    98  		dev.Close()
    99  		return nil, nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, tcpipErr)
   100  	}
   101  
   102  	dev.events <- tun.EventUp
   103  	return dev, (*Net)(dev), nil
   104  }
   105  
   106  // convert endpoint string to netip.Addr
   107  func parseEndpoints(conf *protocol.Wireguard) ([]netip.Prefix, error) {
   108  	endpoints := make([]netip.Prefix, 0, len(conf.Endpoint))
   109  	for _, str := range conf.Endpoint {
   110  		prefix, err := netip.ParsePrefix(str)
   111  		if err != nil {
   112  			addr, err := netip.ParseAddr(str)
   113  			if err != nil {
   114  				return nil, err
   115  			}
   116  
   117  			if addr.Is4() {
   118  				prefix = netip.PrefixFrom(addr, 32)
   119  			} else {
   120  				prefix = netip.PrefixFrom(addr, 128)
   121  			}
   122  		}
   123  
   124  		endpoints = append(endpoints, prefix)
   125  	}
   126  
   127  	return endpoints, nil
   128  }
   129  
   130  func (tun *netTun) Name() (string, error)    { return "go", nil }
   131  func (tun *netTun) File() *os.File           { return nil }
   132  func (tun *netTun) Events() <-chan tun.Event { return tun.events }
   133  func (tun *netTun) BatchSize() int           { return 1 }
   134  
   135  func (tun *netTun) Read(buf [][]byte, size []int, offset int) (int, error) {
   136  	var err error
   137  	size[0], err = tun.dev.Inbound(buf[0][offset:])
   138  	if err != nil {
   139  		return 0, err
   140  	}
   141  
   142  	return 1, nil
   143  
   144  }
   145  
   146  func (tun *netTun) Write(buffers [][]byte, offset int) (int, error) {
   147  	n := 0
   148  	for _, buf := range buffers {
   149  		packet := buf[offset:]
   150  		if len(packet) == 0 {
   151  			continue
   152  		}
   153  
   154  		err := tun.dev.Outbound(packet)
   155  		if err != nil {
   156  			if n > 0 {
   157  				return n, nil
   158  			}
   159  			return 0, err
   160  		}
   161  
   162  		n++
   163  	}
   164  
   165  	return n, nil
   166  }
   167  
   168  func (tun *netTun) Flush() error { return nil }
   169  
   170  func (tun *netTun) Close() error {
   171  	tun.stack.Destroy()
   172  
   173  	if tun.events != nil {
   174  		close(tun.events)
   175  	}
   176  	tun.ep.Close()
   177  	tun.dev.Close()
   178  	return nil
   179  }
   180  
   181  func (tun *netTun) MTU() (int, error) { return int(tun.ep.MTU()), nil }
   182  
   183  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   184  	var protoNumber tcpip.NetworkProtocolNumber
   185  	if endpoint.Addr().Is4() {
   186  		protoNumber = ipv4.ProtocolNumber
   187  	} else {
   188  		protoNumber = ipv6.ProtocolNumber
   189  	}
   190  	return tcpip.FullAddress{
   191  		NIC:  1,
   192  		Addr: tcpip.AddrFromSlice(endpoint.Addr().Unmap().AsSlice()),
   193  		Port: endpoint.Port(),
   194  	}, protoNumber
   195  }
   196  
   197  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   198  	fa, pn := convertToFullAddr(addr)
   199  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   200  }
   201  
   202  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   203  	if addr == nil {
   204  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   205  	}
   206  	ip, _ := netip.AddrFromSlice(addr.IP)
   207  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   208  }
   209  
   210  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   211  	fa, pn := convertToFullAddr(addr)
   212  	return gonet.DialTCP(net.stack, fa, pn)
   213  }
   214  
   215  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   216  	if addr == nil {
   217  		return net.DialTCPAddrPort(netip.AddrPort{})
   218  	}
   219  	ip, _ := netip.AddrFromSlice(addr.IP)
   220  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   221  }
   222  
   223  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   224  	fa, pn := convertToFullAddr(addr)
   225  	return gonet.ListenTCP(net.stack, fa, pn)
   226  }
   227  
   228  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   229  	if addr == nil {
   230  		return net.ListenTCPAddrPort(netip.AddrPort{})
   231  	}
   232  	ip, _ := netip.AddrFromSlice(addr.IP)
   233  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   234  }
   235  
   236  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   237  	var lfa, rfa *tcpip.FullAddress
   238  	var pn tcpip.NetworkProtocolNumber
   239  	if laddr.IsValid() || laddr.Port() > 0 {
   240  		var addr tcpip.FullAddress
   241  		addr, pn = convertToFullAddr(laddr)
   242  		lfa = &addr
   243  	}
   244  	if raddr.IsValid() || raddr.Port() > 0 {
   245  		var addr tcpip.FullAddress
   246  		addr, pn = convertToFullAddr(raddr)
   247  		rfa = &addr
   248  	}
   249  
   250  	if pn == 0 {
   251  		if net.HasV6() {
   252  			pn = ipv6.ProtocolNumber
   253  		} else {
   254  			pn = ipv4.ProtocolNumber
   255  		}
   256  	}
   257  
   258  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   259  }
   260  
   261  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   262  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   263  }
   264  
   265  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   266  	var la, ra netip.AddrPort
   267  	if laddr != nil {
   268  		ip, _ := netip.AddrFromSlice(laddr.IP)
   269  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   270  	}
   271  	if raddr != nil {
   272  		ip, _ := netip.AddrFromSlice(raddr.IP)
   273  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   274  	}
   275  	return net.DialUDPAddrPort(la, ra)
   276  }
   277  
   278  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   279  	return net.DialUDP(laddr, nil)
   280  }
   281  
   282  func (n *Net) HasV4() bool { return n.hasV4 }
   283  func (n *Net) HasV6() bool { return n.hasV6 }