github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/dns/resolver.go (about) 1 package dns 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "math/rand" 9 "net" 10 "strings" 11 "time" 12 13 "github.com/laof/lite-speed-test/common/cache" 14 "github.com/laof/lite-speed-test/common/picker" 15 "github.com/laof/lite-speed-test/transport/resolver" 16 "golang.org/x/sync/singleflight" 17 18 // "github.com/Dreamacro/clash/component/trie" 19 20 D "github.com/miekg/dns" 21 ) 22 23 var ( 24 globalSessionCache = tls.NewLRUClientSessionCache(64) 25 ) 26 27 type dnsClient interface { 28 Exchange(m *D.Msg) (msg *D.Msg, err error) 29 ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) 30 } 31 32 type result struct { 33 Msg *D.Msg 34 Error error 35 } 36 37 type Resolver struct { 38 ipv6 bool 39 main []dnsClient 40 fallback []dnsClient 41 group singleflight.Group 42 lruCache *cache.LruCache 43 } 44 45 // ResolveIP request with TypeA and TypeAAAA, priority return TypeA 46 func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) { 47 ch := make(chan net.IP, 1) 48 go func() { 49 defer close(ch) 50 ip, err := r.resolveIP(host, D.TypeAAAA) 51 if err != nil { 52 return 53 } 54 ch <- ip 55 }() 56 57 ip, err = r.resolveIP(host, D.TypeA) 58 if err == nil && ip != nil { 59 return 60 } 61 62 ip, open := <-ch 63 if !open { 64 return nil, resolver.ErrIPNotFound 65 } 66 67 return ip, nil 68 } 69 70 // ResolveIPv4 request with TypeA 71 func (r *Resolver) ResolveIPv4(host string) (ip net.IP, err error) { 72 return r.resolveIP(host, D.TypeA) 73 } 74 75 // ResolveIPv6 request with TypeAAAA 76 func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { 77 return r.resolveIP(host, D.TypeAAAA) 78 } 79 80 func (r *Resolver) shouldIPFallback(ip net.IP) bool { 81 return false 82 } 83 84 // Exchange a batch of dns request, and it use cache 85 func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { 86 if len(m.Question) == 0 { 87 return nil, errors.New("should have one question at least") 88 } 89 90 q := m.Question[0] 91 cache, expireTime, hit := r.lruCache.GetWithExpire(q.String()) 92 if hit { 93 now := time.Now() 94 msg = cache.(*D.Msg).Copy() 95 if expireTime.Before(now) { 96 setMsgTTL(msg, uint32(1)) // Continue fetch 97 go r.exchangeWithoutCache(m) 98 } else { 99 setMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) 100 } 101 return 102 } 103 return r.exchangeWithoutCache(m) 104 } 105 106 // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache 107 func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { 108 q := m.Question[0] 109 110 ret, err, shared := r.group.Do(q.String(), func() (result interface{}, err error) { 111 defer func() { 112 if err != nil { 113 return 114 } 115 116 msg := result.(*D.Msg) 117 118 putMsgToCache(r.lruCache, q.String(), msg) 119 }() 120 121 isIPReq := isIPRequest(q) 122 if isIPReq { 123 return r.ipExchange(m) 124 } 125 126 return r.batchExchange(r.main, m) 127 }) 128 129 if err == nil { 130 msg = ret.(*D.Msg) 131 if shared { 132 msg = msg.Copy() 133 } 134 } 135 136 return 137 } 138 139 func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) { 140 fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) 141 for _, client := range clients { 142 r := client 143 fast.Go(func() (interface{}, error) { 144 m, err := r.ExchangeContext(ctx, m) 145 if err != nil { 146 return nil, err 147 } else if m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused { 148 return nil, errors.New("server failure") 149 } 150 return m, nil 151 }) 152 } 153 154 elm := fast.Wait() 155 if elm == nil { 156 err := errors.New("all DNS requests failed") 157 if fErr := fast.Error(); fErr != nil { 158 err = fmt.Errorf("%w, first error: %s", err, fErr.Error()) 159 } 160 return nil, err 161 } 162 163 msg = elm.(*D.Msg) 164 return 165 } 166 167 func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { 168 if r.fallback == nil { 169 return false 170 } 171 172 domain := r.msgToDomain(m) 173 174 if domain == "" { 175 return false 176 } 177 178 return false 179 } 180 181 func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) { 182 183 onlyFallback := r.shouldOnlyQueryFallback(m) 184 185 if onlyFallback { 186 res := <-r.asyncExchange(r.fallback, m) 187 return res.Msg, res.Error 188 } 189 190 msgCh := r.asyncExchange(r.main, m) 191 192 if r.fallback == nil { // directly return if no fallback servers are available 193 res := <-msgCh 194 msg, err = res.Msg, res.Error 195 return 196 } 197 198 fallbackMsg := r.asyncExchange(r.fallback, m) 199 res := <-msgCh 200 if res.Error == nil { 201 if ips := r.msgToIP(res.Msg); len(ips) != 0 { 202 if !r.shouldIPFallback(ips[0]) { 203 msg = res.Msg // no need to wait for fallback result 204 err = res.Error 205 return msg, err 206 } 207 } 208 } 209 210 res = <-fallbackMsg 211 msg, err = res.Msg, res.Error 212 return 213 } 214 215 func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { 216 ip = net.ParseIP(host) 217 if ip != nil { 218 isIPv4 := ip.To4() != nil 219 if dnsType == D.TypeAAAA && !isIPv4 { 220 return ip, nil 221 } else if dnsType == D.TypeA && isIPv4 { 222 return ip, nil 223 } else { 224 return nil, resolver.ErrIPVersion 225 } 226 } 227 228 query := &D.Msg{} 229 query.SetQuestion(D.Fqdn(host), dnsType) 230 231 msg, err := r.Exchange(query) 232 if err != nil { 233 return nil, err 234 } 235 236 ips := r.msgToIP(msg) 237 ipLength := len(ips) 238 if ipLength == 0 { 239 return nil, resolver.ErrIPNotFound 240 } 241 242 ip = ips[rand.Intn(ipLength)] 243 return 244 } 245 246 func (r *Resolver) msgToIP(msg *D.Msg) []net.IP { 247 ips := []net.IP{} 248 249 for _, answer := range msg.Answer { 250 switch ans := answer.(type) { 251 case *D.AAAA: 252 ips = append(ips, ans.AAAA) 253 case *D.A: 254 ips = append(ips, ans.A) 255 } 256 } 257 258 return ips 259 } 260 261 func (r *Resolver) msgToDomain(msg *D.Msg) string { 262 if len(msg.Question) > 0 { 263 return strings.TrimRight(msg.Question[0].Name, ".") 264 } 265 266 return "" 267 } 268 269 func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result { 270 ch := make(chan *result, 1) 271 go func() { 272 res, err := r.batchExchange(client, msg) 273 ch <- &result{Msg: res, Error: err} 274 }() 275 return ch 276 } 277 278 type NameServer struct { 279 Net string 280 Addr string 281 } 282 283 type FallbackFilter struct { 284 GeoIP bool 285 IPCIDR []*net.IPNet 286 Domain []string 287 } 288 289 type Config struct { 290 Main, Fallback []NameServer 291 Default []NameServer 292 IPv6 bool 293 FallbackFilter FallbackFilter 294 // Pool *fakeip.Pool 295 // Hosts *trie.DomainTrie 296 } 297 298 func NewResolver(config Config) *Resolver { 299 defaultResolver := &Resolver{ 300 main: transform(config.Default, nil), 301 lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), 302 } 303 304 r := &Resolver{ 305 ipv6: config.IPv6, 306 lruCache: cache.NewLRUCache(cache.WithSize(4096), cache.WithStale(true)), 307 main: transform(config.Main, defaultResolver), 308 } 309 310 if len(config.Fallback) != 0 { 311 r.fallback = transform(config.Fallback, defaultResolver) 312 } 313 314 return r 315 } 316 317 func DefaultResolver() *Resolver { 318 servers := []NameServer{ 319 { 320 Net: "udp", 321 Addr: "223.5.5.5:53", 322 }, 323 { 324 Net: "udp", 325 Addr: "8.8.8.8:53", 326 }, 327 } 328 c := Config{ 329 Main: servers, 330 Default: servers, 331 } 332 return NewResolver(c) 333 }