github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/trie/trie.go (about)

     1  package trie
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log/slog"
     7  	"net/netip"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/log"
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/trie/cidr"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/trie/domain"
    13  	"github.com/Asutorufa/yuhaiin/pkg/utils/yerror"
    14  )
    15  
    16  type Trie[T any] struct {
    17  	cidr   *cidr.Cidr[T]
    18  	domain *domain.Fqdn[T]
    19  }
    20  
    21  func (x *Trie[T]) Insert(str string, mark T) {
    22  	if str == "" {
    23  		return
    24  	}
    25  
    26  	ipNet, err := netip.ParsePrefix(str)
    27  	if err == nil {
    28  		x.cidr.InsertCIDR(ipNet, mark)
    29  		return
    30  	}
    31  
    32  	if ip, err := netip.ParseAddr(str); err == nil {
    33  		mask := 128
    34  		if ip.Is4() {
    35  			mask = 32
    36  		}
    37  		x.cidr.InsertIP(ip, mask, mark)
    38  		return
    39  	}
    40  
    41  	x.domain.Insert(str, mark)
    42  }
    43  
    44  var ErrSkipResolver = errors.New("skip resolve domain")
    45  
    46  var SkipResolver = netapi.ErrorResolver(func(domain string) error { return ErrSkipResolver })
    47  
    48  func (x *Trie[T]) Search(ctx context.Context, addr netapi.Address) (mark T, ok bool) {
    49  	if addr.Type() == netapi.IP {
    50  		return x.cidr.SearchIP(yerror.Must(addr.IP(ctx)))
    51  	}
    52  
    53  	if mark, ok = x.domain.Search(addr); ok {
    54  		return
    55  	}
    56  
    57  	if ips, err := addr.IP(ctx); err == nil {
    58  		mark, ok = x.cidr.SearchIP(ips)
    59  	} else if !errors.Is(err, ErrSkipResolver) {
    60  		log.Warn("dns lookup failed, skip match ip", slog.Any("addr", addr), slog.Any("err", err))
    61  	}
    62  
    63  	return
    64  }
    65  
    66  func (x *Trie[T]) Remove(str string) {
    67  	if str == "" {
    68  		return
    69  	}
    70  
    71  	ipNet, err := netip.ParsePrefix(str)
    72  	if err == nil {
    73  		x.cidr.RemoveCIDR(ipNet)
    74  		return
    75  	}
    76  
    77  	if ip, err := netip.ParseAddr(str); err == nil {
    78  		mask := 128
    79  		if ip.Is4() {
    80  			mask = 32
    81  		}
    82  		x.cidr.RemoveIP(ip, mask)
    83  		return
    84  	}
    85  
    86  	x.domain.Remove(str)
    87  }
    88  
    89  func (x *Trie[T]) SearchWithDefault(ctx context.Context, addr netapi.Address, defaultT T) T {
    90  	t, ok := x.Search(ctx, addr)
    91  	if ok {
    92  		return t
    93  	}
    94  
    95  	return defaultT
    96  }
    97  
    98  func (x *Trie[T]) Clear() error {
    99  	x.cidr = cidr.NewCidrMapper[T]()
   100  	x.domain = domain.NewDomainMapper[T]()
   101  	return nil
   102  }
   103  
   104  func NewTrie[T any]() *Trie[T] {
   105  	return &Trie[T]{cidr: cidr.NewCidrMapper[T](), domain: domain.NewDomainMapper[T]()}
   106  }