github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/trie/trie.go (about) 1 package trie 2 3 import ( 4 "context" 5 "errors" 6 "log/slog" 7 "net/netip" 8 9 "github.com/Asutorufa/yuhaiin/pkg/log" 10 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 11 "github.com/Asutorufa/yuhaiin/pkg/net/trie/cidr" 12 "github.com/Asutorufa/yuhaiin/pkg/net/trie/domain" 13 "github.com/Asutorufa/yuhaiin/pkg/utils/yerror" 14 ) 15 16 type Trie[T any] struct { 17 cidr *cidr.Cidr[T] 18 domain *domain.Fqdn[T] 19 } 20 21 func (x *Trie[T]) Insert(str string, mark T) { 22 if str == "" { 23 return 24 } 25 26 ipNet, err := netip.ParsePrefix(str) 27 if err == nil { 28 x.cidr.InsertCIDR(ipNet, mark) 29 return 30 } 31 32 if ip, err := netip.ParseAddr(str); err == nil { 33 mask := 128 34 if ip.Is4() { 35 mask = 32 36 } 37 x.cidr.InsertIP(ip, mask, mark) 38 return 39 } 40 41 x.domain.Insert(str, mark) 42 } 43 44 var ErrSkipResolver = errors.New("skip resolve domain") 45 46 var SkipResolver = netapi.ErrorResolver(func(domain string) error { return ErrSkipResolver }) 47 48 func (x *Trie[T]) Search(ctx context.Context, addr netapi.Address) (mark T, ok bool) { 49 if addr.Type() == netapi.IP { 50 return x.cidr.SearchIP(yerror.Must(addr.IP(ctx))) 51 } 52 53 if mark, ok = x.domain.Search(addr); ok { 54 return 55 } 56 57 if ips, err := addr.IP(ctx); err == nil { 58 mark, ok = x.cidr.SearchIP(ips) 59 } else if !errors.Is(err, ErrSkipResolver) { 60 log.Warn("dns lookup failed, skip match ip", slog.Any("addr", addr), slog.Any("err", err)) 61 } 62 63 return 64 } 65 66 func (x *Trie[T]) Remove(str string) { 67 if str == "" { 68 return 69 } 70 71 ipNet, err := netip.ParsePrefix(str) 72 if err == nil { 73 x.cidr.RemoveCIDR(ipNet) 74 return 75 } 76 77 if ip, err := netip.ParseAddr(str); err == nil { 78 mask := 128 79 if ip.Is4() { 80 mask = 32 81 } 82 x.cidr.RemoveIP(ip, mask) 83 return 84 } 85 86 x.domain.Remove(str) 87 } 88 89 func (x *Trie[T]) SearchWithDefault(ctx context.Context, addr netapi.Address, defaultT T) T { 90 t, ok := x.Search(ctx, addr) 91 if ok { 92 return t 93 } 94 95 return defaultT 96 } 97 98 func (x *Trie[T]) Clear() error { 99 x.cidr = cidr.NewCidrMapper[T]() 100 x.domain = domain.NewDomainMapper[T]() 101 return nil 102 } 103 104 func NewTrie[T any]() *Trie[T] { 105 return &Trie[T]{cidr: cidr.NewCidrMapper[T](), domain: domain.NewDomainMapper[T]()} 106 }