github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/tun2socket/nat/nat.go (about)

     1  package nat
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"math"
     7  	"net"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/log"
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netlink"
    13  	tun "github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/gvisor"
    14  	"gvisor.dev/gvisor/pkg/tcpip"
    15  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    16  	"gvisor.dev/gvisor/pkg/tcpip/header"
    17  )
    18  
    19  type Nat struct {
    20  	*TCP
    21  	*UDP
    22  
    23  	address     tcpip.Address
    24  	portal      tcpip.Address
    25  	addressV6   tcpip.Address
    26  	portalV6    tcpip.Address
    27  	gatewayPort uint16
    28  	mtu         int32
    29  
    30  	tab *tableSplit
    31  }
    32  
    33  func Start(opt *tun.Opt) (*Nat, error) {
    34  	listener, err := dialer.ListenContextWithOptions(context.Background(), "tcp", "", &dialer.Options{})
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	log.Info("new tun2socket tcp server", "host", listener.Addr(),
    40  		"gateway", opt.V4Address(), "portal", opt.V4Address().Addr().Next(),
    41  		"gatewayv6", opt.V6Address(), "portalv6", opt.V6Address().Addr().Next(),
    42  	)
    43  
    44  	err = netlink.Route(opt.Options)
    45  	if err != nil {
    46  		log.Warn("set route failed", "err", err)
    47  	}
    48  
    49  	if opt.MTU <= 0 {
    50  		opt.MTU = nat.MaxSegmentSize
    51  	}
    52  
    53  	tab := newTable()
    54  
    55  	nat := &Nat{
    56  		address:     tcpip.AddrFromSlice(opt.V4Address().Addr().AsSlice()),
    57  		portal:      tcpip.AddrFromSlice(opt.V4Address().Addr().Next().AsSlice()),
    58  		addressV6:   tcpip.AddrFromSlice(opt.V6Address().Addr().AsSlice()),
    59  		portalV6:    tcpip.AddrFromSlice(opt.V6Address().Addr().Next().AsSlice()),
    60  		gatewayPort: uint16(listener.Addr().(*net.TCPAddr).Port),
    61  		mtu:         int32(opt.MTU),
    62  		tab:         tab,
    63  		TCP: &TCP{
    64  			listener: listener.(*net.TCPListener),
    65  			portal:   opt.V4Address().Addr().Next().AsSlice(),
    66  			portalv6: opt.V6Address().Addr().Next().AsSlice(),
    67  			table:    tab,
    68  		},
    69  		UDP: NewUDPv2(int32(opt.MTU), opt.Writer),
    70  	}
    71  
    72  	subnet := tcpip.AddressWithPrefix{Address: nat.address, PrefixLen: opt.V4Address().Bits()}.Subnet()
    73  	broadcast := subnet.Broadcast()
    74  	if broadcast.Equal(nat.address) || broadcast.Equal(nat.portal) {
    75  		broadcast = tcpip.AddrFrom4([4]byte{255, 255, 255, 255})
    76  	}
    77  
    78  	go func() {
    79  		defer nat.Close()
    80  
    81  		sizes := make([]int, opt.Writer.Tun().BatchSize())
    82  		bufs := make([][]byte, opt.Writer.Tun().BatchSize())
    83  		for i := range bufs {
    84  			bufs[i] = make([]byte, opt.MTU)
    85  		}
    86  
    87  		wbufs := make([][]byte, opt.Writer.Tun().BatchSize())
    88  
    89  		for {
    90  			n, err := opt.Writer.Read(bufs, sizes)
    91  			if err != nil {
    92  				log.Error("tun device read failed", "err", err)
    93  				return
    94  			}
    95  
    96  			wbufs = wbufs[:0]
    97  
    98  			for i := range n {
    99  				if sizes[i] < header.IPv4MinimumSize {
   100  					continue
   101  				}
   102  
   103  				raw := bufs[i][:sizes[i]]
   104  
   105  				ip := nat.processIP(raw)
   106  				if ip == nil {
   107  					continue
   108  				}
   109  
   110  				if len(ip.Payload()) > len(raw) {
   111  					continue
   112  				}
   113  
   114  				dst, src := ip.DestinationAddress(), ip.SourceAddress()
   115  
   116  				if !net.IP(dst.AsSlice()).IsGlobalUnicast() || dst.Equal(broadcast) {
   117  					continue
   118  				}
   119  
   120  				var tp header.Transport
   121  				var pseudoHeaderSum uint16
   122  				var ok bool
   123  
   124  				switch ip.TransportProtocol() {
   125  				case header.TCPProtocolNumber:
   126  					tp, pseudoHeaderSum, ok = nat.processTCP(ip, src, dst)
   127  
   128  				case header.ICMPv4ProtocolNumber:
   129  					tp, pseudoHeaderSum, ok = processICMP(ip)
   130  
   131  				case header.ICMPv6ProtocolNumber:
   132  					tp, pseudoHeaderSum, ok = processICMPv6(ip)
   133  
   134  				case header.UDPProtocolNumber:
   135  					u := header.UDP(ip.Payload())
   136  					if u.Length() == 0 {
   137  						continue
   138  					}
   139  
   140  					nat.UDP.handleUDPPacket(
   141  						Tuple{
   142  							SourceAddr:      src,
   143  							SourcePort:      u.SourcePort(),
   144  							DestinationAddr: dst,
   145  							DestinationPort: u.DestinationPort(),
   146  						}, u.Payload())
   147  
   148  					continue
   149  
   150  				default:
   151  					continue
   152  				}
   153  
   154  				if !ok {
   155  					continue
   156  				}
   157  
   158  				resetCheckSum(ip, tp, pseudoHeaderSum)
   159  
   160  				wbufs = append(wbufs, raw)
   161  			}
   162  
   163  			if len(wbufs) == 0 {
   164  				continue
   165  			}
   166  
   167  			if _, err = opt.Writer.Write(wbufs); err != nil {
   168  				log.Error("write tcp raw to tun device failed", "err", err)
   169  			}
   170  
   171  		}
   172  	}()
   173  
   174  	return nat, nil
   175  }
   176  
   177  func (n *Nat) processIP(raw []byte) header.Network {
   178  	switch header.IPVersion(raw) {
   179  	case header.IPv4Version:
   180  		ipv4 := header.IPv4(raw)
   181  
   182  		if !ipv4.IsValid(int(ipv4.TotalLength())) {
   183  			return nil
   184  		}
   185  
   186  		if ipv4.More() {
   187  			return nil
   188  		}
   189  
   190  		if ipv4.FragmentOffset() != 0 {
   191  			return nil
   192  		}
   193  
   194  		return ipv4
   195  
   196  	case header.IPv6Version:
   197  		ipv6 := header.IPv6(raw)
   198  
   199  		if ipv6.HopLimit() == 0x00 {
   200  			return nil
   201  		}
   202  
   203  		return ipv6
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  func (n *Nat) processTCP(ip header.Network, src, dst tcpip.Address) (_ header.Transport, pseudoHeaderSum uint16, _ bool) {
   210  	t := header.TCP(ip.Payload())
   211  
   212  	sourcePort := t.SourcePort()
   213  	destinationPort := t.DestinationPort()
   214  
   215  	var address, portal tcpip.Address
   216  	if _, ok := ip.(header.IPv4); ok {
   217  		address, portal = n.address, n.portal
   218  	} else {
   219  		address, portal = n.addressV6, n.portalV6
   220  	}
   221  
   222  	if address.Unspecified() || portal.Unspecified() {
   223  		return nil, 0, false
   224  	}
   225  	if src == address && sourcePort == n.gatewayPort {
   226  		tup := n.tab.tupleOf(destinationPort, dst.Len() == 16)
   227  		if tup == zeroTuple {
   228  			return nil, 0, false
   229  		}
   230  
   231  		ip.SetDestinationAddress(tup.SourceAddr)
   232  		t.SetDestinationPort(tup.SourcePort)
   233  		ip.SetSourceAddress(tup.DestinationAddr)
   234  		t.SetSourcePort(tup.DestinationPort)
   235  	} else {
   236  		tup := Tuple{
   237  			SourceAddr:      src,
   238  			SourcePort:      sourcePort,
   239  			DestinationAddr: dst,
   240  			DestinationPort: destinationPort,
   241  		}
   242  
   243  		port := n.tab.portOf(tup)
   244  		ip.SetDestinationAddress(address)
   245  		t.SetDestinationPort(n.gatewayPort)
   246  		ip.SetSourceAddress(portal)
   247  		t.SetSourcePort(port)
   248  	}
   249  
   250  	pseudoHeaderSum = header.PseudoHeaderChecksum(header.TCPProtocolNumber,
   251  		ip.SourceAddress(),
   252  		ip.DestinationAddress(),
   253  		uint16(len(ip.Payload())),
   254  	)
   255  
   256  	return t, pseudoHeaderSum, true
   257  }
   258  
   259  func (n *Nat) Close() error {
   260  	var err error
   261  
   262  	if n.UDP != nil {
   263  		if er := n.UDP.Close(); er != nil {
   264  			err = errors.Join(err, er)
   265  		}
   266  	}
   267  
   268  	if n.TCP != nil {
   269  		if er := n.TCP.Close(); er != nil {
   270  			err = errors.Join(err, er)
   271  		}
   272  	}
   273  
   274  	return err
   275  }
   276  
   277  func processICMP(ip header.Network) (_ header.Transport, pseudoHeaderSum uint16, _ bool) {
   278  	i := header.ICMPv4(ip.Payload())
   279  
   280  	if i.Type() != header.ICMPv4Echo || i.Code() != 0 {
   281  		return nil, 0, false
   282  	}
   283  
   284  	i.SetType(header.ICMPv4EchoReply)
   285  
   286  	destination := ip.DestinationAddress()
   287  	ip.SetDestinationAddress(ip.SourceAddress())
   288  	ip.SetSourceAddress(destination)
   289  
   290  	pseudoHeaderSum = 0
   291  
   292  	return i, pseudoHeaderSum, true
   293  }
   294  
   295  func processICMPv6(ip header.Network) (_ header.Transport, pseudoHeaderSum uint16, _ bool) {
   296  	i := header.ICMPv6(ip.Payload())
   297  
   298  	if i.Type() != header.ICMPv6EchoRequest || i.Code() != 0 {
   299  		return nil, 0, false
   300  	}
   301  
   302  	i.SetType(header.ICMPv6EchoReply)
   303  
   304  	destination := ip.DestinationAddress()
   305  	ip.SetDestinationAddress(ip.SourceAddress())
   306  	ip.SetSourceAddress(destination)
   307  
   308  	pseudoHeaderSum = header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber,
   309  		ip.SourceAddress(), ip.DestinationAddress(),
   310  		uint16(len(i)),
   311  	)
   312  
   313  	return i, pseudoHeaderSum, true
   314  }
   315  
   316  func resetCheckSum(ip header.Network, tp header.Transport, pseudoHeaderSum uint16) {
   317  	resetIPCheckSum(ip)
   318  	resetTransportCheckSum(ip, tp, pseudoHeaderSum)
   319  }
   320  
   321  func resetIPCheckSum(ip header.Network) {
   322  	if ip, ok := ip.(header.IPv4); ok {
   323  		ip.SetChecksum(0)
   324  		sum := ip.CalculateChecksum()
   325  		ip.SetChecksum(^sum)
   326  	}
   327  }
   328  
   329  func resetTransportCheckSum(ip header.Network, tp header.Transport, pseudoHeaderSum uint16) {
   330  	tp.SetChecksum(0)
   331  	sum := checksum.Checksum(ip.Payload(), pseudoHeaderSum)
   332  
   333  	//https://datatracker.ietf.org/doc/html/rfc768
   334  	//
   335  	// If the computed  checksum  is zero,  it is transmitted  as all ones (the
   336  	// equivalent  in one's complement  arithmetic).   An all zero  transmitted
   337  	// checksum  value means that the transmitter  generated  no checksum  (for
   338  	// debugging or for higher level protocols that don't care).
   339  	//
   340  	// https://datatracker.ietf.org/doc/html/rfc8200
   341  	// Unlike IPv4, the default behavior when UDP packets are
   342  	//  originated by an IPv6 node is that the UDP checksum is not
   343  	//  optional.  That is, whenever originating a UDP packet, an IPv6
   344  	//  node must compute a UDP checksum over the packet and the
   345  	//  pseudo-header, and, if that computation yields a result of
   346  	//  zero, it must be changed to hex FFFF for placement in the UDP
   347  	//  header.  IPv6 receivers must discard UDP packets containing a
   348  	//  zero checksum and should log the error.
   349  	if ip.TransportProtocol() != header.UDPProtocolNumber || sum != math.MaxUint16 {
   350  		sum = ^sum
   351  	}
   352  	tp.SetChecksum(sum)
   353  }