github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/endpoint/netstack/netstack.go (about) 1 /* 2 * Copyright (C) 2022 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 /* SPDX-License-Identifier: MIT 19 * 20 * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. 21 */ 22 23 package netstack 24 25 import ( 26 "bytes" 27 "context" 28 "crypto/rand" 29 "encoding/binary" 30 "errors" 31 "fmt" 32 "io" 33 "net" 34 "net/netip" 35 "os" 36 "regexp" 37 "strconv" 38 "strings" 39 "syscall" 40 "time" 41 42 "golang.org/x/net/dns/dnsmessage" 43 "golang.zx2c4.com/wireguard/tun" 44 45 "gvisor.dev/gvisor/pkg/buffer" 46 "gvisor.dev/gvisor/pkg/refs" 47 "gvisor.dev/gvisor/pkg/tcpip" 48 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 49 "gvisor.dev/gvisor/pkg/tcpip/header" 50 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 51 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 52 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 53 "gvisor.dev/gvisor/pkg/tcpip/stack" 54 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 55 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 56 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 57 "gvisor.dev/gvisor/pkg/waiter" 58 ) 59 60 // MaxDNSMsgSize defines UDP recv buffer for resolver and announced 61 // EDNS0 buffer size 62 const MaxDNSResponseMsgSize = 4096 63 64 type netTun struct { 65 ep *channel.Endpoint 66 stack *stack.Stack 67 dispatcher stack.NetworkDispatcher 68 events chan tun.Event 69 incomingPacket chan *buffer.View 70 mtu int 71 dnsServers []netip.Addr 72 hasV4, hasV6 bool 73 } 74 75 type ( 76 endpoint netTun 77 Net netTun 78 ) 79 80 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { 81 refs.SetLeakMode(refs.NoLeakChecking) 82 83 opts := stack.Options{ 84 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 85 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, 86 HandleLocal: true, 87 } 88 dev := &netTun{ 89 ep: channel.New(1024, uint32(mtu), ""), 90 stack: stack.New(opts), 91 events: make(chan tun.Event, 10), 92 incomingPacket: make(chan *buffer.View), 93 dnsServers: dnsServers, 94 mtu: mtu, 95 } 96 97 sack := tcpip.TCPSACKEnabled(true) 98 if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sack); err != nil { 99 return nil, nil, fmt.Errorf("enabling SACK failed: %v", err) 100 } 101 102 moderation := tcpip.TCPModerateReceiveBufferOption(true) 103 if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &moderation); err != nil { 104 return nil, nil, fmt.Errorf("failed to set TCP moderate receive buffer option: %v", err) 105 } 106 107 dev.ep.AddNotify(dev) 108 tcpipErr := dev.stack.CreateNIC(1, dev.ep) 109 if tcpipErr != nil { 110 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 111 } 112 for _, ip := range localAddresses { 113 var protoNumber tcpip.NetworkProtocolNumber 114 if ip.Is4() { 115 protoNumber = ipv4.ProtocolNumber 116 } else if ip.Is6() { 117 protoNumber = ipv6.ProtocolNumber 118 } 119 protoAddr := tcpip.ProtocolAddress{ 120 Protocol: protoNumber, 121 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), 122 } 123 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 124 if tcpipErr != nil { 125 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 126 } 127 if ip.Is4() { 128 dev.hasV4 = true 129 } else if ip.Is6() { 130 dev.hasV6 = true 131 } 132 } 133 if dev.hasV4 { 134 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 135 } 136 if dev.hasV6 { 137 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 138 } 139 140 dev.events <- tun.EventUp 141 return dev, (*Net)(dev), nil 142 } 143 144 func (tun *netTun) BatchSize() int { 145 return 1 146 } 147 148 func (tun *netTun) Name() (string, error) { 149 return "go", nil 150 } 151 152 func (tun *netTun) File() *os.File { 153 return nil 154 } 155 156 func (tun *netTun) Events() <-chan tun.Event { 157 return tun.events 158 } 159 160 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { 161 view, ok := <-tun.incomingPacket 162 if !ok { 163 return 0, os.ErrClosed 164 } 165 166 n, err := view.Read(buf[0][offset:]) 167 if err != nil { 168 return 0, err 169 } 170 sizes[0] = n 171 return 1, nil 172 } 173 174 func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { 175 for _, buf := range buf { 176 packet := buf[offset:] 177 if len(packet) == 0 { 178 continue 179 } 180 181 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) 182 switch packet[0] >> 4 { 183 case 4: 184 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 185 case 6: 186 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 187 default: 188 return 0, syscall.EAFNOSUPPORT 189 } 190 } 191 return len(buf), nil 192 } 193 194 func (tun *netTun) WriteNotify() { 195 pkt := tun.ep.Read() 196 if pkt.IsNil() { 197 return 198 } 199 200 view := pkt.ToView() 201 pkt.DecRef() 202 203 tun.incomingPacket <- view 204 } 205 206 func (tun *netTun) Flush() error { 207 return nil 208 } 209 210 func (tun *netTun) Close() error { 211 tun.stack.RemoveNIC(1) 212 tun.stack.Close() 213 214 if tun.events != nil { 215 close(tun.events) 216 } 217 218 tun.ep.Close() 219 if tun.incomingPacket != nil { 220 close(tun.incomingPacket) 221 } 222 return nil 223 } 224 225 func (tun *netTun) MTU() (int, error) { 226 return tun.mtu, nil 227 } 228 229 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 230 var protoNumber tcpip.NetworkProtocolNumber 231 if endpoint.Addr().Is4() { 232 protoNumber = ipv4.ProtocolNumber 233 } else { 234 protoNumber = ipv6.ProtocolNumber 235 } 236 return tcpip.FullAddress{ 237 NIC: 1, 238 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), 239 Port: endpoint.Port(), 240 }, protoNumber 241 } 242 243 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 244 fa, pn := convertToFullAddr(addr) 245 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 246 } 247 248 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 249 if addr == nil { 250 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 251 } 252 ip, _ := netip.AddrFromSlice(addr.IP) 253 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 254 } 255 256 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 257 fa, pn := convertToFullAddr(addr) 258 return gonet.DialTCP(net.stack, fa, pn) 259 } 260 261 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 262 if addr == nil { 263 return net.DialTCPAddrPort(netip.AddrPort{}) 264 } 265 ip, _ := netip.AddrFromSlice(addr.IP) 266 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 267 } 268 269 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 270 fa, pn := convertToFullAddr(addr) 271 return gonet.ListenTCP(net.stack, fa, pn) 272 } 273 274 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 275 if addr == nil { 276 return net.ListenTCPAddrPort(netip.AddrPort{}) 277 } 278 ip, _ := netip.AddrFromSlice(addr.IP) 279 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 280 } 281 282 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 283 var lfa, rfa *tcpip.FullAddress 284 var pn tcpip.NetworkProtocolNumber 285 if laddr.IsValid() || laddr.Port() > 0 { 286 var addr tcpip.FullAddress 287 addr, pn = convertToFullAddr(laddr) 288 lfa = &addr 289 } 290 if raddr.IsValid() || raddr.Port() > 0 { 291 var addr tcpip.FullAddress 292 addr, pn = convertToFullAddr(raddr) 293 rfa = &addr 294 } 295 return gonet.DialUDP(net.stack, lfa, rfa, pn) 296 } 297 298 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 299 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 300 } 301 302 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 303 var la, ra netip.AddrPort 304 if laddr != nil { 305 ip, _ := netip.AddrFromSlice(laddr.IP) 306 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 307 } 308 if raddr != nil { 309 ip, _ := netip.AddrFromSlice(raddr.IP) 310 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 311 } 312 return net.DialUDPAddrPort(la, ra) 313 } 314 315 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 316 return net.DialUDP(laddr, nil) 317 } 318 319 type PingConn struct { 320 laddr PingAddr 321 raddr PingAddr 322 wq waiter.Queue 323 ep tcpip.Endpoint 324 deadline *time.Timer 325 } 326 327 type PingAddr struct{ addr netip.Addr } 328 329 func (ia PingAddr) String() string { 330 return ia.addr.String() 331 } 332 333 func (ia PingAddr) Network() string { 334 if ia.addr.Is4() { 335 return "ping4" 336 } else if ia.addr.Is6() { 337 return "ping6" 338 } 339 return "ping" 340 } 341 342 func (ia PingAddr) Addr() netip.Addr { 343 return ia.addr 344 } 345 346 func PingAddrFromAddr(addr netip.Addr) *PingAddr { 347 return &PingAddr{addr} 348 } 349 350 func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { 351 if !laddr.IsValid() && !raddr.IsValid() { 352 return nil, errors.New("ping dial: invalid address") 353 } 354 v6 := laddr.Is6() || raddr.Is6() 355 bind := laddr.IsValid() 356 if !bind { 357 if v6 { 358 laddr = netip.IPv6Unspecified() 359 } else { 360 laddr = netip.IPv4Unspecified() 361 } 362 } 363 364 tn := icmp.ProtocolNumber4 365 pn := ipv4.ProtocolNumber 366 if v6 { 367 tn = icmp.ProtocolNumber6 368 pn = ipv6.ProtocolNumber 369 } 370 371 pc := &PingConn{ 372 laddr: PingAddr{laddr}, 373 deadline: time.NewTimer(time.Hour << 10), 374 } 375 pc.deadline.Stop() 376 377 ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) 378 if tcpipErr != nil { 379 return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) 380 } 381 pc.ep = ep 382 383 if bind { 384 fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) 385 if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { 386 return nil, fmt.Errorf("ping bind: %s", tcpipErr) 387 } 388 } 389 390 if raddr.IsValid() { 391 pc.raddr = PingAddr{raddr} 392 fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) 393 if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { 394 return nil, fmt.Errorf("ping connect: %s", tcpipErr) 395 } 396 } 397 398 return pc, nil 399 } 400 401 func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { 402 return net.DialPingAddr(laddr, netip.Addr{}) 403 } 404 405 func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { 406 var la, ra netip.Addr 407 if laddr != nil { 408 la = laddr.addr 409 } 410 if raddr != nil { 411 ra = raddr.addr 412 } 413 return net.DialPingAddr(la, ra) 414 } 415 416 func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { 417 var la netip.Addr 418 if laddr != nil { 419 la = laddr.addr 420 } 421 return net.ListenPingAddr(la) 422 } 423 424 func (pc *PingConn) LocalAddr() net.Addr { 425 return pc.laddr 426 } 427 428 func (pc *PingConn) RemoteAddr() net.Addr { 429 return pc.raddr 430 } 431 432 func (pc *PingConn) Close() error { 433 pc.deadline.Reset(0) 434 pc.ep.Close() 435 return nil 436 } 437 438 func (pc *PingConn) SetWriteDeadline(t time.Time) error { 439 return errors.New("not implemented") 440 } 441 442 func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 443 var na netip.Addr 444 switch v := addr.(type) { 445 case *PingAddr: 446 na = v.addr 447 case *net.IPAddr: 448 na, _ = netip.AddrFromSlice(v.IP) 449 default: 450 return 0, fmt.Errorf("ping write: wrong net.Addr type") 451 } 452 if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { 453 return 0, fmt.Errorf("ping write: mismatched protocols") 454 } 455 456 buf := bytes.NewBuffer(p) 457 rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) 458 // won't block, no deadlines 459 n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ 460 To: &rfa, 461 }) 462 if tcpipErr != nil { 463 return int(n64), fmt.Errorf("ping write: %s", tcpipErr) 464 } 465 466 return int(n64), nil 467 } 468 469 func (pc *PingConn) Write(p []byte) (n int, err error) { 470 return pc.WriteTo(p, &pc.raddr) 471 } 472 473 func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 474 e, notifyCh := waiter.NewChannelEntry(0) 475 pc.wq.EventRegister(&e) 476 defer pc.wq.EventUnregister(&e) 477 478 select { 479 case <-pc.deadline.C: 480 return 0, nil, os.ErrDeadlineExceeded 481 case <-notifyCh: 482 } 483 484 w := tcpip.SliceWriter(p) 485 486 res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ 487 NeedRemoteAddr: true, 488 }) 489 if tcpipErr != nil { 490 return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) 491 } 492 493 remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) 494 return res.Count, &PingAddr{remoteAddr}, nil 495 } 496 497 func (pc *PingConn) Read(p []byte) (n int, err error) { 498 n, _, err = pc.ReadFrom(p) 499 return 500 } 501 502 func (pc *PingConn) SetDeadline(t time.Time) error { 503 // pc.SetWriteDeadline is unimplemented 504 505 return pc.SetReadDeadline(t) 506 } 507 508 func (pc *PingConn) SetReadDeadline(t time.Time) error { 509 pc.deadline.Reset(t.Sub(time.Now())) 510 return nil 511 } 512 513 var ( 514 errNoSuchHost = errors.New("no such host") 515 errLameReferral = errors.New("lame referral") 516 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") 517 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") 518 errServerMisbehaving = errors.New("server misbehaving") 519 errInvalidDNSResponse = errors.New("invalid DNS response") 520 errNoAnswerFromDNSServer = errors.New("no answer from DNS server") 521 errServerTemporarilyMisbehaving = errors.New("server misbehaving") 522 errCanceled = errors.New("operation was canceled") 523 errTimeout = errors.New("i/o timeout") 524 errNumericPort = errors.New("port must be numeric") 525 errNoSuitableAddress = errors.New("no suitable address found") 526 errMissingAddress = errors.New("missing address") 527 ) 528 529 func (net *Net) LookupHost(host string) (addrs []string, err error) { 530 return net.LookupContextHost(context.Background(), host) 531 } 532 533 func isDomainName(s string) bool { 534 l := len(s) 535 if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { 536 return false 537 } 538 last := byte('.') 539 nonNumeric := false 540 partlen := 0 541 for i := 0; i < len(s); i++ { 542 c := s[i] 543 switch { 544 default: 545 return false 546 case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': 547 nonNumeric = true 548 partlen++ 549 case '0' <= c && c <= '9': 550 partlen++ 551 case c == '-': 552 if last == '.' { 553 return false 554 } 555 partlen++ 556 nonNumeric = true 557 case c == '.': 558 if last == '.' || last == '-' { 559 return false 560 } 561 if partlen > 63 || partlen == 0 { 562 return false 563 } 564 partlen = 0 565 } 566 last = c 567 } 568 if last == '-' || partlen > 63 { 569 return false 570 } 571 return nonNumeric 572 } 573 574 func randU16() uint16 { 575 var b [2]byte 576 _, err := rand.Read(b[:]) 577 if err != nil { 578 panic(err) 579 } 580 return binary.LittleEndian.Uint16(b[:]) 581 } 582 583 func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { 584 id = randU16() 585 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) 586 b.EnableCompression() 587 if err := b.StartQuestions(); err != nil { 588 return 0, nil, nil, err 589 } 590 if err := b.Question(q); err != nil { 591 return 0, nil, nil, err 592 } 593 // Accept packets up to MaxDNSResponseMsgSize. RFC 6891. 594 if err := b.StartAdditionals(); err != nil { 595 return 0, nil, nil, err 596 } 597 var rh dnsmessage.ResourceHeader 598 if err := rh.SetEDNS0(MaxDNSResponseMsgSize, dnsmessage.RCodeSuccess, false); err != nil { 599 return 0, nil, nil, err 600 } 601 if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil { 602 return 0, nil, nil, err 603 } 604 tcpReq, err = b.Finish() 605 udpReq = tcpReq[2:] 606 l := len(tcpReq) - 2 607 tcpReq[0] = byte(l >> 8) 608 tcpReq[1] = byte(l) 609 return id, udpReq, tcpReq, err 610 } 611 612 func equalASCIIName(x, y dnsmessage.Name) bool { 613 if x.Length != y.Length { 614 return false 615 } 616 for i := 0; i < int(x.Length); i++ { 617 a := x.Data[i] 618 b := y.Data[i] 619 if 'A' <= a && a <= 'Z' { 620 a += 0x20 621 } 622 if 'A' <= b && b <= 'Z' { 623 b += 0x20 624 } 625 if a != b { 626 return false 627 } 628 } 629 return true 630 } 631 632 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { 633 if !respHdr.Response { 634 return false 635 } 636 if reqID != respHdr.ID { 637 return false 638 } 639 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { 640 return false 641 } 642 return true 643 } 644 645 func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { 646 if _, err := c.Write(b); err != nil { 647 return dnsmessage.Parser{}, dnsmessage.Header{}, err 648 } 649 b = make([]byte, MaxDNSResponseMsgSize) 650 for { 651 n, err := c.Read(b) 652 if err != nil { 653 return dnsmessage.Parser{}, dnsmessage.Header{}, err 654 } 655 var p dnsmessage.Parser 656 h, err := p.Start(b[:n]) 657 if err != nil { 658 continue 659 } 660 q, err := p.Question() 661 if err != nil || !checkResponse(id, query, h, q) { 662 continue 663 } 664 return p, h, nil 665 } 666 } 667 668 func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { 669 if _, err := c.Write(b); err != nil { 670 return dnsmessage.Parser{}, dnsmessage.Header{}, err 671 } 672 b = make([]byte, 1280) 673 if _, err := io.ReadFull(c, b[:2]); err != nil { 674 return dnsmessage.Parser{}, dnsmessage.Header{}, err 675 } 676 l := int(b[0])<<8 | int(b[1]) 677 if l > len(b) { 678 b = make([]byte, l) 679 } 680 n, err := io.ReadFull(c, b[:l]) 681 if err != nil { 682 return dnsmessage.Parser{}, dnsmessage.Header{}, err 683 } 684 var p dnsmessage.Parser 685 h, err := p.Start(b[:n]) 686 if err != nil { 687 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage 688 } 689 q, err := p.Question() 690 if err != nil { 691 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage 692 } 693 if !checkResponse(id, query, h, q) { 694 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse 695 } 696 return p, h, nil 697 } 698 699 func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { 700 q.Class = dnsmessage.ClassINET 701 id, udpReq, tcpReq, err := newRequest(q) 702 if err != nil { 703 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage 704 } 705 706 for _, useUDP := range []bool{true, false} { 707 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) 708 defer cancel() 709 710 var c net.Conn 711 var err error 712 if useUDP { 713 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) 714 } else { 715 c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) 716 } 717 718 if err != nil { 719 return dnsmessage.Parser{}, dnsmessage.Header{}, err 720 } 721 if d, ok := ctx.Deadline(); ok && !d.IsZero() { 722 err := c.SetDeadline(d) 723 if err != nil { 724 return dnsmessage.Parser{}, dnsmessage.Header{}, err 725 } 726 } 727 var p dnsmessage.Parser 728 var h dnsmessage.Header 729 if useUDP { 730 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) 731 } else { 732 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) 733 } 734 c.Close() 735 if err != nil { 736 if err == context.Canceled { 737 err = errCanceled 738 } else if err == context.DeadlineExceeded { 739 err = errTimeout 740 } 741 return dnsmessage.Parser{}, dnsmessage.Header{}, err 742 } 743 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { 744 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse 745 } 746 if h.Truncated { 747 continue 748 } 749 return p, h, nil 750 } 751 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer 752 } 753 754 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { 755 if h.RCode == dnsmessage.RCodeNameError { 756 return errNoSuchHost 757 } 758 _, err := p.AnswerHeader() 759 if err != nil && err != dnsmessage.ErrSectionDone { 760 return errCannotUnmarshalDNSMessage 761 } 762 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { 763 return errLameReferral 764 } 765 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { 766 if h.RCode == dnsmessage.RCodeServerFailure { 767 return errServerTemporarilyMisbehaving 768 } 769 return errServerMisbehaving 770 } 771 return nil 772 } 773 774 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { 775 for { 776 h, err := p.AnswerHeader() 777 if err == dnsmessage.ErrSectionDone { 778 return errNoSuchHost 779 } 780 if err != nil { 781 return errCannotUnmarshalDNSMessage 782 } 783 if h.Type == qtype { 784 return nil 785 } 786 if err := p.SkipAnswer(); err != nil { 787 return errCannotUnmarshalDNSMessage 788 } 789 } 790 } 791 792 func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { 793 var lastErr error 794 795 n, err := dnsmessage.NewName(name) 796 if err != nil { 797 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage 798 } 799 q := dnsmessage.Question{ 800 Name: n, 801 Type: qtype, 802 Class: dnsmessage.ClassINET, 803 } 804 805 for i := 0; i < 2; i++ { 806 for _, server := range tnet.dnsServers { 807 p, h, err := tnet.exchange(ctx, server, q, time.Second*5) 808 if err != nil { 809 dnsErr := &net.DNSError{ 810 Err: err.Error(), 811 Name: name, 812 Server: server.String(), 813 } 814 if nerr, ok := err.(net.Error); ok && nerr.Timeout() { 815 dnsErr.IsTimeout = true 816 } 817 if _, ok := err.(*net.OpError); ok { 818 dnsErr.IsTemporary = true 819 } 820 lastErr = dnsErr 821 continue 822 } 823 824 if err := checkHeader(&p, h); err != nil { 825 dnsErr := &net.DNSError{ 826 Err: err.Error(), 827 Name: name, 828 Server: server.String(), 829 } 830 if err == errServerTemporarilyMisbehaving { 831 dnsErr.IsTemporary = true 832 } 833 if err == errNoSuchHost { 834 dnsErr.IsNotFound = true 835 return p, server.String(), dnsErr 836 } 837 lastErr = dnsErr 838 continue 839 } 840 841 err = skipToAnswer(&p, qtype) 842 if err == nil { 843 return p, server.String(), nil 844 } 845 lastErr = &net.DNSError{ 846 Err: err.Error(), 847 Name: name, 848 Server: server.String(), 849 } 850 if err == errNoSuchHost { 851 lastErr.(*net.DNSError).IsNotFound = true 852 return p, server.String(), lastErr 853 } 854 } 855 } 856 return dnsmessage.Parser{}, "", lastErr 857 } 858 859 func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { 860 if host == "" || (!tnet.hasV6 && !tnet.hasV4) { 861 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} 862 } 863 zlen := len(host) 864 if strings.IndexByte(host, ':') != -1 { 865 if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { 866 zlen = zidx 867 } 868 } 869 if ip, err := netip.ParseAddr(host[:zlen]); err == nil { 870 return []string{ip.String()}, nil 871 } 872 873 if !isDomainName(host) { 874 return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} 875 } 876 type result struct { 877 p dnsmessage.Parser 878 server string 879 error 880 } 881 var addrsV4, addrsV6 []netip.Addr 882 lanes := 0 883 if tnet.hasV4 { 884 lanes++ 885 } 886 if tnet.hasV6 { 887 lanes++ 888 } 889 lane := make(chan result, lanes) 890 var lastErr error 891 if tnet.hasV4 { 892 go func() { 893 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) 894 lane <- result{p, server, err} 895 }() 896 } 897 if tnet.hasV6 { 898 go func() { 899 p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) 900 lane <- result{p, server, err} 901 }() 902 } 903 for l := 0; l < lanes; l++ { 904 result := <-lane 905 if result.error != nil { 906 if lastErr == nil { 907 lastErr = result.error 908 } 909 continue 910 } 911 912 loop: 913 for { 914 h, err := result.p.AnswerHeader() 915 if err != nil && err != dnsmessage.ErrSectionDone { 916 lastErr = &net.DNSError{ 917 Err: errCannotMarshalDNSMessage.Error(), 918 Name: host, 919 Server: result.server, 920 } 921 } 922 if err != nil { 923 break 924 } 925 switch h.Type { 926 case dnsmessage.TypeA: 927 a, err := result.p.AResource() 928 if err != nil { 929 lastErr = &net.DNSError{ 930 Err: errCannotMarshalDNSMessage.Error(), 931 Name: host, 932 Server: result.server, 933 } 934 break loop 935 } 936 addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) 937 938 case dnsmessage.TypeAAAA: 939 aaaa, err := result.p.AAAAResource() 940 if err != nil { 941 lastErr = &net.DNSError{ 942 Err: errCannotMarshalDNSMessage.Error(), 943 Name: host, 944 Server: result.server, 945 } 946 break loop 947 } 948 addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) 949 950 default: 951 if err := result.p.SkipAnswer(); err != nil { 952 lastErr = &net.DNSError{ 953 Err: errCannotMarshalDNSMessage.Error(), 954 Name: host, 955 Server: result.server, 956 } 957 break loop 958 } 959 continue 960 } 961 } 962 } 963 // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled 964 var addrs []netip.Addr 965 if tnet.hasV6 { 966 addrs = append(addrsV6, addrsV4...) 967 } else { 968 addrs = append(addrsV4, addrsV6...) 969 } 970 971 if len(addrs) == 0 && lastErr != nil { 972 return nil, lastErr 973 } 974 saddrs := make([]string, 0, len(addrs)) 975 for _, ip := range addrs { 976 saddrs = append(saddrs, ip.String()) 977 } 978 return saddrs, nil 979 } 980 981 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { 982 if deadline.IsZero() { 983 return deadline, nil 984 } 985 timeRemaining := deadline.Sub(now) 986 if timeRemaining <= 0 { 987 return time.Time{}, errTimeout 988 } 989 timeout := timeRemaining / time.Duration(addrsRemaining) 990 const saneMinimum = 2 * time.Second 991 if timeout < saneMinimum { 992 if timeRemaining < saneMinimum { 993 timeout = timeRemaining 994 } else { 995 timeout = saneMinimum 996 } 997 } 998 return now.Add(timeout), nil 999 } 1000 1001 var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) 1002 1003 func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 1004 if ctx == nil { 1005 panic("nil context") 1006 } 1007 var acceptV4, acceptV6 bool 1008 matches := protoSplitter.FindStringSubmatch(network) 1009 if matches == nil { 1010 return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} 1011 } else if len(matches[2]) == 0 { 1012 acceptV4 = true 1013 acceptV6 = true 1014 } else { 1015 acceptV4 = matches[2][0] == '4' 1016 acceptV6 = !acceptV4 1017 } 1018 var host string 1019 var port int 1020 if matches[1] == "ping" { 1021 host = address 1022 } else { 1023 var sport string 1024 var err error 1025 host, sport, err = net.SplitHostPort(address) 1026 if err != nil { 1027 return nil, &net.OpError{Op: "dial", Err: err} 1028 } 1029 port, err = strconv.Atoi(sport) 1030 if err != nil || port < 0 || port > 65535 { 1031 return nil, &net.OpError{Op: "dial", Err: errNumericPort} 1032 } 1033 } 1034 allAddr, err := tnet.LookupContextHost(ctx, host) 1035 if err != nil { 1036 return nil, &net.OpError{Op: "dial", Err: err} 1037 } 1038 var addrs []netip.AddrPort 1039 for _, addr := range allAddr { 1040 ip, err := netip.ParseAddr(addr) 1041 if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { 1042 addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) 1043 } 1044 } 1045 if len(addrs) == 0 && len(allAddr) != 0 { 1046 return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} 1047 } 1048 1049 var firstErr error 1050 for i, addr := range addrs { 1051 select { 1052 case <-ctx.Done(): 1053 err := ctx.Err() 1054 if err == context.Canceled { 1055 err = errCanceled 1056 } else if err == context.DeadlineExceeded { 1057 err = errTimeout 1058 } 1059 return nil, &net.OpError{Op: "dial", Err: err} 1060 default: 1061 } 1062 1063 dialCtx := ctx 1064 if deadline, hasDeadline := ctx.Deadline(); hasDeadline { 1065 partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) 1066 if err != nil { 1067 if firstErr == nil { 1068 firstErr = &net.OpError{Op: "dial", Err: err} 1069 } 1070 break 1071 } 1072 if partialDeadline.Before(deadline) { 1073 var cancel context.CancelFunc 1074 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) 1075 defer cancel() 1076 } 1077 } 1078 1079 var c net.Conn 1080 switch matches[1] { 1081 case "tcp": 1082 c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) 1083 case "udp": 1084 c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) 1085 case "ping": 1086 c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) 1087 } 1088 if err == nil { 1089 return c, nil 1090 } 1091 if firstErr == nil { 1092 firstErr = err 1093 } 1094 } 1095 if firstErr == nil { 1096 firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} 1097 } 1098 return nil, firstErr 1099 } 1100 1101 func (tnet *Net) Dial(network, address string) (net.Conn, error) { 1102 return tnet.DialContext(context.Background(), network, address) 1103 }