github.com/anuvu/tyk@v2.9.0-beta9-dl-apic+incompatible/test/dns.go (about) 1 package test 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "reflect" 8 "regexp" 9 "strings" 10 11 "time" 12 13 "sync" 14 15 "github.com/miekg/dns" 16 ) 17 18 var ( 19 muDefaultResolver sync.RWMutex 20 DomainsToAddresses = map[string][]string{ 21 "host1.local.": {"127.0.0.1"}, 22 "host2.local.": {"127.0.0.1"}, 23 "host3.local.": {"127.0.0.1"}, 24 } 25 DomainsToIgnore = []string{ 26 "redis.", 27 "tyk-redis.", 28 "mongo.", // For dashboard integration tests 29 "tyk-mongo.", 30 } 31 ) 32 33 type dnsMockHandler struct { 34 domainsToAddresses map[string][]string 35 domainsToErrors map[string]int 36 37 muDomainsToAddresses sync.RWMutex 38 } 39 40 func (d *dnsMockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { 41 msg := dns.Msg{} 42 msg.SetReply(r) 43 switch r.Question[0].Qtype { 44 case dns.TypeA: 45 msg.Authoritative = true 46 domain := msg.Question[0].Name 47 48 d.muDomainsToAddresses.RLock() 49 defer d.muDomainsToAddresses.RUnlock() 50 51 if rcode, ok := d.domainsToErrors[domain]; ok { 52 m := new(dns.Msg) 53 m.SetRcode(r, rcode) 54 w.WriteMsg(m) 55 return 56 } 57 58 for _, ignore := range DomainsToIgnore { 59 if strings.HasPrefix(domain, ignore) { 60 resolver := &net.Resolver{} 61 ipAddrs, err := resolver.LookupIPAddr(context.Background(), domain) 62 if err != nil { 63 m := new(dns.Msg) 64 m.SetRcode(r, dns.RcodeServerFailure) 65 w.WriteMsg(m) 66 return 67 } 68 msg.Answer = append(msg.Answer, &dns.A{ 69 Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, 70 A: ipAddrs[0].IP, 71 }) 72 w.WriteMsg(&msg) 73 return 74 } 75 } 76 77 addresses, ok := d.domainsToAddresses[domain] 78 if !ok { 79 // ^ start of line 80 // localhost\. match literally 81 // ()* match between 0 and unlimited times 82 // [[:alnum:]]+\. match single character in [a-zA-Z0-9] minimum one time and ending in . literally 83 reg := regexp.MustCompile(`^localhost\.([[:alnum:]]+\.)*`) 84 if matched := reg.MatchString(domain); !matched { 85 panic(fmt.Sprintf("domain not mocked: %s", domain)) 86 } 87 88 addresses = []string{"127.0.0.1"} 89 } 90 91 for _, addr := range addresses { 92 msg.Answer = append(msg.Answer, &dns.A{ 93 Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, 94 A: net.ParseIP(addr), 95 }) 96 } 97 } 98 w.WriteMsg(&msg) 99 } 100 101 type DnsMockHandle struct { 102 id string 103 mockServer *dns.Server 104 ShutdownDnsMock func() error 105 } 106 107 func (h *DnsMockHandle) PushDomains(domainsMap map[string][]string, domainsErrorMap map[string]int) func() { 108 handler := h.mockServer.Handler.(*dnsMockHandler) 109 handler.muDomainsToAddresses.Lock() 110 defer handler.muDomainsToAddresses.Unlock() 111 112 dta := handler.domainsToAddresses 113 dte := handler.domainsToErrors 114 115 prevDta := map[string][]string{} 116 prevDte := map[string]int{} 117 118 for key, value := range dta { 119 prevDta[key] = value 120 } 121 122 for key, value := range dte { 123 prevDte[key] = value 124 } 125 126 pullDomainsFunc := func() { 127 handler := h.mockServer.Handler.(*dnsMockHandler) 128 handler.muDomainsToAddresses.Lock() 129 defer handler.muDomainsToAddresses.Unlock() 130 131 handler.domainsToAddresses = prevDta 132 handler.domainsToErrors = prevDte 133 } 134 135 for key, ips := range domainsMap { 136 addr, ok := dta[key] 137 if !ok { 138 dta[key] = ips 139 } else { 140 dta[key] = append(addr, ips...) 141 } 142 } 143 144 for key, rCode := range domainsErrorMap { 145 dte[key] = rCode 146 } 147 148 return pullDomainsFunc 149 } 150 151 // InitDNSMock initializes dns server on udp:0 address and replaces net.DefaultResolver in order 152 // to route all dns queries within tests to this server. 153 // InitDNSMock returns handle, which can be used to add/remove dns query mock responses or initialization error. 154 func InitDNSMock(domainsMap map[string][]string, domainsErrorMap map[string]int) (*DnsMockHandle, error) { 155 addr, _ := net.ResolveUDPAddr("udp", ":0") 156 conn, err := net.ListenUDP("udp", addr) 157 if err != nil { 158 return &DnsMockHandle{}, err 159 } 160 161 startResultChannel := make(chan error) 162 started := func() { 163 startResultChannel <- nil 164 } 165 166 mockServer := &dns.Server{PacketConn: conn, NotifyStartedFunc: started} 167 handle := &DnsMockHandle{id: time.Now().String(), mockServer: mockServer} 168 169 dnsMux := &dnsMockHandler{muDomainsToAddresses: sync.RWMutex{}} 170 171 if domainsMap != nil { 172 dnsMux.domainsToAddresses = domainsMap 173 } else { 174 dnsMux.domainsToAddresses = DomainsToAddresses 175 } 176 177 if domainsErrorMap != nil { 178 dnsMux.domainsToErrors = domainsErrorMap 179 } 180 181 mockServer.Handler = dnsMux 182 183 go func() { 184 startResultChannel <- mockServer.ActivateAndServe() 185 }() 186 187 err = <-startResultChannel 188 if err != nil { 189 close(startResultChannel) 190 return handle, err 191 } 192 193 muDefaultResolver.RLock() 194 defaultResolver := net.DefaultResolver 195 muDefaultResolver.RUnlock() 196 mockResolver := &net.Resolver{ 197 PreferGo: true, 198 Dial: func(ctx context.Context, network, address string) (net.Conn, error) { 199 d := net.Dialer{} 200 201 //Use write lock to prevent unsafe d.DialContext update of net.DefaultResolver 202 muDefaultResolver.Lock() 203 defer muDefaultResolver.Unlock() 204 return d.DialContext(ctx, network, mockServer.PacketConn.LocalAddr().String()) 205 }, 206 } 207 208 muDefaultResolver.Lock() 209 net.DefaultResolver = mockResolver 210 muDefaultResolver.Unlock() 211 212 handle.ShutdownDnsMock = func() error { 213 muDefaultResolver.Lock() 214 net.DefaultResolver = defaultResolver 215 muDefaultResolver.Unlock() 216 217 return mockServer.Shutdown() 218 } 219 220 return handle, nil 221 } 222 223 func IsDnsRecordsAddrsEqualsTo(itemAddrs, addrs []string) bool { 224 return reflect.DeepEqual(itemAddrs, addrs) 225 }