github.com/metacubex/mihomo@v1.18.5/listener/sing_tun/dns.go (about)

     1  package sing_tun
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/metacubex/mihomo/component/resolver"
    11  	"github.com/metacubex/mihomo/listener/sing"
    12  	"github.com/metacubex/mihomo/log"
    13  
    14  	"github.com/sagernet/sing/common/buf"
    15  	"github.com/sagernet/sing/common/bufio"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	"github.com/sagernet/sing/common/network"
    18  )
    19  
    20  type ListenerHandler struct {
    21  	*sing.ListenerHandler
    22  	DnsAdds []netip.AddrPort
    23  }
    24  
    25  func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool {
    26  	if targetAddr.Addr().IsLoopback() && targetAddr.Port() == 53 { // cause by system stack
    27  		return true
    28  	}
    29  	for _, addrPort := range h.DnsAdds {
    30  		if addrPort == targetAddr || (addrPort.Addr().IsUnspecified() && targetAddr.Port() == 53) {
    31  			return true
    32  		}
    33  	}
    34  	return false
    35  }
    36  
    37  func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    38  	if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
    39  		log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String())
    40  		return resolver.RelayDnsConn(ctx, conn, resolver.DefaultDnsReadTimeout)
    41  	}
    42  	return h.ListenerHandler.NewConnection(ctx, conn, metadata)
    43  }
    44  
    45  func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error {
    46  	if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
    47  		log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String())
    48  		defer func() { _ = conn.Close() }()
    49  		mutex := sync.Mutex{}
    50  		conn2 := conn // a new interface to set nil in defer
    51  		defer func() {
    52  			mutex.Lock() // this goroutine must exit after all conn.WritePacket() is not running
    53  			defer mutex.Unlock()
    54  			conn2 = nil
    55  		}()
    56  		rwOptions := network.ReadWaitOptions{
    57  			FrontHeadroom: network.CalculateFrontHeadroom(conn),
    58  			RearHeadroom:  network.CalculateRearHeadroom(conn),
    59  			MTU:           resolver.SafeDnsPacketSize,
    60  		}
    61  		readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
    62  		if isReadWaiter {
    63  			readWaiter.InitializeReadWaiter(rwOptions)
    64  		}
    65  		for {
    66  			var (
    67  				readBuff *buf.Buffer
    68  				dest     M.Socksaddr
    69  				err      error
    70  			)
    71  			_ = conn.SetReadDeadline(time.Now().Add(resolver.DefaultDnsReadTimeout))
    72  			readBuff = nil // clear last loop status, avoid repeat release
    73  			if isReadWaiter {
    74  				readBuff, dest, err = readWaiter.WaitReadPacket()
    75  			} else {
    76  				readBuff = rwOptions.NewPacketBuffer()
    77  				dest, err = conn.ReadPacket(readBuff)
    78  				if readBuff != nil {
    79  					rwOptions.PostReturn(readBuff)
    80  				}
    81  			}
    82  			if err != nil {
    83  				if readBuff != nil {
    84  					readBuff.Release()
    85  				}
    86  				if sing.ShouldIgnorePacketError(err) {
    87  					break
    88  				}
    89  				return err
    90  			}
    91  			go func() {
    92  				ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDnsRelayTimeout)
    93  				defer cancel()
    94  				inData := readBuff.Bytes()
    95  				writeBuff := readBuff
    96  				writeBuff.Resize(writeBuff.Start(), 0)
    97  				if len(writeBuff.FreeBytes()) < resolver.SafeDnsPacketSize { // only create a new buffer when space don't enough
    98  					writeBuff = rwOptions.NewPacketBuffer()
    99  				}
   100  				msg, err := resolver.RelayDnsPacket(ctx, inData, writeBuff.FreeBytes())
   101  				if writeBuff != readBuff {
   102  					readBuff.Release()
   103  				}
   104  				if err != nil {
   105  					writeBuff.Release()
   106  					return
   107  				}
   108  				writeBuff.Truncate(len(msg))
   109  				mutex.Lock()
   110  				defer mutex.Unlock()
   111  				conn := conn2
   112  				if conn == nil {
   113  					writeBuff.Release()
   114  					return
   115  				}
   116  				err = conn.WritePacket(writeBuff, dest) // WritePacket will release writeBuff
   117  				if err != nil {
   118  					writeBuff.Release()
   119  					return
   120  				}
   121  			}()
   122  		}
   123  		return nil
   124  	}
   125  	return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata)
   126  }