github.com/zooyer/miskit@v1.0.71/dns/dns.go (about)

     1  /**
     2   * @Author: zzy
     3   * @Email: zhangzhongyuan@didiglobal.com
     4   * @Description:
     5   * @File: dns.go
     6   * @Package: dns
     7   * @Version: 1.0.0
     8   * @Date: 2022/9/28 16:34
     9   */
    10  
    11  package dns
    12  
    13  import (
    14  	"net"
    15  	"strings"
    16  
    17  	"golang.org/x/net/dns/dnsmessage"
    18  )
    19  
    20  var (
    21  	ttl  uint32 = 600
    22  	conn *net.UDPConn
    23  )
    24  
    25  func setForwardTTL(t uint32) {
    26  	ttl = t
    27  }
    28  
    29  func SetForwardTTL(ttl uint32) {
    30  	setForwardTTL(ttl)
    31  }
    32  
    33  func handleHook(conn *net.UDPConn, addr *net.UDPAddr, msg dnsmessage.Message, hosts map[string][]dnsmessage.Resource) {
    34  	if conn == nil || addr == nil || len(msg.Questions) < 1 {
    35  		return
    36  	}
    37  
    38  	var questions = make([]dnsmessage.Question, 0, len(msg.Questions))
    39  	for _, question := range msg.Questions {
    40  		var hit bool
    41  
    42  		name := strings.TrimRight(question.Name.String(), ".")
    43  		for _, host := range hosts[name] {
    44  			var (
    45  				ok  bool
    46  				res dnsmessage.ResourceBody
    47  			)
    48  
    49  			if host.Header.Class != question.Class {
    50  				continue
    51  			}
    52  
    53  			switch question.Type {
    54  			case dnsmessage.TypeA:
    55  				res, ok = host.Body.(*dnsmessage.AResource)
    56  			case dnsmessage.TypeAAAA:
    57  				res, ok = host.Body.(*dnsmessage.AAAAResource)
    58  			}
    59  
    60  			if !ok {
    61  				continue
    62  			}
    63  
    64  			hit = true
    65  
    66  			msg.Answers = append(msg.Answers, dnsmessage.Resource{
    67  				Header: dnsmessage.ResourceHeader{
    68  					Name:  question.Name,
    69  					Class: question.Class,
    70  					TTL:   host.Header.TTL,
    71  				},
    72  				Body: res,
    73  			})
    74  		}
    75  
    76  		if !hit {
    77  			questions = append(questions, question)
    78  		}
    79  	}
    80  
    81  	if len(questions) > 0 {
    82  		for _, question := range questions {
    83  			ips, err := net.LookupIP(question.Name.String())
    84  			if err != nil {
    85  				continue
    86  			}
    87  
    88  			for _, ip := range ips {
    89  				var resource = dnsmessage.Resource{
    90  					Header: dnsmessage.ResourceHeader{
    91  						Name:  question.Name,
    92  						Class: question.Class,
    93  						TTL:   ttl,
    94  					},
    95  					Body: nil,
    96  				}
    97  
    98  				if ipv4 := ip.To4(); len(ipv4) == net.IPv4len {
    99  					if question.Type != dnsmessage.TypeA {
   100  						continue
   101  					}
   102  					var v4 dnsmessage.AResource
   103  					copy(v4.A[:], ipv4[:net.IPv4len])
   104  					resource.Body = &v4
   105  					resource.Header.Type = dnsmessage.TypeA
   106  				} else if ipv6 := ip.To16(); len(ipv6) == net.IPv6len {
   107  					if question.Type != dnsmessage.TypeAAAA {
   108  						continue
   109  					}
   110  					var v6 dnsmessage.AAAAResource
   111  					copy(v6.AAAA[:], ipv6[:net.IPv6len])
   112  					resource.Body = &v6
   113  					resource.Header.Type = dnsmessage.TypeAAAA
   114  				}
   115  
   116  				msg.Answers = append(msg.Answers, resource)
   117  			}
   118  		}
   119  	}
   120  
   121  	if len(msg.Answers) > 0 {
   122  		msg.Response = true
   123  	}
   124  
   125  	pkg, err := msg.Pack()
   126  	if err != nil {
   127  		return
   128  	}
   129  
   130  	if _, err = conn.WriteToUDP(pkg, addr); err != nil {
   131  		return
   132  	}
   133  }
   134  
   135  func HookHosts(hosts map[string][]dnsmessage.Resource) (err error) {
   136  	conn, err = net.ListenUDP("udp", &net.UDPAddr{Port: 53})
   137  	if err != nil {
   138  		return
   139  	}
   140  
   141  	var buf = make([]byte, 512)
   142  	for {
   143  		_, addr, err := conn.ReadFromUDP(buf)
   144  		if err != nil {
   145  			continue
   146  		}
   147  
   148  		var msg dnsmessage.Message
   149  		if err = msg.Unpack(buf); err != nil {
   150  			continue
   151  		}
   152  
   153  		go handleHook(conn, addr, msg, hosts)
   154  	}
   155  }
   156  
   157  func HookHostsByText(hosts map[string][]string) (err error) {
   158  	var resources = make(map[string][]dnsmessage.Resource)
   159  	for name, hosts := range hosts {
   160  		for _, host := range hosts {
   161  			var resource = dnsmessage.Resource{
   162  				Header: dnsmessage.ResourceHeader{
   163  					Name:   dnsmessage.MustNewName(name),
   164  					Type:   0,
   165  					Class:  dnsmessage.ClassINET,
   166  					TTL:    ttl,
   167  					Length: 0,
   168  				},
   169  				Body: nil,
   170  			}
   171  
   172  			ip := net.ParseIP(host)
   173  
   174  			if ipv4 := ip.To4(); len(ipv4) == net.IPv4len {
   175  				var v4 dnsmessage.AResource
   176  				copy(v4.A[:], ipv4[:net.IPv4len])
   177  				resource.Body = &v4
   178  				resource.Header.Type = dnsmessage.TypeA
   179  			} else if ipv6 := ip.To16(); len(ipv6) == net.IPv6len {
   180  				var v6 dnsmessage.AAAAResource
   181  				copy(v6.AAAA[:], ipv6[:net.IPv6len])
   182  				resource.Body = &v6
   183  				resource.Header.Type = dnsmessage.TypeAAAA
   184  			}
   185  
   186  			resources[name] = append(resources[name], resource)
   187  		}
   188  	}
   189  
   190  	return HookHosts(resources)
   191  }
   192  
   193  func HookHostsByLocal(names ...string) (err error) {
   194  	address, err := net.InterfaceAddrs()
   195  	if err != nil {
   196  		return
   197  	}
   198  
   199  	var hosts = make(map[string][]string)
   200  	for _, addr := range address {
   201  		if addr, ok := addr.(*net.IPNet); ok && addr.IP.IsGlobalUnicast() {
   202  			for _, name := range names {
   203  				hosts[name] = append(hosts[name], addr.IP.String())
   204  			}
   205  		}
   206  	}
   207  
   208  	return HookHostsByText(hosts)
   209  }