github.com/metacubex/mihomo@v1.18.5/component/sniffer/dispatcher.go (about)

     1  package sniffer
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/metacubex/mihomo/common/lru"
    12  	N "github.com/metacubex/mihomo/common/net"
    13  	"github.com/metacubex/mihomo/component/trie"
    14  	C "github.com/metacubex/mihomo/constant"
    15  	"github.com/metacubex/mihomo/constant/sniffer"
    16  	"github.com/metacubex/mihomo/log"
    17  )
    18  
    19  var (
    20  	ErrorUnsupportedSniffer = errors.New("unsupported sniffer")
    21  	ErrorSniffFailed        = errors.New("all sniffer failed")
    22  	ErrNoClue               = errors.New("not enough information for making a decision")
    23  )
    24  
    25  var Dispatcher *SnifferDispatcher
    26  
    27  type SnifferDispatcher struct {
    28  	enable          bool
    29  	sniffers        map[sniffer.Sniffer]SnifferConfig
    30  	forceDomain     *trie.DomainSet
    31  	skipSNI         *trie.DomainSet
    32  	skipList        *lru.LruCache[string, uint8]
    33  	rwMux           sync.RWMutex
    34  	forceDnsMapping bool
    35  	parsePureIp     bool
    36  }
    37  
    38  func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
    39  	return (metadata.Host == "" && sd.parsePureIp) ||
    40  		sd.forceDomain.Has(metadata.Host) ||
    41  		(metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping)
    42  }
    43  
    44  func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
    45  	metadata := packet.Metadata()
    46  
    47  	if sd.shouldOverride(packet.Metadata()) {
    48  		for sniffer, config := range sd.sniffers {
    49  			if sniffer.SupportNetwork() == C.UDP || sniffer.SupportNetwork() == C.ALLNet {
    50  				inWhitelist := sniffer.SupportPort(metadata.DstPort)
    51  				overrideDest := config.OverrideDest
    52  
    53  				if inWhitelist {
    54  					host, err := sniffer.SniffData(packet.Data())
    55  					if err != nil {
    56  						continue
    57  					}
    58  
    59  					sd.replaceDomain(metadata, host, overrideDest)
    60  					return true
    61  				}
    62  			}
    63  		}
    64  	}
    65  
    66  	return false
    67  }
    68  
    69  // TCPSniff returns true if the connection is sniffed to have a domain
    70  func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
    71  	if sd.shouldOverride(metadata) {
    72  		inWhitelist := false
    73  		overrideDest := false
    74  		for sniffer, config := range sd.sniffers {
    75  			if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
    76  				inWhitelist = sniffer.SupportPort(metadata.DstPort)
    77  				if inWhitelist {
    78  					overrideDest = config.OverrideDest
    79  					break
    80  				}
    81  			}
    82  		}
    83  
    84  		if !inWhitelist {
    85  			return false
    86  		}
    87  
    88  		sd.rwMux.RLock()
    89  		dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
    90  		if count, ok := sd.skipList.Get(dst); ok && count > 5 {
    91  			log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
    92  			defer sd.rwMux.RUnlock()
    93  			return false
    94  		}
    95  		sd.rwMux.RUnlock()
    96  
    97  		if host, err := sd.sniffDomain(conn, metadata); err != nil {
    98  			sd.cacheSniffFailed(metadata)
    99  			log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
   100  			return false
   101  		} else {
   102  			if sd.skipSNI.Has(host) {
   103  				log.Debugln("[Sniffer] Skip sni[%s]", host)
   104  				return false
   105  			}
   106  
   107  			sd.rwMux.RLock()
   108  			sd.skipList.Delete(dst)
   109  			sd.rwMux.RUnlock()
   110  
   111  			sd.replaceDomain(metadata, host, overrideDest)
   112  			return true
   113  		}
   114  	}
   115  	return false
   116  }
   117  
   118  func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
   119  	// show log early, since the following code may mutate `metadata.Host`
   120  	log.Debugln("[Sniffer] Sniff %s [%s]-->[%s] success, replace domain [%s]-->[%s]",
   121  		metadata.NetWork,
   122  		metadata.SourceDetail(),
   123  		metadata.RemoteAddress(),
   124  		metadata.Host, host)
   125  	metadata.SniffHost = host
   126  	if overrideDest {
   127  		metadata.Host = host
   128  	}
   129  	metadata.DNSMode = C.DNSNormal
   130  }
   131  
   132  func (sd *SnifferDispatcher) Enable() bool {
   133  	return sd.enable
   134  }
   135  
   136  func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
   137  	for s := range sd.sniffers {
   138  		if s.SupportNetwork() == C.TCP {
   139  			_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
   140  			_, err := conn.Peek(1)
   141  			_ = conn.SetReadDeadline(time.Time{})
   142  			if err != nil {
   143  				_, ok := err.(*net.OpError)
   144  				if ok {
   145  					sd.cacheSniffFailed(metadata)
   146  					log.Errorln("[Sniffer] [%s] may not have any sent data, Consider adding skip", metadata.DstIP.String())
   147  					_ = conn.Close()
   148  				}
   149  
   150  				return "", err
   151  			}
   152  
   153  			bufferedLen := conn.Buffered()
   154  			bytes, err := conn.Peek(bufferedLen)
   155  			if err != nil {
   156  				log.Debugln("[Sniffer] the data length not enough")
   157  				continue
   158  			}
   159  
   160  			host, err := s.SniffData(bytes)
   161  			if err != nil {
   162  				//log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
   163  				continue
   164  			}
   165  
   166  			_, err = netip.ParseAddr(host)
   167  			if err == nil {
   168  				//log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
   169  				continue
   170  			}
   171  
   172  			return host, nil
   173  		}
   174  	}
   175  
   176  	return "", ErrorSniffFailed
   177  }
   178  
   179  func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
   180  	sd.rwMux.Lock()
   181  	dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
   182  	count, _ := sd.skipList.Get(dst)
   183  	if count <= 5 {
   184  		count++
   185  	}
   186  	sd.skipList.Set(dst, count)
   187  	sd.rwMux.Unlock()
   188  }
   189  
   190  func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
   191  	dispatcher := SnifferDispatcher{
   192  		enable: false,
   193  	}
   194  
   195  	return &dispatcher, nil
   196  }
   197  
   198  func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig,
   199  	forceDomain *trie.DomainSet, skipSNI *trie.DomainSet,
   200  	forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
   201  	dispatcher := SnifferDispatcher{
   202  		enable:          true,
   203  		forceDomain:     forceDomain,
   204  		skipSNI:         skipSNI,
   205  		skipList:        lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)),
   206  		forceDnsMapping: forceDnsMapping,
   207  		parsePureIp:     parsePureIp,
   208  		sniffers:        make(map[sniffer.Sniffer]SnifferConfig, 0),
   209  	}
   210  
   211  	for snifferName, config := range snifferConfig {
   212  		s, err := NewSniffer(snifferName, config)
   213  		if err != nil {
   214  			log.Errorln("Sniffer name[%s] is error", snifferName)
   215  			return &SnifferDispatcher{enable: false}, err
   216  		}
   217  		dispatcher.sniffers[s] = config
   218  	}
   219  
   220  	return &dispatcher, nil
   221  }
   222  
   223  func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) {
   224  	switch name {
   225  	case sniffer.TLS:
   226  		return NewTLSSniffer(snifferConfig)
   227  	case sniffer.HTTP:
   228  		return NewHTTPSniffer(snifferConfig)
   229  	case sniffer.QUIC:
   230  		return NewQuicSniffer(snifferConfig)
   231  	default:
   232  		return nil, ErrorUnsupportedSniffer
   233  	}
   234  }