github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/dnsx/resolver.go (about)

     1  package dnsx
     2  
     3  import (
     4  	"fmt"
     5  	logger "log"
     6  	"net"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    11  	dns "github.com/miekg/dns"
    12  )
    13  
    14  type DomainResolver struct {
    15  	ttl         int
    16  	dnsAddrress string
    17  	data        mapx.ConcurrentMap
    18  	log         *logger.Logger
    19  }
    20  type DomainResolverItem struct {
    21  	ip        string
    22  	domain    string
    23  	expiredAt int64
    24  }
    25  
    26  func NewDomainResolver(dnsAddrress string, ttl int, log *logger.Logger) DomainResolver {
    27  	return DomainResolver{
    28  		ttl:         ttl,
    29  		dnsAddrress: dnsAddrress,
    30  		data:        mapx.NewConcurrentMap(),
    31  		log:         log,
    32  	}
    33  }
    34  func (a *DomainResolver) DnsAddress() (address string) {
    35  	address = a.dnsAddrress
    36  	return
    37  }
    38  func (a *DomainResolver) MustResolve(address string) (ip string) {
    39  	ip, _ = a.Resolve(address)
    40  	return
    41  }
    42  func (a *DomainResolver) Resolve(address string) (ip string, err error) {
    43  	domain := address
    44  	port := ""
    45  	fromCache := "false"
    46  	defer func() {
    47  		if port != "" {
    48  			ip = net.JoinHostPort(ip, port)
    49  		}
    50  		a.log.Printf("dns:%s->%s,cache:%s", address, ip, fromCache)
    51  		//a.PrintData()
    52  	}()
    53  	if strings.Contains(domain, ":") {
    54  		domain, port, err = net.SplitHostPort(domain)
    55  		if err != nil {
    56  			return
    57  		}
    58  	}
    59  	if net.ParseIP(domain) != nil {
    60  		ip = domain
    61  		fromCache = "ip ignore"
    62  		return
    63  	}
    64  	item, ok := a.data.Get(domain)
    65  	if ok {
    66  		//log.Println("find ", domain)
    67  		if (*item.(*DomainResolverItem)).expiredAt > time.Now().Unix() {
    68  			ip = (*item.(*DomainResolverItem)).ip
    69  			fromCache = "true"
    70  			//log.Println("from cache ", domain)
    71  			return
    72  		}
    73  	} else {
    74  		item = &DomainResolverItem{
    75  			domain: domain,
    76  		}
    77  
    78  	}
    79  	c := new(dns.Client)
    80  	c.DialTimeout = time.Millisecond * 5000
    81  	c.ReadTimeout = time.Millisecond * 5000
    82  	c.WriteTimeout = time.Millisecond * 5000
    83  	m := new(dns.Msg)
    84  	m.SetQuestion(dns.Fqdn(domain), dns.TypeA)
    85  	m.RecursionDesired = true
    86  	r, _, err := c.Exchange(m, a.dnsAddrress)
    87  	if r == nil {
    88  		return
    89  	}
    90  	if r.Rcode != dns.RcodeSuccess {
    91  		err = fmt.Errorf(" *** invalid answer name %s after A query for %s", domain, a.dnsAddrress)
    92  		return
    93  	}
    94  	for _, answer := range r.Answer {
    95  		if answer.Header().Rrtype == dns.TypeA {
    96  			info := strings.Fields(answer.String())
    97  			if len(info) >= 5 {
    98  				ip = info[4]
    99  				_item := item.(*DomainResolverItem)
   100  				(*_item).expiredAt = time.Now().Unix() + int64(a.ttl)
   101  				(*_item).ip = ip
   102  				a.data.Set(domain, item)
   103  				return
   104  			}
   105  		}
   106  	}
   107  	return
   108  }
   109  func (a *DomainResolver) PrintData() {
   110  	for k, item := range a.data.Items() {
   111  		d := item.(*DomainResolverItem)
   112  		a.log.Printf("%s:ip[%s],domain[%s],expired at[%d]\n", k, (*d).ip, (*d).domain, (*d).expiredAt)
   113  	}
   114  }