github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/tun2socket/tun2socket.go (about)

     1  package tun2socket
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/log"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    13  	tun "github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/gvisor"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/tun/tun2socket/nat"
    15  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    16  	"gvisor.dev/gvisor/pkg/tcpip"
    17  )
    18  
    19  type Tun2socket struct {
    20  	Mtu int32
    21  
    22  	device io.Closer
    23  	nat    *nat.Nat
    24  
    25  	*netapi.ChannelServer
    26  }
    27  
    28  func New(o *tun.Opt) (netapi.Accepter, error) {
    29  	device, err := tun.OpenWriter(o.Interface, int(o.Tun.Mtu))
    30  	if err != nil {
    31  		return nil, fmt.Errorf("open tun device failed: %w", err)
    32  	}
    33  
    34  	o.Writer = device
    35  
    36  	nat, err := nat.Start(o)
    37  	if err != nil {
    38  		device.Close()
    39  		return nil, err
    40  	}
    41  
    42  	handler := &Tun2socket{
    43  		nat:           nat,
    44  		device:        device,
    45  		Mtu:           o.Tun.Mtu,
    46  		ChannelServer: netapi.NewChannelServer(),
    47  	}
    48  
    49  	go handler.tcpLoop()
    50  	go handler.udpLoop()
    51  
    52  	return handler, nil
    53  }
    54  
    55  func (h *Tun2socket) Close() error {
    56  	h.ChannelServer.Close()
    57  	_ = h.nat.TCP.Close()
    58  	_ = h.nat.UDP.Close()
    59  	return h.device.Close()
    60  }
    61  
    62  func (h *Tun2socket) tcpLoop() {
    63  
    64  	defer h.nat.TCP.Close()
    65  
    66  	for h.nat.TCP.SetDeadline(time.Time{}) == nil {
    67  		conn, err := h.nat.TCP.Accept()
    68  		if err != nil {
    69  			log.Error("tun2socket tcp accept failed", "err", err)
    70  			continue
    71  		}
    72  
    73  		go func() {
    74  			if err = h.handleTCP(conn); err != nil {
    75  				if errors.Is(err, netapi.ErrBlocked) {
    76  					log.Debug(err.Error())
    77  				} else {
    78  					log.Error("handle tcp failed", "err", err)
    79  				}
    80  			}
    81  		}()
    82  
    83  	}
    84  }
    85  
    86  func (h *Tun2socket) udpLoop() {
    87  	defer h.nat.UDP.Close()
    88  	for {
    89  		if err := h.handleUDP(); err != nil {
    90  			if errors.Is(err, netapi.ErrBlocked) {
    91  				log.Debug(err.Error())
    92  			} else {
    93  				log.Error("handle udp failed", "err", err)
    94  			}
    95  			if errors.Is(err, errUDPAccept) {
    96  				return
    97  			}
    98  		}
    99  	}
   100  }
   101  
   102  func (h *Tun2socket) handleTCP(conn net.Conn) error {
   103  	// lAddrPort := conn.LocalAddr().(*net.TCPAddr).AddrPort()  // source
   104  	rAddrPort := conn.RemoteAddr().(*net.TCPAddr) // dst
   105  
   106  	if rAddrPort.IP.IsLoopback() {
   107  		return nil
   108  	}
   109  
   110  	return h.SendStream(&netapi.StreamMeta{
   111  		Source:      conn.LocalAddr(),
   112  		Destination: conn.RemoteAddr(),
   113  		Src:         conn,
   114  		Address:     netapi.ParseTCPAddress(rAddrPort),
   115  	})
   116  }
   117  
   118  var errUDPAccept = errors.New("tun2socket udp accept failed")
   119  
   120  func (h *Tun2socket) handleUDP() error {
   121  	buf := pool.GetBytesBuffer(h.Mtu)
   122  
   123  	n, tuple, err := h.nat.UDP.ReadFrom(buf.Bytes())
   124  	if err != nil {
   125  		return fmt.Errorf("%w: %v", errUDPAccept, err)
   126  	}
   127  
   128  	buf.Refactor(0, n)
   129  
   130  	return h.SendPacket(&netapi.Packet{
   131  		Src: &net.UDPAddr{
   132  			IP:   net.IP(tuple.SourceAddr.AsSlice()),
   133  			Port: int(tuple.SourcePort),
   134  		},
   135  		Dst: netapi.ParseUDPAddr(&net.UDPAddr{
   136  			IP:   net.IP(tuple.DestinationAddr.AsSlice()),
   137  			Port: int(tuple.DestinationPort),
   138  		}),
   139  		Payload: buf,
   140  		WriteBack: func(b []byte, addr net.Addr) (int, error) {
   141  			address, err := netapi.ParseSysAddr(addr)
   142  			if err != nil {
   143  				return 0, err
   144  			}
   145  
   146  			daddr, err := address.IP(context.TODO())
   147  			if err != nil {
   148  				return 0, err
   149  			}
   150  
   151  			if tuple.SourceAddr.Len() == 16 {
   152  				daddr = daddr.To16()
   153  			}
   154  
   155  			return h.nat.UDP.WriteTo(b, nat.Tuple{
   156  				DestinationAddr: tcpip.AddrFromSlice(daddr),
   157  				DestinationPort: address.Port().Port(),
   158  				SourceAddr:      tuple.SourceAddr,
   159  				SourcePort:      tuple.SourcePort,
   160  			})
   161  		},
   162  	})
   163  }