github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/services/wireguard/endpoint/netstack/netstack.go (about)

     1  /*
     2   * Copyright (C) 2022 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  /* SPDX-License-Identifier: MIT
    19   *
    20   * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
    21   */
    22  
    23  package netstack
    24  
    25  import (
    26  	"bytes"
    27  	"context"
    28  	"crypto/rand"
    29  	"encoding/binary"
    30  	"errors"
    31  	"fmt"
    32  	"io"
    33  	"net"
    34  	"net/netip"
    35  	"os"
    36  	"regexp"
    37  	"strconv"
    38  	"strings"
    39  	"syscall"
    40  	"time"
    41  
    42  	"golang.org/x/net/dns/dnsmessage"
    43  	"golang.zx2c4.com/wireguard/tun"
    44  
    45  	"gvisor.dev/gvisor/pkg/buffer"
    46  	"gvisor.dev/gvisor/pkg/refs"
    47  	"gvisor.dev/gvisor/pkg/tcpip"
    48  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    49  	"gvisor.dev/gvisor/pkg/tcpip/header"
    50  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    51  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    52  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    53  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    54  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    55  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    56  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    57  	"gvisor.dev/gvisor/pkg/waiter"
    58  )
    59  
    60  // MaxDNSMsgSize defines UDP recv buffer for resolver and announced
    61  // EDNS0 buffer size
    62  const MaxDNSResponseMsgSize = 4096
    63  
    64  type netTun struct {
    65  	ep             *channel.Endpoint
    66  	stack          *stack.Stack
    67  	dispatcher     stack.NetworkDispatcher
    68  	events         chan tun.Event
    69  	incomingPacket chan *buffer.View
    70  	mtu            int
    71  	dnsServers     []netip.Addr
    72  	hasV4, hasV6   bool
    73  }
    74  
    75  type (
    76  	endpoint netTun
    77  	Net      netTun
    78  )
    79  
    80  func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
    81  	refs.SetLeakMode(refs.NoLeakChecking)
    82  
    83  	opts := stack.Options{
    84  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    85  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
    86  		HandleLocal:        true,
    87  	}
    88  	dev := &netTun{
    89  		ep:             channel.New(1024, uint32(mtu), ""),
    90  		stack:          stack.New(opts),
    91  		events:         make(chan tun.Event, 10),
    92  		incomingPacket: make(chan *buffer.View),
    93  		dnsServers:     dnsServers,
    94  		mtu:            mtu,
    95  	}
    96  
    97  	sack := tcpip.TCPSACKEnabled(true)
    98  	if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sack); err != nil {
    99  		return nil, nil, fmt.Errorf("enabling SACK failed: %v", err)
   100  	}
   101  
   102  	moderation := tcpip.TCPModerateReceiveBufferOption(true)
   103  	if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &moderation); err != nil {
   104  		return nil, nil, fmt.Errorf("failed to set TCP moderate receive buffer option: %v", err)
   105  	}
   106  
   107  	dev.ep.AddNotify(dev)
   108  	tcpipErr := dev.stack.CreateNIC(1, dev.ep)
   109  	if tcpipErr != nil {
   110  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
   111  	}
   112  	for _, ip := range localAddresses {
   113  		var protoNumber tcpip.NetworkProtocolNumber
   114  		if ip.Is4() {
   115  			protoNumber = ipv4.ProtocolNumber
   116  		} else if ip.Is6() {
   117  			protoNumber = ipv6.ProtocolNumber
   118  		}
   119  		protoAddr := tcpip.ProtocolAddress{
   120  			Protocol:          protoNumber,
   121  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
   122  		}
   123  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
   124  		if tcpipErr != nil {
   125  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
   126  		}
   127  		if ip.Is4() {
   128  			dev.hasV4 = true
   129  		} else if ip.Is6() {
   130  			dev.hasV6 = true
   131  		}
   132  	}
   133  	if dev.hasV4 {
   134  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
   135  	}
   136  	if dev.hasV6 {
   137  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
   138  	}
   139  
   140  	dev.events <- tun.EventUp
   141  	return dev, (*Net)(dev), nil
   142  }
   143  
   144  func (tun *netTun) BatchSize() int {
   145  	return 1
   146  }
   147  
   148  func (tun *netTun) Name() (string, error) {
   149  	return "go", nil
   150  }
   151  
   152  func (tun *netTun) File() *os.File {
   153  	return nil
   154  }
   155  
   156  func (tun *netTun) Events() <-chan tun.Event {
   157  	return tun.events
   158  }
   159  
   160  func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
   161  	view, ok := <-tun.incomingPacket
   162  	if !ok {
   163  		return 0, os.ErrClosed
   164  	}
   165  
   166  	n, err := view.Read(buf[0][offset:])
   167  	if err != nil {
   168  		return 0, err
   169  	}
   170  	sizes[0] = n
   171  	return 1, nil
   172  }
   173  
   174  func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
   175  	for _, buf := range buf {
   176  		packet := buf[offset:]
   177  		if len(packet) == 0 {
   178  			continue
   179  		}
   180  
   181  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
   182  		switch packet[0] >> 4 {
   183  		case 4:
   184  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   185  		case 6:
   186  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   187  		default:
   188  			return 0, syscall.EAFNOSUPPORT
   189  		}
   190  	}
   191  	return len(buf), nil
   192  }
   193  
   194  func (tun *netTun) WriteNotify() {
   195  	pkt := tun.ep.Read()
   196  	if pkt.IsNil() {
   197  		return
   198  	}
   199  
   200  	view := pkt.ToView()
   201  	pkt.DecRef()
   202  
   203  	tun.incomingPacket <- view
   204  }
   205  
   206  func (tun *netTun) Flush() error {
   207  	return nil
   208  }
   209  
   210  func (tun *netTun) Close() error {
   211  	tun.stack.RemoveNIC(1)
   212  	tun.stack.Close()
   213  
   214  	if tun.events != nil {
   215  		close(tun.events)
   216  	}
   217  
   218  	tun.ep.Close()
   219  	if tun.incomingPacket != nil {
   220  		close(tun.incomingPacket)
   221  	}
   222  	return nil
   223  }
   224  
   225  func (tun *netTun) MTU() (int, error) {
   226  	return tun.mtu, nil
   227  }
   228  
   229  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   230  	var protoNumber tcpip.NetworkProtocolNumber
   231  	if endpoint.Addr().Is4() {
   232  		protoNumber = ipv4.ProtocolNumber
   233  	} else {
   234  		protoNumber = ipv6.ProtocolNumber
   235  	}
   236  	return tcpip.FullAddress{
   237  		NIC:  1,
   238  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
   239  		Port: endpoint.Port(),
   240  	}, protoNumber
   241  }
   242  
   243  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   244  	fa, pn := convertToFullAddr(addr)
   245  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   246  }
   247  
   248  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   249  	if addr == nil {
   250  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   251  	}
   252  	ip, _ := netip.AddrFromSlice(addr.IP)
   253  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   254  }
   255  
   256  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   257  	fa, pn := convertToFullAddr(addr)
   258  	return gonet.DialTCP(net.stack, fa, pn)
   259  }
   260  
   261  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   262  	if addr == nil {
   263  		return net.DialTCPAddrPort(netip.AddrPort{})
   264  	}
   265  	ip, _ := netip.AddrFromSlice(addr.IP)
   266  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   267  }
   268  
   269  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   270  	fa, pn := convertToFullAddr(addr)
   271  	return gonet.ListenTCP(net.stack, fa, pn)
   272  }
   273  
   274  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   275  	if addr == nil {
   276  		return net.ListenTCPAddrPort(netip.AddrPort{})
   277  	}
   278  	ip, _ := netip.AddrFromSlice(addr.IP)
   279  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   280  }
   281  
   282  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   283  	var lfa, rfa *tcpip.FullAddress
   284  	var pn tcpip.NetworkProtocolNumber
   285  	if laddr.IsValid() || laddr.Port() > 0 {
   286  		var addr tcpip.FullAddress
   287  		addr, pn = convertToFullAddr(laddr)
   288  		lfa = &addr
   289  	}
   290  	if raddr.IsValid() || raddr.Port() > 0 {
   291  		var addr tcpip.FullAddress
   292  		addr, pn = convertToFullAddr(raddr)
   293  		rfa = &addr
   294  	}
   295  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   296  }
   297  
   298  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   299  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   300  }
   301  
   302  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   303  	var la, ra netip.AddrPort
   304  	if laddr != nil {
   305  		ip, _ := netip.AddrFromSlice(laddr.IP)
   306  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   307  	}
   308  	if raddr != nil {
   309  		ip, _ := netip.AddrFromSlice(raddr.IP)
   310  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   311  	}
   312  	return net.DialUDPAddrPort(la, ra)
   313  }
   314  
   315  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   316  	return net.DialUDP(laddr, nil)
   317  }
   318  
   319  type PingConn struct {
   320  	laddr    PingAddr
   321  	raddr    PingAddr
   322  	wq       waiter.Queue
   323  	ep       tcpip.Endpoint
   324  	deadline *time.Timer
   325  }
   326  
   327  type PingAddr struct{ addr netip.Addr }
   328  
   329  func (ia PingAddr) String() string {
   330  	return ia.addr.String()
   331  }
   332  
   333  func (ia PingAddr) Network() string {
   334  	if ia.addr.Is4() {
   335  		return "ping4"
   336  	} else if ia.addr.Is6() {
   337  		return "ping6"
   338  	}
   339  	return "ping"
   340  }
   341  
   342  func (ia PingAddr) Addr() netip.Addr {
   343  	return ia.addr
   344  }
   345  
   346  func PingAddrFromAddr(addr netip.Addr) *PingAddr {
   347  	return &PingAddr{addr}
   348  }
   349  
   350  func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
   351  	if !laddr.IsValid() && !raddr.IsValid() {
   352  		return nil, errors.New("ping dial: invalid address")
   353  	}
   354  	v6 := laddr.Is6() || raddr.Is6()
   355  	bind := laddr.IsValid()
   356  	if !bind {
   357  		if v6 {
   358  			laddr = netip.IPv6Unspecified()
   359  		} else {
   360  			laddr = netip.IPv4Unspecified()
   361  		}
   362  	}
   363  
   364  	tn := icmp.ProtocolNumber4
   365  	pn := ipv4.ProtocolNumber
   366  	if v6 {
   367  		tn = icmp.ProtocolNumber6
   368  		pn = ipv6.ProtocolNumber
   369  	}
   370  
   371  	pc := &PingConn{
   372  		laddr:    PingAddr{laddr},
   373  		deadline: time.NewTimer(time.Hour << 10),
   374  	}
   375  	pc.deadline.Stop()
   376  
   377  	ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
   378  	if tcpipErr != nil {
   379  		return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
   380  	}
   381  	pc.ep = ep
   382  
   383  	if bind {
   384  		fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
   385  		if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
   386  			return nil, fmt.Errorf("ping bind: %s", tcpipErr)
   387  		}
   388  	}
   389  
   390  	if raddr.IsValid() {
   391  		pc.raddr = PingAddr{raddr}
   392  		fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
   393  		if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
   394  			return nil, fmt.Errorf("ping connect: %s", tcpipErr)
   395  		}
   396  	}
   397  
   398  	return pc, nil
   399  }
   400  
   401  func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
   402  	return net.DialPingAddr(laddr, netip.Addr{})
   403  }
   404  
   405  func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
   406  	var la, ra netip.Addr
   407  	if laddr != nil {
   408  		la = laddr.addr
   409  	}
   410  	if raddr != nil {
   411  		ra = raddr.addr
   412  	}
   413  	return net.DialPingAddr(la, ra)
   414  }
   415  
   416  func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
   417  	var la netip.Addr
   418  	if laddr != nil {
   419  		la = laddr.addr
   420  	}
   421  	return net.ListenPingAddr(la)
   422  }
   423  
   424  func (pc *PingConn) LocalAddr() net.Addr {
   425  	return pc.laddr
   426  }
   427  
   428  func (pc *PingConn) RemoteAddr() net.Addr {
   429  	return pc.raddr
   430  }
   431  
   432  func (pc *PingConn) Close() error {
   433  	pc.deadline.Reset(0)
   434  	pc.ep.Close()
   435  	return nil
   436  }
   437  
   438  func (pc *PingConn) SetWriteDeadline(t time.Time) error {
   439  	return errors.New("not implemented")
   440  }
   441  
   442  func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   443  	var na netip.Addr
   444  	switch v := addr.(type) {
   445  	case *PingAddr:
   446  		na = v.addr
   447  	case *net.IPAddr:
   448  		na, _ = netip.AddrFromSlice(v.IP)
   449  	default:
   450  		return 0, fmt.Errorf("ping write: wrong net.Addr type")
   451  	}
   452  	if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
   453  		return 0, fmt.Errorf("ping write: mismatched protocols")
   454  	}
   455  
   456  	buf := bytes.NewBuffer(p)
   457  	rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
   458  	// won't block, no deadlines
   459  	n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
   460  		To: &rfa,
   461  	})
   462  	if tcpipErr != nil {
   463  		return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
   464  	}
   465  
   466  	return int(n64), nil
   467  }
   468  
   469  func (pc *PingConn) Write(p []byte) (n int, err error) {
   470  	return pc.WriteTo(p, &pc.raddr)
   471  }
   472  
   473  func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   474  	e, notifyCh := waiter.NewChannelEntry(0)
   475  	pc.wq.EventRegister(&e)
   476  	defer pc.wq.EventUnregister(&e)
   477  
   478  	select {
   479  	case <-pc.deadline.C:
   480  		return 0, nil, os.ErrDeadlineExceeded
   481  	case <-notifyCh:
   482  	}
   483  
   484  	w := tcpip.SliceWriter(p)
   485  
   486  	res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
   487  		NeedRemoteAddr: true,
   488  	})
   489  	if tcpipErr != nil {
   490  		return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
   491  	}
   492  
   493  	remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
   494  	return res.Count, &PingAddr{remoteAddr}, nil
   495  }
   496  
   497  func (pc *PingConn) Read(p []byte) (n int, err error) {
   498  	n, _, err = pc.ReadFrom(p)
   499  	return
   500  }
   501  
   502  func (pc *PingConn) SetDeadline(t time.Time) error {
   503  	// pc.SetWriteDeadline is unimplemented
   504  
   505  	return pc.SetReadDeadline(t)
   506  }
   507  
   508  func (pc *PingConn) SetReadDeadline(t time.Time) error {
   509  	pc.deadline.Reset(t.Sub(time.Now()))
   510  	return nil
   511  }
   512  
   513  var (
   514  	errNoSuchHost                   = errors.New("no such host")
   515  	errLameReferral                 = errors.New("lame referral")
   516  	errCannotUnmarshalDNSMessage    = errors.New("cannot unmarshal DNS message")
   517  	errCannotMarshalDNSMessage      = errors.New("cannot marshal DNS message")
   518  	errServerMisbehaving            = errors.New("server misbehaving")
   519  	errInvalidDNSResponse           = errors.New("invalid DNS response")
   520  	errNoAnswerFromDNSServer        = errors.New("no answer from DNS server")
   521  	errServerTemporarilyMisbehaving = errors.New("server misbehaving")
   522  	errCanceled                     = errors.New("operation was canceled")
   523  	errTimeout                      = errors.New("i/o timeout")
   524  	errNumericPort                  = errors.New("port must be numeric")
   525  	errNoSuitableAddress            = errors.New("no suitable address found")
   526  	errMissingAddress               = errors.New("missing address")
   527  )
   528  
   529  func (net *Net) LookupHost(host string) (addrs []string, err error) {
   530  	return net.LookupContextHost(context.Background(), host)
   531  }
   532  
   533  func isDomainName(s string) bool {
   534  	l := len(s)
   535  	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
   536  		return false
   537  	}
   538  	last := byte('.')
   539  	nonNumeric := false
   540  	partlen := 0
   541  	for i := 0; i < len(s); i++ {
   542  		c := s[i]
   543  		switch {
   544  		default:
   545  			return false
   546  		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
   547  			nonNumeric = true
   548  			partlen++
   549  		case '0' <= c && c <= '9':
   550  			partlen++
   551  		case c == '-':
   552  			if last == '.' {
   553  				return false
   554  			}
   555  			partlen++
   556  			nonNumeric = true
   557  		case c == '.':
   558  			if last == '.' || last == '-' {
   559  				return false
   560  			}
   561  			if partlen > 63 || partlen == 0 {
   562  				return false
   563  			}
   564  			partlen = 0
   565  		}
   566  		last = c
   567  	}
   568  	if last == '-' || partlen > 63 {
   569  		return false
   570  	}
   571  	return nonNumeric
   572  }
   573  
   574  func randU16() uint16 {
   575  	var b [2]byte
   576  	_, err := rand.Read(b[:])
   577  	if err != nil {
   578  		panic(err)
   579  	}
   580  	return binary.LittleEndian.Uint16(b[:])
   581  }
   582  
   583  func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
   584  	id = randU16()
   585  	b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
   586  	b.EnableCompression()
   587  	if err := b.StartQuestions(); err != nil {
   588  		return 0, nil, nil, err
   589  	}
   590  	if err := b.Question(q); err != nil {
   591  		return 0, nil, nil, err
   592  	}
   593  	// Accept packets up to MaxDNSResponseMsgSize.  RFC 6891.
   594  	if err := b.StartAdditionals(); err != nil {
   595  		return 0, nil, nil, err
   596  	}
   597  	var rh dnsmessage.ResourceHeader
   598  	if err := rh.SetEDNS0(MaxDNSResponseMsgSize, dnsmessage.RCodeSuccess, false); err != nil {
   599  		return 0, nil, nil, err
   600  	}
   601  	if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
   602  		return 0, nil, nil, err
   603  	}
   604  	tcpReq, err = b.Finish()
   605  	udpReq = tcpReq[2:]
   606  	l := len(tcpReq) - 2
   607  	tcpReq[0] = byte(l >> 8)
   608  	tcpReq[1] = byte(l)
   609  	return id, udpReq, tcpReq, err
   610  }
   611  
   612  func equalASCIIName(x, y dnsmessage.Name) bool {
   613  	if x.Length != y.Length {
   614  		return false
   615  	}
   616  	for i := 0; i < int(x.Length); i++ {
   617  		a := x.Data[i]
   618  		b := y.Data[i]
   619  		if 'A' <= a && a <= 'Z' {
   620  			a += 0x20
   621  		}
   622  		if 'A' <= b && b <= 'Z' {
   623  			b += 0x20
   624  		}
   625  		if a != b {
   626  			return false
   627  		}
   628  	}
   629  	return true
   630  }
   631  
   632  func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
   633  	if !respHdr.Response {
   634  		return false
   635  	}
   636  	if reqID != respHdr.ID {
   637  		return false
   638  	}
   639  	if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
   640  		return false
   641  	}
   642  	return true
   643  }
   644  
   645  func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   646  	if _, err := c.Write(b); err != nil {
   647  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   648  	}
   649  	b = make([]byte, MaxDNSResponseMsgSize)
   650  	for {
   651  		n, err := c.Read(b)
   652  		if err != nil {
   653  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   654  		}
   655  		var p dnsmessage.Parser
   656  		h, err := p.Start(b[:n])
   657  		if err != nil {
   658  			continue
   659  		}
   660  		q, err := p.Question()
   661  		if err != nil || !checkResponse(id, query, h, q) {
   662  			continue
   663  		}
   664  		return p, h, nil
   665  	}
   666  }
   667  
   668  func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   669  	if _, err := c.Write(b); err != nil {
   670  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   671  	}
   672  	b = make([]byte, 1280)
   673  	if _, err := io.ReadFull(c, b[:2]); err != nil {
   674  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   675  	}
   676  	l := int(b[0])<<8 | int(b[1])
   677  	if l > len(b) {
   678  		b = make([]byte, l)
   679  	}
   680  	n, err := io.ReadFull(c, b[:l])
   681  	if err != nil {
   682  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   683  	}
   684  	var p dnsmessage.Parser
   685  	h, err := p.Start(b[:n])
   686  	if err != nil {
   687  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   688  	}
   689  	q, err := p.Question()
   690  	if err != nil {
   691  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   692  	}
   693  	if !checkResponse(id, query, h, q) {
   694  		return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   695  	}
   696  	return p, h, nil
   697  }
   698  
   699  func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
   700  	q.Class = dnsmessage.ClassINET
   701  	id, udpReq, tcpReq, err := newRequest(q)
   702  	if err != nil {
   703  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
   704  	}
   705  
   706  	for _, useUDP := range []bool{true, false} {
   707  		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
   708  		defer cancel()
   709  
   710  		var c net.Conn
   711  		var err error
   712  		if useUDP {
   713  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
   714  		} else {
   715  			c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
   716  		}
   717  
   718  		if err != nil {
   719  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   720  		}
   721  		if d, ok := ctx.Deadline(); ok && !d.IsZero() {
   722  			err := c.SetDeadline(d)
   723  			if err != nil {
   724  				return dnsmessage.Parser{}, dnsmessage.Header{}, err
   725  			}
   726  		}
   727  		var p dnsmessage.Parser
   728  		var h dnsmessage.Header
   729  		if useUDP {
   730  			p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
   731  		} else {
   732  			p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
   733  		}
   734  		c.Close()
   735  		if err != nil {
   736  			if err == context.Canceled {
   737  				err = errCanceled
   738  			} else if err == context.DeadlineExceeded {
   739  				err = errTimeout
   740  			}
   741  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   742  		}
   743  		if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
   744  			return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   745  		}
   746  		if h.Truncated {
   747  			continue
   748  		}
   749  		return p, h, nil
   750  	}
   751  	return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
   752  }
   753  
   754  func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
   755  	if h.RCode == dnsmessage.RCodeNameError {
   756  		return errNoSuchHost
   757  	}
   758  	_, err := p.AnswerHeader()
   759  	if err != nil && err != dnsmessage.ErrSectionDone {
   760  		return errCannotUnmarshalDNSMessage
   761  	}
   762  	if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
   763  		return errLameReferral
   764  	}
   765  	if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
   766  		if h.RCode == dnsmessage.RCodeServerFailure {
   767  			return errServerTemporarilyMisbehaving
   768  		}
   769  		return errServerMisbehaving
   770  	}
   771  	return nil
   772  }
   773  
   774  func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
   775  	for {
   776  		h, err := p.AnswerHeader()
   777  		if err == dnsmessage.ErrSectionDone {
   778  			return errNoSuchHost
   779  		}
   780  		if err != nil {
   781  			return errCannotUnmarshalDNSMessage
   782  		}
   783  		if h.Type == qtype {
   784  			return nil
   785  		}
   786  		if err := p.SkipAnswer(); err != nil {
   787  			return errCannotUnmarshalDNSMessage
   788  		}
   789  	}
   790  }
   791  
   792  func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
   793  	var lastErr error
   794  
   795  	n, err := dnsmessage.NewName(name)
   796  	if err != nil {
   797  		return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
   798  	}
   799  	q := dnsmessage.Question{
   800  		Name:  n,
   801  		Type:  qtype,
   802  		Class: dnsmessage.ClassINET,
   803  	}
   804  
   805  	for i := 0; i < 2; i++ {
   806  		for _, server := range tnet.dnsServers {
   807  			p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
   808  			if err != nil {
   809  				dnsErr := &net.DNSError{
   810  					Err:    err.Error(),
   811  					Name:   name,
   812  					Server: server.String(),
   813  				}
   814  				if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   815  					dnsErr.IsTimeout = true
   816  				}
   817  				if _, ok := err.(*net.OpError); ok {
   818  					dnsErr.IsTemporary = true
   819  				}
   820  				lastErr = dnsErr
   821  				continue
   822  			}
   823  
   824  			if err := checkHeader(&p, h); err != nil {
   825  				dnsErr := &net.DNSError{
   826  					Err:    err.Error(),
   827  					Name:   name,
   828  					Server: server.String(),
   829  				}
   830  				if err == errServerTemporarilyMisbehaving {
   831  					dnsErr.IsTemporary = true
   832  				}
   833  				if err == errNoSuchHost {
   834  					dnsErr.IsNotFound = true
   835  					return p, server.String(), dnsErr
   836  				}
   837  				lastErr = dnsErr
   838  				continue
   839  			}
   840  
   841  			err = skipToAnswer(&p, qtype)
   842  			if err == nil {
   843  				return p, server.String(), nil
   844  			}
   845  			lastErr = &net.DNSError{
   846  				Err:    err.Error(),
   847  				Name:   name,
   848  				Server: server.String(),
   849  			}
   850  			if err == errNoSuchHost {
   851  				lastErr.(*net.DNSError).IsNotFound = true
   852  				return p, server.String(), lastErr
   853  			}
   854  		}
   855  	}
   856  	return dnsmessage.Parser{}, "", lastErr
   857  }
   858  
   859  func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
   860  	if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
   861  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
   862  	}
   863  	zlen := len(host)
   864  	if strings.IndexByte(host, ':') != -1 {
   865  		if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
   866  			zlen = zidx
   867  		}
   868  	}
   869  	if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
   870  		return []string{ip.String()}, nil
   871  	}
   872  
   873  	if !isDomainName(host) {
   874  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
   875  	}
   876  	type result struct {
   877  		p      dnsmessage.Parser
   878  		server string
   879  		error
   880  	}
   881  	var addrsV4, addrsV6 []netip.Addr
   882  	lanes := 0
   883  	if tnet.hasV4 {
   884  		lanes++
   885  	}
   886  	if tnet.hasV6 {
   887  		lanes++
   888  	}
   889  	lane := make(chan result, lanes)
   890  	var lastErr error
   891  	if tnet.hasV4 {
   892  		go func() {
   893  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
   894  			lane <- result{p, server, err}
   895  		}()
   896  	}
   897  	if tnet.hasV6 {
   898  		go func() {
   899  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
   900  			lane <- result{p, server, err}
   901  		}()
   902  	}
   903  	for l := 0; l < lanes; l++ {
   904  		result := <-lane
   905  		if result.error != nil {
   906  			if lastErr == nil {
   907  				lastErr = result.error
   908  			}
   909  			continue
   910  		}
   911  
   912  	loop:
   913  		for {
   914  			h, err := result.p.AnswerHeader()
   915  			if err != nil && err != dnsmessage.ErrSectionDone {
   916  				lastErr = &net.DNSError{
   917  					Err:    errCannotMarshalDNSMessage.Error(),
   918  					Name:   host,
   919  					Server: result.server,
   920  				}
   921  			}
   922  			if err != nil {
   923  				break
   924  			}
   925  			switch h.Type {
   926  			case dnsmessage.TypeA:
   927  				a, err := result.p.AResource()
   928  				if err != nil {
   929  					lastErr = &net.DNSError{
   930  						Err:    errCannotMarshalDNSMessage.Error(),
   931  						Name:   host,
   932  						Server: result.server,
   933  					}
   934  					break loop
   935  				}
   936  				addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
   937  
   938  			case dnsmessage.TypeAAAA:
   939  				aaaa, err := result.p.AAAAResource()
   940  				if err != nil {
   941  					lastErr = &net.DNSError{
   942  						Err:    errCannotMarshalDNSMessage.Error(),
   943  						Name:   host,
   944  						Server: result.server,
   945  					}
   946  					break loop
   947  				}
   948  				addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
   949  
   950  			default:
   951  				if err := result.p.SkipAnswer(); err != nil {
   952  					lastErr = &net.DNSError{
   953  						Err:    errCannotMarshalDNSMessage.Error(),
   954  						Name:   host,
   955  						Server: result.server,
   956  					}
   957  					break loop
   958  				}
   959  				continue
   960  			}
   961  		}
   962  	}
   963  	// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
   964  	var addrs []netip.Addr
   965  	if tnet.hasV6 {
   966  		addrs = append(addrsV6, addrsV4...)
   967  	} else {
   968  		addrs = append(addrsV4, addrsV6...)
   969  	}
   970  
   971  	if len(addrs) == 0 && lastErr != nil {
   972  		return nil, lastErr
   973  	}
   974  	saddrs := make([]string, 0, len(addrs))
   975  	for _, ip := range addrs {
   976  		saddrs = append(saddrs, ip.String())
   977  	}
   978  	return saddrs, nil
   979  }
   980  
   981  func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
   982  	if deadline.IsZero() {
   983  		return deadline, nil
   984  	}
   985  	timeRemaining := deadline.Sub(now)
   986  	if timeRemaining <= 0 {
   987  		return time.Time{}, errTimeout
   988  	}
   989  	timeout := timeRemaining / time.Duration(addrsRemaining)
   990  	const saneMinimum = 2 * time.Second
   991  	if timeout < saneMinimum {
   992  		if timeRemaining < saneMinimum {
   993  			timeout = timeRemaining
   994  		} else {
   995  			timeout = saneMinimum
   996  		}
   997  	}
   998  	return now.Add(timeout), nil
   999  }
  1000  
  1001  var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
  1002  
  1003  func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  1004  	if ctx == nil {
  1005  		panic("nil context")
  1006  	}
  1007  	var acceptV4, acceptV6 bool
  1008  	matches := protoSplitter.FindStringSubmatch(network)
  1009  	if matches == nil {
  1010  		return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
  1011  	} else if len(matches[2]) == 0 {
  1012  		acceptV4 = true
  1013  		acceptV6 = true
  1014  	} else {
  1015  		acceptV4 = matches[2][0] == '4'
  1016  		acceptV6 = !acceptV4
  1017  	}
  1018  	var host string
  1019  	var port int
  1020  	if matches[1] == "ping" {
  1021  		host = address
  1022  	} else {
  1023  		var sport string
  1024  		var err error
  1025  		host, sport, err = net.SplitHostPort(address)
  1026  		if err != nil {
  1027  			return nil, &net.OpError{Op: "dial", Err: err}
  1028  		}
  1029  		port, err = strconv.Atoi(sport)
  1030  		if err != nil || port < 0 || port > 65535 {
  1031  			return nil, &net.OpError{Op: "dial", Err: errNumericPort}
  1032  		}
  1033  	}
  1034  	allAddr, err := tnet.LookupContextHost(ctx, host)
  1035  	if err != nil {
  1036  		return nil, &net.OpError{Op: "dial", Err: err}
  1037  	}
  1038  	var addrs []netip.AddrPort
  1039  	for _, addr := range allAddr {
  1040  		ip, err := netip.ParseAddr(addr)
  1041  		if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
  1042  			addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
  1043  		}
  1044  	}
  1045  	if len(addrs) == 0 && len(allAddr) != 0 {
  1046  		return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
  1047  	}
  1048  
  1049  	var firstErr error
  1050  	for i, addr := range addrs {
  1051  		select {
  1052  		case <-ctx.Done():
  1053  			err := ctx.Err()
  1054  			if err == context.Canceled {
  1055  				err = errCanceled
  1056  			} else if err == context.DeadlineExceeded {
  1057  				err = errTimeout
  1058  			}
  1059  			return nil, &net.OpError{Op: "dial", Err: err}
  1060  		default:
  1061  		}
  1062  
  1063  		dialCtx := ctx
  1064  		if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
  1065  			partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
  1066  			if err != nil {
  1067  				if firstErr == nil {
  1068  					firstErr = &net.OpError{Op: "dial", Err: err}
  1069  				}
  1070  				break
  1071  			}
  1072  			if partialDeadline.Before(deadline) {
  1073  				var cancel context.CancelFunc
  1074  				dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
  1075  				defer cancel()
  1076  			}
  1077  		}
  1078  
  1079  		var c net.Conn
  1080  		switch matches[1] {
  1081  		case "tcp":
  1082  			c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
  1083  		case "udp":
  1084  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
  1085  		case "ping":
  1086  			c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
  1087  		}
  1088  		if err == nil {
  1089  			return c, nil
  1090  		}
  1091  		if firstErr == nil {
  1092  			firstErr = err
  1093  		}
  1094  	}
  1095  	if firstErr == nil {
  1096  		firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
  1097  	}
  1098  	return nil, firstErr
  1099  }
  1100  
  1101  func (tnet *Net) Dial(network, address string) (net.Conn, error) {
  1102  	return tnet.DialContext(context.Background(), network, address)
  1103  }