github.com/yaling888/clash@v1.53.0/listener/tun/ipstack/system/stack.go (about)

     1  package system
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net"
     7  	"net/netip"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/phuslu/log"
    12  
    13  	"github.com/yaling888/clash/adapter/inbound"
    14  	"github.com/yaling888/clash/common/nnip"
    15  	"github.com/yaling888/clash/common/pool"
    16  	C "github.com/yaling888/clash/constant"
    17  	"github.com/yaling888/clash/listener/tun/device"
    18  	"github.com/yaling888/clash/listener/tun/ipstack"
    19  	D "github.com/yaling888/clash/listener/tun/ipstack/commons"
    20  	"github.com/yaling888/clash/listener/tun/ipstack/system/mars"
    21  	"github.com/yaling888/clash/listener/tun/ipstack/system/mars/nat"
    22  )
    23  
    24  type sysStack struct {
    25  	stack  io.Closer
    26  	device device.Device
    27  
    28  	closed bool
    29  	once   sync.Once
    30  	wg     sync.WaitGroup
    31  }
    32  
    33  func (s *sysStack) Close() error {
    34  	D.StopDefaultInterfaceChangeMonitor()
    35  
    36  	defer func() {
    37  		if s.device != nil {
    38  			_ = s.device.Close()
    39  		}
    40  	}()
    41  
    42  	s.closed = true
    43  
    44  	err := s.stack.Close()
    45  
    46  	s.wg.Wait()
    47  
    48  	return err
    49  }
    50  
    51  func New(device device.Device, dnsHijack []C.DNSUrl, tunAddress netip.Prefix, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.Stack, error) {
    52  	var (
    53  		gateway   = tunAddress.Masked().Addr().Next()
    54  		portal    = gateway.Next()
    55  		broadcast = nnip.UnMasked(tunAddress)
    56  	)
    57  
    58  	stack, err := mars.StartListener(device, gateway, portal, broadcast)
    59  	if err != nil {
    60  		_ = device.Close()
    61  
    62  		return nil, err
    63  	}
    64  
    65  	ipStack := &sysStack{stack: stack, device: device}
    66  
    67  	dnsAddr := dnsHijack
    68  
    69  	tcp := func() {
    70  		defer func(tcp *nat.TCP) {
    71  			_ = tcp.Close()
    72  		}(stack.TCP())
    73  
    74  		for {
    75  			conn, err0 := stack.TCP().Accept()
    76  			if err0 != nil {
    77  				if ipStack.closed {
    78  					break
    79  				}
    80  				log.Warn().
    81  					Err(err0).
    82  					Msg("[Stack] accept connection failed")
    83  				continue
    84  			}
    85  
    86  			lAddrPort := conn.LocalAddr().(*net.TCPAddr).AddrPort()
    87  			rAddrPort := conn.RemoteAddr().(*net.TCPAddr).AddrPort()
    88  
    89  			if rAddrPort.Addr().IsLoopback() {
    90  				_ = conn.Close()
    91  
    92  				continue
    93  			}
    94  
    95  			if D.ShouldHijackDns(dnsAddr, rAddrPort, "tcp") {
    96  				go func(dnsConn net.Conn, addr netip.AddrPort) {
    97  					log.Debug().NetIPAddrPort("addr", addr).Msg("[TUN] hijack tcp dns")
    98  
    99  					defer func(c net.Conn) {
   100  						_ = c.Close()
   101  					}(dnsConn)
   102  
   103  					err1 := dnsConn.SetReadDeadline(time.Now().Add(C.DefaultTCPTimeout))
   104  					if err1 != nil {
   105  						return
   106  					}
   107  
   108  					buf := pool.NewBuffer()
   109  					defer buf.Release()
   110  
   111  					length, err1 := buf.ReadUint16be(dnsConn)
   112  					if err1 != nil {
   113  						return
   114  					}
   115  
   116  					_, err1 = buf.ReadFullFrom(dnsConn, int64(length))
   117  					if err1 != nil {
   118  						return
   119  					}
   120  
   121  					msg, err1 := D.RelayDnsPacket(buf.Bytes())
   122  					if err1 != nil {
   123  						return
   124  					}
   125  
   126  					buf.Reset()
   127  
   128  					length = uint16(len(msg))
   129  					_ = binary.Write(buf, binary.BigEndian, length)
   130  
   131  					_, err1 = buf.Write(msg)
   132  					if err1 != nil {
   133  						return
   134  					}
   135  
   136  					_, _ = buf.WriteTo(dnsConn)
   137  				}(conn, rAddrPort)
   138  
   139  				continue
   140  			}
   141  
   142  			tcpIn <- inbound.NewSocketBy(conn, lAddrPort, rAddrPort, C.TUN)
   143  		}
   144  
   145  		ipStack.wg.Done()
   146  	}
   147  
   148  	udp := func() {
   149  		defer func(udp *nat.UDP) {
   150  			_ = udp.Close()
   151  		}(stack.UDP())
   152  
   153  		for {
   154  			ue, err0 := stack.UDP().ReadFrom()
   155  			if err0 != nil {
   156  				if ipStack.closed {
   157  					break
   158  				}
   159  
   160  				log.Warn().Err(err0).Msg("[Stack] accept udp failed")
   161  				continue
   162  			}
   163  
   164  			rAddrPort := ue.Destination
   165  			if rAddrPort.Addr().IsLoopback() || rAddrPort.Addr() == gateway {
   166  				stack.UDP().PutUDPElement(ue)
   167  				continue
   168  			}
   169  
   170  			if D.ShouldHijackDns(dnsAddr, rAddrPort, "udp") {
   171  				go func() {
   172  					defer stack.UDP().PutUDPElement(ue)
   173  
   174  					log.Debug().NetIPAddrPort("addr", ue.Destination).Msg("[TUN] hijack udp dns")
   175  
   176  					msg, err1 := D.RelayDnsPacket(*ue.Packet)
   177  					if err1 != nil {
   178  						return
   179  					}
   180  
   181  					_, _ = stack.UDP().WriteTo(msg, ue.Destination, ue.Source)
   182  				}()
   183  
   184  				continue
   185  			}
   186  
   187  			pkt := &packet{
   188  				sender: stack.UDP(),
   189  				data:   ue,
   190  				lAddr:  ue.Source,
   191  			}
   192  
   193  			select {
   194  			case udpIn <- inbound.NewPacketBy(pkt, ue.Source, rAddrPort, C.TUN):
   195  			default:
   196  				log.Debug().
   197  					NetIPAddrPort("lAddrPort", ue.Source).
   198  					NetIPAddrPort("rAddrPort", rAddrPort).
   199  					Msg("[Stack] drop udp packet, because inbound queue is full")
   200  				pkt.Drop()
   201  			}
   202  		}
   203  
   204  		ipStack.wg.Done()
   205  	}
   206  
   207  	ipStack.once.Do(func() {
   208  		ipStack.wg.Add(2)
   209  		go tcp()
   210  		go udp()
   211  	})
   212  
   213  	return ipStack, nil
   214  }