github.com/anuvu/tyk@v2.9.0-beta9-dl-apic+incompatible/test/dns.go (about)

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"reflect"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"time"
    12  
    13  	"sync"
    14  
    15  	"github.com/miekg/dns"
    16  )
    17  
    18  var (
    19  	muDefaultResolver  sync.RWMutex
    20  	DomainsToAddresses = map[string][]string{
    21  		"host1.local.": {"127.0.0.1"},
    22  		"host2.local.": {"127.0.0.1"},
    23  		"host3.local.": {"127.0.0.1"},
    24  	}
    25  	DomainsToIgnore = []string{
    26  		"redis.",
    27  		"tyk-redis.",
    28  		"mongo.", // For dashboard integration tests
    29  		"tyk-mongo.",
    30  	}
    31  )
    32  
    33  type dnsMockHandler struct {
    34  	domainsToAddresses map[string][]string
    35  	domainsToErrors    map[string]int
    36  
    37  	muDomainsToAddresses sync.RWMutex
    38  }
    39  
    40  func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
    41  	msg := dns.Msg{}
    42  	msg.SetReply(r)
    43  	switch r.Question[0].Qtype {
    44  	case dns.TypeA:
    45  		msg.Authoritative = true
    46  		domain := msg.Question[0].Name
    47  
    48  		d.muDomainsToAddresses.RLock()
    49  		defer d.muDomainsToAddresses.RUnlock()
    50  
    51  		if rcode, ok := d.domainsToErrors[domain]; ok {
    52  			m := new(dns.Msg)
    53  			m.SetRcode(r, rcode)
    54  			w.WriteMsg(m)
    55  			return
    56  		}
    57  
    58  		for _, ignore := range DomainsToIgnore {
    59  			if strings.HasPrefix(domain, ignore) {
    60  				resolver := &net.Resolver{}
    61  				ipAddrs, err := resolver.LookupIPAddr(context.Background(), domain)
    62  				if err != nil {
    63  					m := new(dns.Msg)
    64  					m.SetRcode(r, dns.RcodeServerFailure)
    65  					w.WriteMsg(m)
    66  					return
    67  				}
    68  				msg.Answer = append(msg.Answer, &dns.A{
    69  					Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
    70  					A:   ipAddrs[0].IP,
    71  				})
    72  				w.WriteMsg(&msg)
    73  				return
    74  			}
    75  		}
    76  
    77  		addresses, ok := d.domainsToAddresses[domain]
    78  		if !ok {
    79  			// ^ 				start of line
    80  			// localhost\.		match literally
    81  			// ()* 				match between 0 and unlimited times
    82  			// [[:alnum:]]+\.	match single character in [a-zA-Z0-9] minimum one time and ending in . literally
    83  			reg := regexp.MustCompile(`^localhost\.([[:alnum:]]+\.)*`)
    84  			if matched := reg.MatchString(domain); !matched {
    85  				panic(fmt.Sprintf("domain not mocked: %s", domain))
    86  			}
    87  
    88  			addresses = []string{"127.0.0.1"}
    89  		}
    90  
    91  		for _, addr := range addresses {
    92  			msg.Answer = append(msg.Answer, &dns.A{
    93  				Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
    94  				A:   net.ParseIP(addr),
    95  			})
    96  		}
    97  	}
    98  	w.WriteMsg(&msg)
    99  }
   100  
   101  type DnsMockHandle struct {
   102  	id              string
   103  	mockServer      *dns.Server
   104  	ShutdownDnsMock func() error
   105  }
   106  
   107  func (h *DnsMockHandle) PushDomains(domainsMap map[string][]string, domainsErrorMap map[string]int) func() {
   108  	handler := h.mockServer.Handler.(*dnsMockHandler)
   109  	handler.muDomainsToAddresses.Lock()
   110  	defer handler.muDomainsToAddresses.Unlock()
   111  
   112  	dta := handler.domainsToAddresses
   113  	dte := handler.domainsToErrors
   114  
   115  	prevDta := map[string][]string{}
   116  	prevDte := map[string]int{}
   117  
   118  	for key, value := range dta {
   119  		prevDta[key] = value
   120  	}
   121  
   122  	for key, value := range dte {
   123  		prevDte[key] = value
   124  	}
   125  
   126  	pullDomainsFunc := func() {
   127  		handler := h.mockServer.Handler.(*dnsMockHandler)
   128  		handler.muDomainsToAddresses.Lock()
   129  		defer handler.muDomainsToAddresses.Unlock()
   130  
   131  		handler.domainsToAddresses = prevDta
   132  		handler.domainsToErrors = prevDte
   133  	}
   134  
   135  	for key, ips := range domainsMap {
   136  		addr, ok := dta[key]
   137  		if !ok {
   138  			dta[key] = ips
   139  		} else {
   140  			dta[key] = append(addr, ips...)
   141  		}
   142  	}
   143  
   144  	for key, rCode := range domainsErrorMap {
   145  		dte[key] = rCode
   146  	}
   147  
   148  	return pullDomainsFunc
   149  }
   150  
   151  // InitDNSMock initializes dns server on udp:0 address and replaces net.DefaultResolver in order
   152  // to route all dns queries within tests to this server.
   153  // InitDNSMock returns handle, which can be used to add/remove dns query mock responses or initialization error.
   154  func InitDNSMock(domainsMap map[string][]string, domainsErrorMap map[string]int) (*DnsMockHandle, error) {
   155  	addr, _ := net.ResolveUDPAddr("udp", ":0")
   156  	conn, err := net.ListenUDP("udp", addr)
   157  	if err != nil {
   158  		return &DnsMockHandle{}, err
   159  	}
   160  
   161  	startResultChannel := make(chan error)
   162  	started := func() {
   163  		startResultChannel <- nil
   164  	}
   165  
   166  	mockServer := &dns.Server{PacketConn: conn, NotifyStartedFunc: started}
   167  	handle := &DnsMockHandle{id: time.Now().String(), mockServer: mockServer}
   168  
   169  	dnsMux := &dnsMockHandler{muDomainsToAddresses: sync.RWMutex{}}
   170  
   171  	if domainsMap != nil {
   172  		dnsMux.domainsToAddresses = domainsMap
   173  	} else {
   174  		dnsMux.domainsToAddresses = DomainsToAddresses
   175  	}
   176  
   177  	if domainsErrorMap != nil {
   178  		dnsMux.domainsToErrors = domainsErrorMap
   179  	}
   180  
   181  	mockServer.Handler = dnsMux
   182  
   183  	go func() {
   184  		startResultChannel <- mockServer.ActivateAndServe()
   185  	}()
   186  
   187  	err = <-startResultChannel
   188  	if err != nil {
   189  		close(startResultChannel)
   190  		return handle, err
   191  	}
   192  
   193  	muDefaultResolver.RLock()
   194  	defaultResolver := net.DefaultResolver
   195  	muDefaultResolver.RUnlock()
   196  	mockResolver := &net.Resolver{
   197  		PreferGo: true,
   198  		Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
   199  			d := net.Dialer{}
   200  
   201  			//Use write lock to prevent unsafe d.DialContext update of net.DefaultResolver
   202  			muDefaultResolver.Lock()
   203  			defer muDefaultResolver.Unlock()
   204  			return d.DialContext(ctx, network, mockServer.PacketConn.LocalAddr().String())
   205  		},
   206  	}
   207  
   208  	muDefaultResolver.Lock()
   209  	net.DefaultResolver = mockResolver
   210  	muDefaultResolver.Unlock()
   211  
   212  	handle.ShutdownDnsMock = func() error {
   213  		muDefaultResolver.Lock()
   214  		net.DefaultResolver = defaultResolver
   215  		muDefaultResolver.Unlock()
   216  
   217  		return mockServer.Shutdown()
   218  	}
   219  
   220  	return handle, nil
   221  }
   222  
   223  func IsDnsRecordsAddrsEqualsTo(itemAddrs, addrs []string) bool {
   224  	return reflect.DeepEqual(itemAddrs, addrs)
   225  }