github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/netstack/tun.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package netstack 7 8 import ( 9 "bytes" 10 "context" 11 "crypto/rand" 12 "encoding/binary" 13 "errors" 14 "fmt" 15 "io" 16 "net" 17 "net/netip" 18 "os" 19 "regexp" 20 "strconv" 21 "strings" 22 "syscall" 23 "time" 24 25 "github.com/amnezia-vpn/amneziawg-go/tun" 26 27 "golang.org/x/net/dns/dnsmessage" 28 "gvisor.dev/gvisor/pkg/buffer" 29 "gvisor.dev/gvisor/pkg/tcpip" 30 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 31 "gvisor.dev/gvisor/pkg/tcpip/header" 32 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 33 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 34 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 35 "gvisor.dev/gvisor/pkg/tcpip/stack" 36 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 37 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 38 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 39 "gvisor.dev/gvisor/pkg/waiter" 40 ) 41 42 type netTun struct { 43 ep *channel.Endpoint 44 stack *stack.Stack 45 events chan tun.Event 46 incomingPacket chan *buffer.View 47 mtu int 48 dnsServers []netip.Addr 49 hasV4, hasV6 bool 50 } 51 52 type Net netTun 53 54 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { 55 opts := stack.Options{ 56 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 57 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, 58 HandleLocal: true, 59 } 60 dev := &netTun{ 61 ep: channel.New(1024, uint32(mtu), ""), 62 stack: stack.New(opts), 63 events: make(chan tun.Event, 10), 64 incomingPacket: make(chan *buffer.View), 65 dnsServers: dnsServers, 66 mtu: mtu, 67 } 68 sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default 69 tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) 70 if tcpipErr != nil { 71 return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) 72 } 73 dev.ep.AddNotify(dev) 74 tcpipErr = dev.stack.CreateNIC(1, dev.ep) 75 if tcpipErr != nil { 76 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 77 } 78 for _, ip := range localAddresses { 79 var protoNumber tcpip.NetworkProtocolNumber 80 if ip.Is4() { 81 protoNumber = ipv4.ProtocolNumber 82 } else if ip.Is6() { 83 protoNumber = ipv6.ProtocolNumber 84 } 85 protoAddr := tcpip.ProtocolAddress{ 86 Protocol: protoNumber, 87 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), 88 } 89 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 90 if tcpipErr != nil { 91 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 92 } 93 if ip.Is4() { 94 dev.hasV4 = true 95 } else if ip.Is6() { 96 dev.hasV6 = true 97 } 98 } 99 if dev.hasV4 { 100 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 101 } 102 if dev.hasV6 { 103 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 104 } 105 106 dev.events <- tun.EventUp 107 return dev, (*Net)(dev), nil 108 } 109 110 func (tun *netTun) Name() (string, error) { 111 return "go", nil 112 } 113 114 func (tun *netTun) File() *os.File { 115 return nil 116 } 117 118 func (tun *netTun) Events() <-chan tun.Event { 119 return tun.events 120 } 121 122 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { 123 view, ok := <-tun.incomingPacket 124 if !ok { 125 return 0, os.ErrClosed 126 } 127 128 n, err := view.Read(buf[0][offset:]) 129 if err != nil { 130 return 0, err 131 } 132 sizes[0] = n 133 return 1, nil 134 } 135 136 func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { 137 for _, buf := range buf { 138 packet := buf[offset:] 139 if len(packet) == 0 { 140 continue 141 } 142 143 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) 144 switch packet[0] >> 4 { 145 case 4: 146 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 147 case 6: 148 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 149 default: 150 return 0, syscall.EAFNOSUPPORT 151 } 152 } 153 return len(buf), nil 154 } 155 156 func (tun *netTun) WriteNotify() { 157 pkt := tun.ep.Read() 158 if pkt.IsNil() { 159 return 160 } 161 162 view := pkt.ToView() 163 pkt.DecRef() 164 165 tun.incomingPacket <- view 166 } 167 168 func (tun *netTun) Close() error { 169 tun.stack.RemoveNIC(1) 170 171 if tun.events != nil { 172 close(tun.events) 173 } 174 175 tun.ep.Close() 176 177 if tun.incomingPacket != nil { 178 close(tun.incomingPacket) 179 } 180 181 return nil 182 } 183 184 func (tun *netTun) MTU() (int, error) { 185 return tun.mtu, nil 186 } 187 188 func (tun *netTun) BatchSize() int { 189 return 1 190 } 191 192 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 193 var protoNumber tcpip.NetworkProtocolNumber 194 if endpoint.Addr().Is4() { 195 protoNumber = ipv4.ProtocolNumber 196 } else { 197 protoNumber = ipv6.ProtocolNumber 198 } 199 return tcpip.FullAddress{ 200 NIC: 1, 201 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), 202 Port: endpoint.Port(), 203 }, protoNumber 204 } 205 206 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 207 fa, pn := convertToFullAddr(addr) 208 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 209 } 210 211 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 212 if addr == nil { 213 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 214 } 215 ip, _ := netip.AddrFromSlice(addr.IP) 216 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 217 } 218 219 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 220 fa, pn := convertToFullAddr(addr) 221 return gonet.DialTCP(net.stack, fa, pn) 222 } 223 224 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 225 if addr == nil { 226 return net.DialTCPAddrPort(netip.AddrPort{}) 227 } 228 ip, _ := netip.AddrFromSlice(addr.IP) 229 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 230 } 231 232 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 233 fa, pn := convertToFullAddr(addr) 234 return gonet.ListenTCP(net.stack, fa, pn) 235 } 236 237 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 238 if addr == nil { 239 return net.ListenTCPAddrPort(netip.AddrPort{}) 240 } 241 ip, _ := netip.AddrFromSlice(addr.IP) 242 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 243 } 244 245 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 246 var lfa, rfa *tcpip.FullAddress 247 var pn tcpip.NetworkProtocolNumber 248 if laddr.IsValid() || laddr.Port() > 0 { 249 var addr tcpip.FullAddress 250 addr, pn = convertToFullAddr(laddr) 251 lfa = &addr 252 } 253 if raddr.IsValid() || raddr.Port() > 0 { 254 var addr tcpip.FullAddress 255 addr, pn = convertToFullAddr(raddr) 256 rfa = &addr 257 } 258 return gonet.DialUDP(net.stack, lfa, rfa, pn) 259 } 260 261 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 262 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 263 } 264 265 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 266 var la, ra netip.AddrPort 267 if laddr != nil { 268 ip, _ := netip.AddrFromSlice(laddr.IP) 269 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 270 } 271 if raddr != nil { 272 ip, _ := netip.AddrFromSlice(raddr.IP) 273 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 274 } 275 return net.DialUDPAddrPort(la, ra) 276 } 277 278 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 279 return net.DialUDP(laddr, nil) 280 } 281 282 type PingConn struct { 283 laddr PingAddr 284 raddr PingAddr 285 wq waiter.Queue 286 ep tcpip.Endpoint 287 deadline *time.Timer 288 } 289 290 type PingAddr struct{ addr netip.Addr } 291 292 func (ia PingAddr) String() string { 293 return ia.addr.String() 294 } 295 296 func (ia PingAddr) Network() string { 297 if ia.addr.Is4() { 298 return "ping4" 299 } else if ia.addr.Is6() { 300 return "ping6" 301 } 302 return "ping" 303 } 304 305 func (ia PingAddr) Addr() netip.Addr { 306 return ia.addr 307 } 308 309 func PingAddrFromAddr(addr netip.Addr) *PingAddr { 310 return &PingAddr{addr} 311 } 312 313 func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { 314 if !laddr.IsValid() && !raddr.IsValid() { 315 return nil, errors.New("ping dial: invalid address") 316 } 317 v6 := laddr.Is6() || raddr.Is6() 318 bind := laddr.IsValid() 319 if !bind { 320 if v6 { 321 laddr = netip.IPv6Unspecified() 322 } else { 323 laddr = netip.IPv4Unspecified() 324 } 325 } 326 327 tn := icmp.ProtocolNumber4 328 pn := ipv4.ProtocolNumber 329 if v6 { 330 tn = icmp.ProtocolNumber6 331 pn = ipv6.ProtocolNumber 332 } 333 334 pc := &PingConn{ 335 laddr: PingAddr{laddr}, 336 deadline: time.NewTimer(time.Hour << 10), 337 } 338 pc.deadline.Stop() 339 340 ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) 341 if tcpipErr != nil { 342 return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) 343 } 344 pc.ep = ep 345 346 if bind { 347 fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) 348 if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { 349 return nil, fmt.Errorf("ping bind: %s", tcpipErr) 350 } 351 } 352 353 if raddr.IsValid() { 354 pc.raddr = PingAddr{raddr} 355 fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) 356 if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { 357 return nil, fmt.Errorf("ping connect: %s", tcpipErr) 358 } 359 } 360 361 return pc, nil 362 } 363 364 func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { 365 return net.DialPingAddr(laddr, netip.Addr{}) 366 } 367 368 func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { 369 var la, ra netip.Addr 370 if laddr != nil { 371 la = laddr.addr 372 } 373 if raddr != nil { 374 ra = raddr.addr 375 } 376 return net.DialPingAddr(la, ra) 377 } 378 379 func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { 380 var la netip.Addr 381 if laddr != nil { 382 la = laddr.addr 383 } 384 return net.ListenPingAddr(la) 385 } 386 387 func (pc *PingConn) LocalAddr() net.Addr { 388 return pc.laddr 389 } 390 391 func (pc *PingConn) RemoteAddr() net.Addr { 392 return pc.raddr 393 } 394 395 func (pc *PingConn) Close() error { 396 pc.deadline.Reset(0) 397 pc.ep.Close() 398 return nil 399 } 400 401 func (pc *PingConn) SetWriteDeadline(t time.Time) error { 402 return errors.New("not implemented") 403 } 404 405 func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 406 var na netip.Addr 407 switch v := addr.(type) { 408 case *PingAddr: 409 na = v.addr 410 case *net.IPAddr: 411 na, _ = netip.AddrFromSlice(v.IP) 412 default: 413 return 0, fmt.Errorf("ping write: wrong net.Addr type") 414 } 415 if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { 416 return 0, fmt.Errorf("ping write: mismatched protocols") 417 } 418 419 buf := bytes.NewReader(p) 420 rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) 421 // won't block, no deadlines 422 n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ 423 To: &rfa, 424 }) 425 if tcpipErr != nil { 426 return int(n64), fmt.Errorf("ping write: %s", tcpipErr) 427 } 428 429 return int(n64), nil 430 } 431 432 func (pc *PingConn) Write(p []byte) (n int, err error) { 433 return pc.WriteTo(p, &pc.raddr) 434 } 435 436 func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 437 e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) 438 pc.wq.EventRegister(&e) 439 defer pc.wq.EventUnregister(&e) 440 441 select { 442 case <-pc.deadline.C: 443 return 0, nil, os.ErrDeadlineExceeded 444 case <-notifyCh: 445 } 446 447 w := tcpip.SliceWriter(p) 448 449 res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ 450 NeedRemoteAddr: true, 451 }) 452 if tcpipErr != nil { 453 return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) 454 } 455 456 remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) 457 return res.Count, &PingAddr{remoteAddr}, nil 458 } 459 460 func (pc *PingConn) Read(p []byte) (n int, err error) { 461 n, _, err = pc.ReadFrom(p) 462 return 463 } 464 465 func (pc *PingConn) SetDeadline(t time.Time) error { 466 // pc.SetWriteDeadline is unimplemented 467 468 return pc.SetReadDeadline(t) 469 } 470 471 func (pc *PingConn) SetReadDeadline(t time.Time) error { 472 pc.deadline.Reset(time.Until(t)) 473 return nil 474 } 475 476 var ( 477 errNoSuchHost = errors.New("no such host") 478 errLameReferral = errors.New("lame referral") 479 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") 480 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") 481 errServerMisbehaving = errors.New("server misbehaving") 482 errInvalidDNSResponse = errors.New("invalid DNS response") 483 errNoAnswerFromDNSServer = errors.New("no answer from DNS server") 484 errServerTemporarilyMisbehaving = errors.New("server misbehaving") 485 errCanceled = errors.New("operation was canceled") 486 errTimeout = errors.New("i/o timeout") 487 errNumericPort = errors.New("port must be numeric") 488 errNoSuitableAddress = errors.New("no suitable address found") 489 errMissingAddress = errors.New("missing address") 490 ) 491 492 func (net *Net) LookupHost(host string) (addrs []string, err error) { 493 return net.LookupContextHost(context.Background(), host) 494 } 495 496 func isDomainName(s string) bool { 497 l := len(s) 498 if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { 499 return false 500 } 501 last := byte('.') 502 nonNumeric := false 503 partlen := 0 504 for i := 0; i < len(s); i++ { 505 c := s[i] 506 switch { 507 default: 508 return false 509 case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': 510 nonNumeric = true 511 partlen++ 512 case '0' <= c && c <= '9': 513 partlen++ 514 case c == '-': 515 if last == '.' { 516 return false 517 } 518 partlen++ 519 nonNumeric = true 520 case c == '.': 521 if last == '.' || last == '-' { 522 return false 523 } 524 if partlen > 63 || partlen == 0 { 525 return false 526 } 527 partlen = 0 528 } 529 last = c 530 } 531 if last == '-' || partlen > 63 { 532 return false 533 } 534 return nonNumeric 535 } 536 537 func randU16() uint16 { 538 var b [2]byte 539 _, err := rand.Read(b[:]) 540 if err != nil { 541 panic(err) 542 } 543 return binary.LittleEndian.Uint16(b[:]) 544 } 545 546 func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { 547 id = randU16() 548 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) 549 b.EnableCompression() 550 if err := b.StartQuestions(); err != nil { 551 return 0, nil, nil, err 552 } 553 if err := b.Question(q); err != nil { 554 return 0, nil, nil, err 555 } 556 tcpReq, err = b.Finish() 557 udpReq = tcpReq[2:] 558 l := len(tcpReq) - 2 559 tcpReq[0] = byte(l >> 8) 560 tcpReq[1] = byte(l) 561 return id, udpReq, tcpReq, err 562 } 563 564 func equalASCIIName(x, y dnsmessage.Name) bool { 565 if x.Length != y.Length { 566 return false 567 } 568 for i := 0; i < int(x.Length); i++ { 569 a := x.Data[i] 570 b := y.Data[i] 571 if 'A' <= a && a <= 'Z' { 572 a += 0x20 573 } 574 if 'A' <= b && b <= 'Z' { 575 b += 0x20 576 } 577 if a != b { 578 return false 579 } 580 } 581 return true 582 } 583 584 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { 585 if !respHdr.Response { 586 return false 587 } 588 if reqID != respHdr.ID { 589 return false 590 } 591 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { 592 return false 593 } 594 return true 595 } 596 597 func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { 598 if _, err := c.Write(b); err != nil { 599 return dnsmessage.Parser{}, dnsmessage.Header{}, err 600 } 601 b = make([]byte, 512) 602 for { 603 n, err := c.Read(b) 604 if err != nil { 605 return dnsmessage.Parser{}, dnsmessage.Header{}, err 606 } 607 var p dnsmessage.Parser 608 h, err := p.Start(b[:n]) 609 if err != nil { 610 continue 611 } 612 q, err := p.Question() 613 if err != nil || !checkResponse(id, query, h, q) { 614 continue 615 } 616 return p, h, nil 617 } 618 } 619 620 func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { 621 if _, err := c.Write(b); err != nil { 622 return dnsmessage.Parser{}, dnsmessage.Header{}, err 623 } 624 b = make([]byte, 1280) 625 if _, err := io.ReadFull(c, b[:2]); err != nil { 626 return dnsmessage.Parser{}, dnsmessage.Header{}, err 627 } 628 l := int(b[0])<<8 | int(b[1]) 629 if l > len(b) { 630 b = make([]byte, l) 631 } 632 n, err := io.ReadFull(c, b[:l]) 633 if err != nil { 634 return dnsmessage.Parser{}, dnsmessage.Header{}, err 635 } 636 var p dnsmessage.Parser 637 h, err := p.Start(b[:n]) 638 if err != nil { 639 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage 640 } 641 q, err := p.Question() 642 if err != nil { 643 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage 644 } 645 if !checkResponse(id, query, h, q) { 646 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse 647 } 648 return p, h, nil 649 } 650 651 func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { 652 q.Class = dnsmessage.ClassINET 653 id, udpReq, tcpReq, err := newRequest(q) 654 if err != nil { 655 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage 656 } 657 658 for _, useUDP := range []bool{true, false} { 659 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) 660 defer cancel() 661 662 var c net.Conn 663 var err error 664 if useUDP { 665 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) 666 } else { 667 c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) 668 } 669 670 if err != nil { 671 return dnsmessage.Parser{}, dnsmessage.Header{}, err 672 } 673 if d, ok := ctx.Deadline(); ok && !d.IsZero() { 674 err := c.SetDeadline(d) 675 if err != nil { 676 return dnsmessage.Parser{}, dnsmessage.Header{}, err 677 } 678 } 679 var p dnsmessage.Parser 680 var h dnsmessage.Header 681 if useUDP { 682 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) 683 } else { 684 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) 685 } 686 c.Close() 687 if err != nil { 688 if err == context.Canceled { 689 err = errCanceled 690 } else if err == context.DeadlineExceeded { 691 err = errTimeout 692 } 693 return dnsmessage.Parser{}, dnsmessage.Header{}, err 694 } 695 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { 696 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse 697 } 698 if h.Truncated { 699 continue 700 } 701 return p, h, nil 702 } 703 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer 704 } 705 706 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { 707 if h.RCode == dnsmessage.RCodeNameError { 708 return errNoSuchHost 709 } 710 _, err := p.AnswerHeader() 711 if err != nil && err != dnsmessage.ErrSectionDone { 712 return errCannotUnmarshalDNSMessage 713 } 714 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { 715 return errLameReferral 716 } 717 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { 718 if h.RCode == dnsmessage.RCodeServerFailure { 719 return errServerTemporarilyMisbehaving 720 } 721 return errServerMisbehaving 722 } 723 return nil 724 } 725 726 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { 727 for { 728 h, err := p.AnswerHeader() 729 if err == dnsmessage.ErrSectionDone { 730 return errNoSuchHost 731 } 732 if err != nil { 733 return errCannotUnmarshalDNSMessage 734 } 735 if h.Type == qtype { 736 return nil 737 } 738 if err := p.SkipAnswer(); err != nil { 739 return errCannotUnmarshalDNSMessage 740 } 741 } 742 } 743 744 func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { 745 var lastErr error 746 747 n, err := dnsmessage.NewName(name) 748 if err != nil { 749 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage 750 } 751 q := dnsmessage.Question{ 752 Name: n, 753 Type: qtype, 754 Class: dnsmessage.ClassINET, 755 } 756 757 for i := 0; i < 2; i++ { 758 for _, server := range tnet.dnsServers { 759 p, h, err := tnet.exchange(ctx, server, q, time.Second*5) 760 if err != nil { 761 dnsErr := &net.DNSError{ 762 Err: err.Error(), 763 Name: name, 764 Server: server.String(), 765 } 766 if nerr, ok := err.(net.Error); ok && nerr.Timeout() { 767 dnsErr.IsTimeout = true 768 } 769 if _, ok := err.(*net.OpError); ok { 770 dnsErr.IsTemporary = true 771 } 772 lastErr = dnsErr 773 continue 774 } 775 776 if err := checkHeader(&p, h); err != nil { 777 dnsErr := &net.DNSError{ 778 Err: err.Error(), 779 Name: name, 780 Server: server.String(), 781 } 782 if err == errServerTemporarilyMisbehaving { 783 dnsErr.IsTemporary = true 784 } 785 if err == errNoSuchHost { 786 dnsErr.IsNotFound = true 787 return p, server.String(), dnsErr 788 } 789 lastErr = dnsErr 790 continue 791 } 792 793 err = skipToAnswer(&p, qtype) 794 if err == nil { 795 return p, server.String(), nil 796 } 797 lastErr = &net.DNSError{ 798 Err: err.Error(), 799 Name: name, 800 Server: server.String(), 801 } 802 if err == errNoSuchHost { 803 lastErr.(*net.DNSError).IsNotFound = true 804 return p, server.String(), lastErr 805 } 806 } 807 } 808 return dnsmessage.Parser{}, "", lastErr 809 } 810 811 func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { 812 if host == "" || (!tnet.hasV6 && !tnet.hasV4) { 813 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} 814 } 815 zlen := len(host) 816 if strings.IndexByte(host, ':') != -1 { 817 if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { 818 zlen = zidx 819 } 820 } 821 if ip, err := netip.ParseAddr(host[:zlen]); err == nil { 822 return []string{ip.String()}, nil 823 } 824 825 if !isDomainName(host) { 826 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} 827 } 828 type result struct { 829 p dnsmessage.Parser 830 server string 831 error 832 } 833 var addrsV4, addrsV6 []netip.Addr 834 lanes := 0 835 if tnet.hasV4 { 836 lanes++ 837 } 838 if tnet.hasV6 { 839 lanes++ 840 } 841 lane := make(chan result, lanes) 842 var lastErr error 843 if tnet.hasV4 { 844 go func() { 845 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) 846 lane <- result{p, server, err} 847 }() 848 } 849 if tnet.hasV6 { 850 go func() { 851 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) 852 lane <- result{p, server, err} 853 }() 854 } 855 for l := 0; l < lanes; l++ { 856 result := <-lane 857 if result.error != nil { 858 if lastErr == nil { 859 lastErr = result.error 860 } 861 continue 862 } 863 864 loop: 865 for { 866 h, err := result.p.AnswerHeader() 867 if err != nil && err != dnsmessage.ErrSectionDone { 868 lastErr = &net.DNSError{ 869 Err: errCannotMarshalDNSMessage.Error(), 870 Name: host, 871 Server: result.server, 872 } 873 } 874 if err != nil { 875 break 876 } 877 switch h.Type { 878 case dnsmessage.TypeA: 879 a, err := result.p.AResource() 880 if err != nil { 881 lastErr = &net.DNSError{ 882 Err: errCannotMarshalDNSMessage.Error(), 883 Name: host, 884 Server: result.server, 885 } 886 break loop 887 } 888 addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) 889 890 case dnsmessage.TypeAAAA: 891 aaaa, err := result.p.AAAAResource() 892 if err != nil { 893 lastErr = &net.DNSError{ 894 Err: errCannotMarshalDNSMessage.Error(), 895 Name: host, 896 Server: result.server, 897 } 898 break loop 899 } 900 addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) 901 902 default: 903 if err := result.p.SkipAnswer(); err != nil { 904 lastErr = &net.DNSError{ 905 Err: errCannotMarshalDNSMessage.Error(), 906 Name: host, 907 Server: result.server, 908 } 909 break loop 910 } 911 continue 912 } 913 } 914 } 915 // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled 916 var addrs []netip.Addr 917 if tnet.hasV6 { 918 addrs = append(addrsV6, addrsV4...) 919 } else { 920 addrs = append(addrsV4, addrsV6...) 921 } 922 923 if len(addrs) == 0 && lastErr != nil { 924 return nil, lastErr 925 } 926 saddrs := make([]string, 0, len(addrs)) 927 for _, ip := range addrs { 928 saddrs = append(saddrs, ip.String()) 929 } 930 return saddrs, nil 931 } 932 933 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { 934 if deadline.IsZero() { 935 return deadline, nil 936 } 937 timeRemaining := deadline.Sub(now) 938 if timeRemaining <= 0 { 939 return time.Time{}, errTimeout 940 } 941 timeout := timeRemaining / time.Duration(addrsRemaining) 942 const saneMinimum = 2 * time.Second 943 if timeout < saneMinimum { 944 if timeRemaining < saneMinimum { 945 timeout = timeRemaining 946 } else { 947 timeout = saneMinimum 948 } 949 } 950 return now.Add(timeout), nil 951 } 952 953 var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) 954 955 func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 956 if ctx == nil { 957 panic("nil context") 958 } 959 var acceptV4, acceptV6 bool 960 matches := protoSplitter.FindStringSubmatch(network) 961 if matches == nil { 962 return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} 963 } else if len(matches[2]) == 0 { 964 acceptV4 = true 965 acceptV6 = true 966 } else { 967 acceptV4 = matches[2][0] == '4' 968 acceptV6 = !acceptV4 969 } 970 var host string 971 var port int 972 if matches[1] == "ping" { 973 host = address 974 } else { 975 var sport string 976 var err error 977 host, sport, err = net.SplitHostPort(address) 978 if err != nil { 979 return nil, &net.OpError{Op: "dial", Err: err} 980 } 981 port, err = strconv.Atoi(sport) 982 if err != nil || port < 0 || port > 65535 { 983 return nil, &net.OpError{Op: "dial", Err: errNumericPort} 984 } 985 } 986 allAddr, err := tnet.LookupContextHost(ctx, host) 987 if err != nil { 988 return nil, &net.OpError{Op: "dial", Err: err} 989 } 990 var addrs []netip.AddrPort 991 for _, addr := range allAddr { 992 ip, err := netip.ParseAddr(addr) 993 if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { 994 addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) 995 } 996 } 997 if len(addrs) == 0 && len(allAddr) != 0 { 998 return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} 999 } 1000 1001 var firstErr error 1002 for i, addr := range addrs { 1003 select { 1004 case <-ctx.Done(): 1005 err := ctx.Err() 1006 if err == context.Canceled { 1007 err = errCanceled 1008 } else if err == context.DeadlineExceeded { 1009 err = errTimeout 1010 } 1011 return nil, &net.OpError{Op: "dial", Err: err} 1012 default: 1013 } 1014 1015 dialCtx := ctx 1016 if deadline, hasDeadline := ctx.Deadline(); hasDeadline { 1017 partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) 1018 if err != nil { 1019 if firstErr == nil { 1020 firstErr = &net.OpError{Op: "dial", Err: err} 1021 } 1022 break 1023 } 1024 if partialDeadline.Before(deadline) { 1025 var cancel context.CancelFunc 1026 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) 1027 defer cancel() 1028 } 1029 } 1030 1031 var c net.Conn 1032 switch matches[1] { 1033 case "tcp": 1034 c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) 1035 case "udp": 1036 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) 1037 case "ping": 1038 c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) 1039 } 1040 if err == nil { 1041 return c, nil 1042 } 1043 if firstErr == nil { 1044 firstErr = err 1045 } 1046 } 1047 if firstErr == nil { 1048 firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} 1049 } 1050 return nil, firstErr 1051 } 1052 1053 func (tnet *Net) Dial(network, address string) (net.Conn, error) { 1054 return tnet.DialContext(context.Background(), network, address) 1055 }