github.com/mtsmfm/go/src@v0.0.0-20221020090648-44bdcb9f8fde/net/resolverdialfunc_test.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build !js 6 7 // Test that Resolver.Dial can be a func returning an in-memory net.Conn 8 // speaking DNS. 9 10 package net 11 12 import ( 13 "bytes" 14 "context" 15 "errors" 16 "fmt" 17 "reflect" 18 "sort" 19 "testing" 20 "time" 21 22 "golang.org/x/net/dns/dnsmessage" 23 ) 24 25 func TestResolverDialFunc(t *testing.T) { 26 r := &Resolver{ 27 PreferGo: true, 28 Dial: newResolverDialFunc(&resolverDialHandler{ 29 StartDial: func(network, address string) error { 30 t.Logf("StartDial(%q, %q) ...", network, address) 31 return nil 32 }, 33 Question: func(h dnsmessage.Header, q dnsmessage.Question) { 34 t.Logf("Header: %+v for %q (type=%v, class=%v)", h, 35 q.Name.String(), q.Type, q.Class) 36 }, 37 // TODO: add test without HandleA* hooks specified at all, that Go 38 // doesn't issue retries; map to something terminal. 39 HandleA: func(w AWriter, name string) error { 40 w.AddIP([4]byte{1, 2, 3, 4}) 41 w.AddIP([4]byte{5, 6, 7, 8}) 42 return nil 43 }, 44 HandleAAAA: func(w AAAAWriter, name string) error { 45 w.AddIP([16]byte{1: 1, 15: 15}) 46 w.AddIP([16]byte{2: 2, 14: 14}) 47 return nil 48 }, 49 HandleSRV: func(w SRVWriter, name string) error { 50 w.AddSRV(1, 2, 80, "foo.bar.") 51 w.AddSRV(2, 3, 81, "bar.baz.") 52 return nil 53 }, 54 }), 55 } 56 ctx := context.Background() 57 const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." 58 59 t.Run("LookupIP", func(t *testing.T) { 60 ips, err := r.LookupIP(ctx, "ip", fakeDomain) 61 if err != nil { 62 t.Fatal(err) 63 } 64 if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) { 65 t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want) 66 } 67 }) 68 69 t.Run("LookupSRV", func(t *testing.T) { 70 _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain) 71 if err != nil { 72 t.Fatal(err) 73 } 74 want := []*SRV{ 75 { 76 Target: "foo.bar.", 77 Port: 80, 78 Priority: 1, 79 Weight: 2, 80 }, 81 { 82 Target: "bar.baz.", 83 Port: 81, 84 Priority: 2, 85 Weight: 3, 86 }, 87 } 88 if !reflect.DeepEqual(got, want) { 89 t.Errorf("wrong result. got:") 90 for _, r := range got { 91 t.Logf(" - %+v", r) 92 } 93 } 94 }) 95 } 96 97 func sortedIPStrings(ips []IP) []string { 98 ret := make([]string, len(ips)) 99 for i, ip := range ips { 100 ret[i] = ip.String() 101 } 102 sort.Strings(ret) 103 return ret 104 } 105 106 func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { 107 return func(ctx context.Context, network, address string) (Conn, error) { 108 a := &resolverFuncConn{ 109 h: h, 110 network: network, 111 address: address, 112 ttl: 10, // 10 second default if unset 113 } 114 if h.StartDial != nil { 115 if err := h.StartDial(network, address); err != nil { 116 return nil, err 117 } 118 } 119 return a, nil 120 } 121 } 122 123 type resolverDialHandler struct { 124 // StartDial, if non-nil, is called when Go first calls Resolver.Dial. 125 // Any error returned aborts the dial and is returned unwrapped. 126 StartDial func(network, address string) error 127 128 Question func(dnsmessage.Header, dnsmessage.Question) 129 130 // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). 131 // A nil error means success. 132 HandleA func(w AWriter, name string) error 133 HandleAAAA func(w AAAAWriter, name string) error 134 HandleSRV func(w SRVWriter, name string) error 135 } 136 137 type ResponseWriter struct{ a *resolverFuncConn } 138 139 func (w ResponseWriter) header() dnsmessage.ResourceHeader { 140 q := w.a.q 141 return dnsmessage.ResourceHeader{ 142 Name: q.Name, 143 Type: q.Type, 144 Class: q.Class, 145 TTL: w.a.ttl, 146 } 147 } 148 149 // SetTTL sets the TTL for subsequent written resources. 150 // Once a resource has been written, SetTTL calls are no-ops. 151 // That is, it can only be called at most once, before anything 152 // else is written. 153 func (w ResponseWriter) SetTTL(seconds uint32) { 154 // ... intention is last one wins and mutates all previously 155 // written records too, but that's a little annoying. 156 // But it's also annoying if the requirement is it needs to be set 157 // last. 158 // And it's also annoying if it's possible for users to set 159 // different TTLs per Answer. 160 if w.a.wrote { 161 return 162 } 163 w.a.ttl = seconds 164 165 } 166 167 type AWriter struct{ ResponseWriter } 168 169 func (w AWriter) AddIP(v4 [4]byte) { 170 w.a.wrote = true 171 err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) 172 if err != nil { 173 panic(err) 174 } 175 } 176 177 type AAAAWriter struct{ ResponseWriter } 178 179 func (w AAAAWriter) AddIP(v6 [16]byte) { 180 w.a.wrote = true 181 err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) 182 if err != nil { 183 panic(err) 184 } 185 } 186 187 type SRVWriter struct{ ResponseWriter } 188 189 // AddSRV adds a SRV record. The target name must end in a period and 190 // be 63 bytes or fewer. 191 func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { 192 targetName, err := dnsmessage.NewName(target) 193 if err != nil { 194 return err 195 } 196 w.a.wrote = true 197 err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ 198 Priority: priority, 199 Weight: weight, 200 Port: port, 201 Target: targetName, 202 }) 203 if err != nil { 204 panic(err) // internal fault, not user 205 } 206 return nil 207 } 208 209 var ( 210 ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN 211 ErrRefused = errors.New("refused") // maps to RCode5, REFUSED 212 ) 213 214 type resolverFuncConn struct { 215 h *resolverDialHandler 216 network string 217 address string 218 builder *dnsmessage.Builder 219 q dnsmessage.Question 220 ttl uint32 221 wrote bool 222 223 rbuf bytes.Buffer 224 } 225 226 func (*resolverFuncConn) Close() error { return nil } 227 func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } 228 func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } 229 func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } 230 func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } 231 func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } 232 233 func (a *resolverFuncConn) Read(p []byte) (n int, err error) { 234 return a.rbuf.Read(p) 235 } 236 237 func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { 238 if len(packet) < 2 { 239 return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) 240 } 241 reqLen := int(packet[0])<<8 | int(packet[1]) 242 req := packet[2:] 243 if len(req) != reqLen { 244 return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) 245 } 246 247 var parser dnsmessage.Parser 248 h, err := parser.Start(req) 249 if err != nil { 250 // TODO: hook 251 return 0, err 252 } 253 q, err := parser.Question() 254 hadQ := (err == nil) 255 if err == nil && a.h.Question != nil { 256 a.h.Question(h, q) 257 } 258 if err != nil && err != dnsmessage.ErrSectionDone { 259 return 0, err 260 } 261 262 resh := h 263 resh.Response = true 264 resh.Authoritative = true 265 if hadQ { 266 resh.RCode = dnsmessage.RCodeSuccess 267 } else { 268 resh.RCode = dnsmessage.RCodeNotImplemented 269 } 270 a.rbuf.Grow(514) 271 a.rbuf.WriteByte('X') // reserved header for beu16 length 272 a.rbuf.WriteByte('Y') // reserved header for beu16 length 273 builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) 274 a.builder = &builder 275 if hadQ { 276 a.q = q 277 a.builder.StartQuestions() 278 err := a.builder.Question(q) 279 if err != nil { 280 return 0, fmt.Errorf("Question: %w", err) 281 } 282 a.builder.StartAnswers() 283 switch q.Type { 284 case dnsmessage.TypeA: 285 if a.h.HandleA != nil { 286 resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) 287 } 288 case dnsmessage.TypeAAAA: 289 if a.h.HandleAAAA != nil { 290 resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) 291 } 292 case dnsmessage.TypeSRV: 293 if a.h.HandleSRV != nil { 294 resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) 295 } 296 } 297 } 298 tcpRes, err := builder.Finish() 299 if err != nil { 300 return 0, fmt.Errorf("Finish: %w", err) 301 } 302 303 n = len(tcpRes) - 2 304 tcpRes[0] = byte(n >> 8) 305 tcpRes[1] = byte(n) 306 a.rbuf.Write(tcpRes[2:]) 307 308 return len(packet), nil 309 } 310 311 type someaddr struct{} 312 313 func (someaddr) Network() string { return "unused" } 314 func (someaddr) String() string { return "unused-someaddr" } 315 316 func mapRCode(err error) dnsmessage.RCode { 317 switch err { 318 case nil: 319 return dnsmessage.RCodeSuccess 320 case ErrNotExist: 321 return dnsmessage.RCodeNameError 322 case ErrRefused: 323 return dnsmessage.RCodeRefused 324 default: 325 return dnsmessage.RCodeServerFailure 326 } 327 }