github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/netstack/tun.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package netstack
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/rand"
    12  	"encoding/binary"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"net"
    17  	"net/netip"
    18  	"os"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"syscall"
    23  	"time"
    24  
    25  	"github.com/amnezia-vpn/amneziawg-go/tun"
    26  
    27  	"golang.org/x/net/dns/dnsmessage"
    28  	"gvisor.dev/gvisor/pkg/buffer"
    29  	"gvisor.dev/gvisor/pkg/tcpip"
    30  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    31  	"gvisor.dev/gvisor/pkg/tcpip/header"
    32  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    33  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    34  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    35  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    36  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    37  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    38  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    39  	"gvisor.dev/gvisor/pkg/waiter"
    40  )
    41  
    42  type netTun struct {
    43  	ep             *channel.Endpoint
    44  	stack          *stack.Stack
    45  	events         chan tun.Event
    46  	incomingPacket chan *buffer.View
    47  	mtu            int
    48  	dnsServers     []netip.Addr
    49  	hasV4, hasV6   bool
    50  }
    51  
    52  type Net netTun
    53  
    54  func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
    55  	opts := stack.Options{
    56  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    57  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
    58  		HandleLocal:        true,
    59  	}
    60  	dev := &netTun{
    61  		ep:             channel.New(1024, uint32(mtu), ""),
    62  		stack:          stack.New(opts),
    63  		events:         make(chan tun.Event, 10),
    64  		incomingPacket: make(chan *buffer.View),
    65  		dnsServers:     dnsServers,
    66  		mtu:            mtu,
    67  	}
    68  	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
    69  	tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
    70  	if tcpipErr != nil {
    71  		return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
    72  	}
    73  	dev.ep.AddNotify(dev)
    74  	tcpipErr = dev.stack.CreateNIC(1, dev.ep)
    75  	if tcpipErr != nil {
    76  		return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
    77  	}
    78  	for _, ip := range localAddresses {
    79  		var protoNumber tcpip.NetworkProtocolNumber
    80  		if ip.Is4() {
    81  			protoNumber = ipv4.ProtocolNumber
    82  		} else if ip.Is6() {
    83  			protoNumber = ipv6.ProtocolNumber
    84  		}
    85  		protoAddr := tcpip.ProtocolAddress{
    86  			Protocol:          protoNumber,
    87  			AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
    88  		}
    89  		tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
    90  		if tcpipErr != nil {
    91  			return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
    92  		}
    93  		if ip.Is4() {
    94  			dev.hasV4 = true
    95  		} else if ip.Is6() {
    96  			dev.hasV6 = true
    97  		}
    98  	}
    99  	if dev.hasV4 {
   100  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
   101  	}
   102  	if dev.hasV6 {
   103  		dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
   104  	}
   105  
   106  	dev.events <- tun.EventUp
   107  	return dev, (*Net)(dev), nil
   108  }
   109  
   110  func (tun *netTun) Name() (string, error) {
   111  	return "go", nil
   112  }
   113  
   114  func (tun *netTun) File() *os.File {
   115  	return nil
   116  }
   117  
   118  func (tun *netTun) Events() <-chan tun.Event {
   119  	return tun.events
   120  }
   121  
   122  func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
   123  	view, ok := <-tun.incomingPacket
   124  	if !ok {
   125  		return 0, os.ErrClosed
   126  	}
   127  
   128  	n, err := view.Read(buf[0][offset:])
   129  	if err != nil {
   130  		return 0, err
   131  	}
   132  	sizes[0] = n
   133  	return 1, nil
   134  }
   135  
   136  func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
   137  	for _, buf := range buf {
   138  		packet := buf[offset:]
   139  		if len(packet) == 0 {
   140  			continue
   141  		}
   142  
   143  		pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
   144  		switch packet[0] >> 4 {
   145  		case 4:
   146  			tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
   147  		case 6:
   148  			tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
   149  		default:
   150  			return 0, syscall.EAFNOSUPPORT
   151  		}
   152  	}
   153  	return len(buf), nil
   154  }
   155  
   156  func (tun *netTun) WriteNotify() {
   157  	pkt := tun.ep.Read()
   158  	if pkt.IsNil() {
   159  		return
   160  	}
   161  
   162  	view := pkt.ToView()
   163  	pkt.DecRef()
   164  
   165  	tun.incomingPacket <- view
   166  }
   167  
   168  func (tun *netTun) Close() error {
   169  	tun.stack.RemoveNIC(1)
   170  
   171  	if tun.events != nil {
   172  		close(tun.events)
   173  	}
   174  
   175  	tun.ep.Close()
   176  
   177  	if tun.incomingPacket != nil {
   178  		close(tun.incomingPacket)
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func (tun *netTun) MTU() (int, error) {
   185  	return tun.mtu, nil
   186  }
   187  
   188  func (tun *netTun) BatchSize() int {
   189  	return 1
   190  }
   191  
   192  func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
   193  	var protoNumber tcpip.NetworkProtocolNumber
   194  	if endpoint.Addr().Is4() {
   195  		protoNumber = ipv4.ProtocolNumber
   196  	} else {
   197  		protoNumber = ipv6.ProtocolNumber
   198  	}
   199  	return tcpip.FullAddress{
   200  		NIC:  1,
   201  		Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
   202  		Port: endpoint.Port(),
   203  	}, protoNumber
   204  }
   205  
   206  func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
   207  	fa, pn := convertToFullAddr(addr)
   208  	return gonet.DialContextTCP(ctx, net.stack, fa, pn)
   209  }
   210  
   211  func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
   212  	if addr == nil {
   213  		return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
   214  	}
   215  	ip, _ := netip.AddrFromSlice(addr.IP)
   216  	return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
   217  }
   218  
   219  func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
   220  	fa, pn := convertToFullAddr(addr)
   221  	return gonet.DialTCP(net.stack, fa, pn)
   222  }
   223  
   224  func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
   225  	if addr == nil {
   226  		return net.DialTCPAddrPort(netip.AddrPort{})
   227  	}
   228  	ip, _ := netip.AddrFromSlice(addr.IP)
   229  	return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   230  }
   231  
   232  func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
   233  	fa, pn := convertToFullAddr(addr)
   234  	return gonet.ListenTCP(net.stack, fa, pn)
   235  }
   236  
   237  func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
   238  	if addr == nil {
   239  		return net.ListenTCPAddrPort(netip.AddrPort{})
   240  	}
   241  	ip, _ := netip.AddrFromSlice(addr.IP)
   242  	return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
   243  }
   244  
   245  func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
   246  	var lfa, rfa *tcpip.FullAddress
   247  	var pn tcpip.NetworkProtocolNumber
   248  	if laddr.IsValid() || laddr.Port() > 0 {
   249  		var addr tcpip.FullAddress
   250  		addr, pn = convertToFullAddr(laddr)
   251  		lfa = &addr
   252  	}
   253  	if raddr.IsValid() || raddr.Port() > 0 {
   254  		var addr tcpip.FullAddress
   255  		addr, pn = convertToFullAddr(raddr)
   256  		rfa = &addr
   257  	}
   258  	return gonet.DialUDP(net.stack, lfa, rfa, pn)
   259  }
   260  
   261  func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
   262  	return net.DialUDPAddrPort(laddr, netip.AddrPort{})
   263  }
   264  
   265  func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
   266  	var la, ra netip.AddrPort
   267  	if laddr != nil {
   268  		ip, _ := netip.AddrFromSlice(laddr.IP)
   269  		la = netip.AddrPortFrom(ip, uint16(laddr.Port))
   270  	}
   271  	if raddr != nil {
   272  		ip, _ := netip.AddrFromSlice(raddr.IP)
   273  		ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
   274  	}
   275  	return net.DialUDPAddrPort(la, ra)
   276  }
   277  
   278  func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
   279  	return net.DialUDP(laddr, nil)
   280  }
   281  
   282  type PingConn struct {
   283  	laddr    PingAddr
   284  	raddr    PingAddr
   285  	wq       waiter.Queue
   286  	ep       tcpip.Endpoint
   287  	deadline *time.Timer
   288  }
   289  
   290  type PingAddr struct{ addr netip.Addr }
   291  
   292  func (ia PingAddr) String() string {
   293  	return ia.addr.String()
   294  }
   295  
   296  func (ia PingAddr) Network() string {
   297  	if ia.addr.Is4() {
   298  		return "ping4"
   299  	} else if ia.addr.Is6() {
   300  		return "ping6"
   301  	}
   302  	return "ping"
   303  }
   304  
   305  func (ia PingAddr) Addr() netip.Addr {
   306  	return ia.addr
   307  }
   308  
   309  func PingAddrFromAddr(addr netip.Addr) *PingAddr {
   310  	return &PingAddr{addr}
   311  }
   312  
   313  func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
   314  	if !laddr.IsValid() && !raddr.IsValid() {
   315  		return nil, errors.New("ping dial: invalid address")
   316  	}
   317  	v6 := laddr.Is6() || raddr.Is6()
   318  	bind := laddr.IsValid()
   319  	if !bind {
   320  		if v6 {
   321  			laddr = netip.IPv6Unspecified()
   322  		} else {
   323  			laddr = netip.IPv4Unspecified()
   324  		}
   325  	}
   326  
   327  	tn := icmp.ProtocolNumber4
   328  	pn := ipv4.ProtocolNumber
   329  	if v6 {
   330  		tn = icmp.ProtocolNumber6
   331  		pn = ipv6.ProtocolNumber
   332  	}
   333  
   334  	pc := &PingConn{
   335  		laddr:    PingAddr{laddr},
   336  		deadline: time.NewTimer(time.Hour << 10),
   337  	}
   338  	pc.deadline.Stop()
   339  
   340  	ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
   341  	if tcpipErr != nil {
   342  		return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
   343  	}
   344  	pc.ep = ep
   345  
   346  	if bind {
   347  		fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
   348  		if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
   349  			return nil, fmt.Errorf("ping bind: %s", tcpipErr)
   350  		}
   351  	}
   352  
   353  	if raddr.IsValid() {
   354  		pc.raddr = PingAddr{raddr}
   355  		fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
   356  		if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
   357  			return nil, fmt.Errorf("ping connect: %s", tcpipErr)
   358  		}
   359  	}
   360  
   361  	return pc, nil
   362  }
   363  
   364  func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
   365  	return net.DialPingAddr(laddr, netip.Addr{})
   366  }
   367  
   368  func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
   369  	var la, ra netip.Addr
   370  	if laddr != nil {
   371  		la = laddr.addr
   372  	}
   373  	if raddr != nil {
   374  		ra = raddr.addr
   375  	}
   376  	return net.DialPingAddr(la, ra)
   377  }
   378  
   379  func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
   380  	var la netip.Addr
   381  	if laddr != nil {
   382  		la = laddr.addr
   383  	}
   384  	return net.ListenPingAddr(la)
   385  }
   386  
   387  func (pc *PingConn) LocalAddr() net.Addr {
   388  	return pc.laddr
   389  }
   390  
   391  func (pc *PingConn) RemoteAddr() net.Addr {
   392  	return pc.raddr
   393  }
   394  
   395  func (pc *PingConn) Close() error {
   396  	pc.deadline.Reset(0)
   397  	pc.ep.Close()
   398  	return nil
   399  }
   400  
   401  func (pc *PingConn) SetWriteDeadline(t time.Time) error {
   402  	return errors.New("not implemented")
   403  }
   404  
   405  func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   406  	var na netip.Addr
   407  	switch v := addr.(type) {
   408  	case *PingAddr:
   409  		na = v.addr
   410  	case *net.IPAddr:
   411  		na, _ = netip.AddrFromSlice(v.IP)
   412  	default:
   413  		return 0, fmt.Errorf("ping write: wrong net.Addr type")
   414  	}
   415  	if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
   416  		return 0, fmt.Errorf("ping write: mismatched protocols")
   417  	}
   418  
   419  	buf := bytes.NewReader(p)
   420  	rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
   421  	// won't block, no deadlines
   422  	n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
   423  		To: &rfa,
   424  	})
   425  	if tcpipErr != nil {
   426  		return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
   427  	}
   428  
   429  	return int(n64), nil
   430  }
   431  
   432  func (pc *PingConn) Write(p []byte) (n int, err error) {
   433  	return pc.WriteTo(p, &pc.raddr)
   434  }
   435  
   436  func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   437  	e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
   438  	pc.wq.EventRegister(&e)
   439  	defer pc.wq.EventUnregister(&e)
   440  
   441  	select {
   442  	case <-pc.deadline.C:
   443  		return 0, nil, os.ErrDeadlineExceeded
   444  	case <-notifyCh:
   445  	}
   446  
   447  	w := tcpip.SliceWriter(p)
   448  
   449  	res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
   450  		NeedRemoteAddr: true,
   451  	})
   452  	if tcpipErr != nil {
   453  		return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
   454  	}
   455  
   456  	remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
   457  	return res.Count, &PingAddr{remoteAddr}, nil
   458  }
   459  
   460  func (pc *PingConn) Read(p []byte) (n int, err error) {
   461  	n, _, err = pc.ReadFrom(p)
   462  	return
   463  }
   464  
   465  func (pc *PingConn) SetDeadline(t time.Time) error {
   466  	// pc.SetWriteDeadline is unimplemented
   467  
   468  	return pc.SetReadDeadline(t)
   469  }
   470  
   471  func (pc *PingConn) SetReadDeadline(t time.Time) error {
   472  	pc.deadline.Reset(time.Until(t))
   473  	return nil
   474  }
   475  
   476  var (
   477  	errNoSuchHost                   = errors.New("no such host")
   478  	errLameReferral                 = errors.New("lame referral")
   479  	errCannotUnmarshalDNSMessage    = errors.New("cannot unmarshal DNS message")
   480  	errCannotMarshalDNSMessage      = errors.New("cannot marshal DNS message")
   481  	errServerMisbehaving            = errors.New("server misbehaving")
   482  	errInvalidDNSResponse           = errors.New("invalid DNS response")
   483  	errNoAnswerFromDNSServer        = errors.New("no answer from DNS server")
   484  	errServerTemporarilyMisbehaving = errors.New("server misbehaving")
   485  	errCanceled                     = errors.New("operation was canceled")
   486  	errTimeout                      = errors.New("i/o timeout")
   487  	errNumericPort                  = errors.New("port must be numeric")
   488  	errNoSuitableAddress            = errors.New("no suitable address found")
   489  	errMissingAddress               = errors.New("missing address")
   490  )
   491  
   492  func (net *Net) LookupHost(host string) (addrs []string, err error) {
   493  	return net.LookupContextHost(context.Background(), host)
   494  }
   495  
   496  func isDomainName(s string) bool {
   497  	l := len(s)
   498  	if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
   499  		return false
   500  	}
   501  	last := byte('.')
   502  	nonNumeric := false
   503  	partlen := 0
   504  	for i := 0; i < len(s); i++ {
   505  		c := s[i]
   506  		switch {
   507  		default:
   508  			return false
   509  		case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
   510  			nonNumeric = true
   511  			partlen++
   512  		case '0' <= c && c <= '9':
   513  			partlen++
   514  		case c == '-':
   515  			if last == '.' {
   516  				return false
   517  			}
   518  			partlen++
   519  			nonNumeric = true
   520  		case c == '.':
   521  			if last == '.' || last == '-' {
   522  				return false
   523  			}
   524  			if partlen > 63 || partlen == 0 {
   525  				return false
   526  			}
   527  			partlen = 0
   528  		}
   529  		last = c
   530  	}
   531  	if last == '-' || partlen > 63 {
   532  		return false
   533  	}
   534  	return nonNumeric
   535  }
   536  
   537  func randU16() uint16 {
   538  	var b [2]byte
   539  	_, err := rand.Read(b[:])
   540  	if err != nil {
   541  		panic(err)
   542  	}
   543  	return binary.LittleEndian.Uint16(b[:])
   544  }
   545  
   546  func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
   547  	id = randU16()
   548  	b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
   549  	b.EnableCompression()
   550  	if err := b.StartQuestions(); err != nil {
   551  		return 0, nil, nil, err
   552  	}
   553  	if err := b.Question(q); err != nil {
   554  		return 0, nil, nil, err
   555  	}
   556  	tcpReq, err = b.Finish()
   557  	udpReq = tcpReq[2:]
   558  	l := len(tcpReq) - 2
   559  	tcpReq[0] = byte(l >> 8)
   560  	tcpReq[1] = byte(l)
   561  	return id, udpReq, tcpReq, err
   562  }
   563  
   564  func equalASCIIName(x, y dnsmessage.Name) bool {
   565  	if x.Length != y.Length {
   566  		return false
   567  	}
   568  	for i := 0; i < int(x.Length); i++ {
   569  		a := x.Data[i]
   570  		b := y.Data[i]
   571  		if 'A' <= a && a <= 'Z' {
   572  			a += 0x20
   573  		}
   574  		if 'A' <= b && b <= 'Z' {
   575  			b += 0x20
   576  		}
   577  		if a != b {
   578  			return false
   579  		}
   580  	}
   581  	return true
   582  }
   583  
   584  func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
   585  	if !respHdr.Response {
   586  		return false
   587  	}
   588  	if reqID != respHdr.ID {
   589  		return false
   590  	}
   591  	if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
   592  		return false
   593  	}
   594  	return true
   595  }
   596  
   597  func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   598  	if _, err := c.Write(b); err != nil {
   599  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   600  	}
   601  	b = make([]byte, 512)
   602  	for {
   603  		n, err := c.Read(b)
   604  		if err != nil {
   605  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   606  		}
   607  		var p dnsmessage.Parser
   608  		h, err := p.Start(b[:n])
   609  		if err != nil {
   610  			continue
   611  		}
   612  		q, err := p.Question()
   613  		if err != nil || !checkResponse(id, query, h, q) {
   614  			continue
   615  		}
   616  		return p, h, nil
   617  	}
   618  }
   619  
   620  func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
   621  	if _, err := c.Write(b); err != nil {
   622  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   623  	}
   624  	b = make([]byte, 1280)
   625  	if _, err := io.ReadFull(c, b[:2]); err != nil {
   626  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   627  	}
   628  	l := int(b[0])<<8 | int(b[1])
   629  	if l > len(b) {
   630  		b = make([]byte, l)
   631  	}
   632  	n, err := io.ReadFull(c, b[:l])
   633  	if err != nil {
   634  		return dnsmessage.Parser{}, dnsmessage.Header{}, err
   635  	}
   636  	var p dnsmessage.Parser
   637  	h, err := p.Start(b[:n])
   638  	if err != nil {
   639  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   640  	}
   641  	q, err := p.Question()
   642  	if err != nil {
   643  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
   644  	}
   645  	if !checkResponse(id, query, h, q) {
   646  		return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   647  	}
   648  	return p, h, nil
   649  }
   650  
   651  func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
   652  	q.Class = dnsmessage.ClassINET
   653  	id, udpReq, tcpReq, err := newRequest(q)
   654  	if err != nil {
   655  		return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
   656  	}
   657  
   658  	for _, useUDP := range []bool{true, false} {
   659  		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
   660  		defer cancel()
   661  
   662  		var c net.Conn
   663  		var err error
   664  		if useUDP {
   665  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
   666  		} else {
   667  			c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
   668  		}
   669  
   670  		if err != nil {
   671  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   672  		}
   673  		if d, ok := ctx.Deadline(); ok && !d.IsZero() {
   674  			err := c.SetDeadline(d)
   675  			if err != nil {
   676  				return dnsmessage.Parser{}, dnsmessage.Header{}, err
   677  			}
   678  		}
   679  		var p dnsmessage.Parser
   680  		var h dnsmessage.Header
   681  		if useUDP {
   682  			p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
   683  		} else {
   684  			p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
   685  		}
   686  		c.Close()
   687  		if err != nil {
   688  			if err == context.Canceled {
   689  				err = errCanceled
   690  			} else if err == context.DeadlineExceeded {
   691  				err = errTimeout
   692  			}
   693  			return dnsmessage.Parser{}, dnsmessage.Header{}, err
   694  		}
   695  		if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
   696  			return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
   697  		}
   698  		if h.Truncated {
   699  			continue
   700  		}
   701  		return p, h, nil
   702  	}
   703  	return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
   704  }
   705  
   706  func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
   707  	if h.RCode == dnsmessage.RCodeNameError {
   708  		return errNoSuchHost
   709  	}
   710  	_, err := p.AnswerHeader()
   711  	if err != nil && err != dnsmessage.ErrSectionDone {
   712  		return errCannotUnmarshalDNSMessage
   713  	}
   714  	if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
   715  		return errLameReferral
   716  	}
   717  	if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
   718  		if h.RCode == dnsmessage.RCodeServerFailure {
   719  			return errServerTemporarilyMisbehaving
   720  		}
   721  		return errServerMisbehaving
   722  	}
   723  	return nil
   724  }
   725  
   726  func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
   727  	for {
   728  		h, err := p.AnswerHeader()
   729  		if err == dnsmessage.ErrSectionDone {
   730  			return errNoSuchHost
   731  		}
   732  		if err != nil {
   733  			return errCannotUnmarshalDNSMessage
   734  		}
   735  		if h.Type == qtype {
   736  			return nil
   737  		}
   738  		if err := p.SkipAnswer(); err != nil {
   739  			return errCannotUnmarshalDNSMessage
   740  		}
   741  	}
   742  }
   743  
   744  func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
   745  	var lastErr error
   746  
   747  	n, err := dnsmessage.NewName(name)
   748  	if err != nil {
   749  		return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
   750  	}
   751  	q := dnsmessage.Question{
   752  		Name:  n,
   753  		Type:  qtype,
   754  		Class: dnsmessage.ClassINET,
   755  	}
   756  
   757  	for i := 0; i < 2; i++ {
   758  		for _, server := range tnet.dnsServers {
   759  			p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
   760  			if err != nil {
   761  				dnsErr := &net.DNSError{
   762  					Err:    err.Error(),
   763  					Name:   name,
   764  					Server: server.String(),
   765  				}
   766  				if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   767  					dnsErr.IsTimeout = true
   768  				}
   769  				if _, ok := err.(*net.OpError); ok {
   770  					dnsErr.IsTemporary = true
   771  				}
   772  				lastErr = dnsErr
   773  				continue
   774  			}
   775  
   776  			if err := checkHeader(&p, h); err != nil {
   777  				dnsErr := &net.DNSError{
   778  					Err:    err.Error(),
   779  					Name:   name,
   780  					Server: server.String(),
   781  				}
   782  				if err == errServerTemporarilyMisbehaving {
   783  					dnsErr.IsTemporary = true
   784  				}
   785  				if err == errNoSuchHost {
   786  					dnsErr.IsNotFound = true
   787  					return p, server.String(), dnsErr
   788  				}
   789  				lastErr = dnsErr
   790  				continue
   791  			}
   792  
   793  			err = skipToAnswer(&p, qtype)
   794  			if err == nil {
   795  				return p, server.String(), nil
   796  			}
   797  			lastErr = &net.DNSError{
   798  				Err:    err.Error(),
   799  				Name:   name,
   800  				Server: server.String(),
   801  			}
   802  			if err == errNoSuchHost {
   803  				lastErr.(*net.DNSError).IsNotFound = true
   804  				return p, server.String(), lastErr
   805  			}
   806  		}
   807  	}
   808  	return dnsmessage.Parser{}, "", lastErr
   809  }
   810  
   811  func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
   812  	if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
   813  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
   814  	}
   815  	zlen := len(host)
   816  	if strings.IndexByte(host, ':') != -1 {
   817  		if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
   818  			zlen = zidx
   819  		}
   820  	}
   821  	if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
   822  		return []string{ip.String()}, nil
   823  	}
   824  
   825  	if !isDomainName(host) {
   826  		return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
   827  	}
   828  	type result struct {
   829  		p      dnsmessage.Parser
   830  		server string
   831  		error
   832  	}
   833  	var addrsV4, addrsV6 []netip.Addr
   834  	lanes := 0
   835  	if tnet.hasV4 {
   836  		lanes++
   837  	}
   838  	if tnet.hasV6 {
   839  		lanes++
   840  	}
   841  	lane := make(chan result, lanes)
   842  	var lastErr error
   843  	if tnet.hasV4 {
   844  		go func() {
   845  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
   846  			lane <- result{p, server, err}
   847  		}()
   848  	}
   849  	if tnet.hasV6 {
   850  		go func() {
   851  			p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
   852  			lane <- result{p, server, err}
   853  		}()
   854  	}
   855  	for l := 0; l < lanes; l++ {
   856  		result := <-lane
   857  		if result.error != nil {
   858  			if lastErr == nil {
   859  				lastErr = result.error
   860  			}
   861  			continue
   862  		}
   863  
   864  	loop:
   865  		for {
   866  			h, err := result.p.AnswerHeader()
   867  			if err != nil && err != dnsmessage.ErrSectionDone {
   868  				lastErr = &net.DNSError{
   869  					Err:    errCannotMarshalDNSMessage.Error(),
   870  					Name:   host,
   871  					Server: result.server,
   872  				}
   873  			}
   874  			if err != nil {
   875  				break
   876  			}
   877  			switch h.Type {
   878  			case dnsmessage.TypeA:
   879  				a, err := result.p.AResource()
   880  				if err != nil {
   881  					lastErr = &net.DNSError{
   882  						Err:    errCannotMarshalDNSMessage.Error(),
   883  						Name:   host,
   884  						Server: result.server,
   885  					}
   886  					break loop
   887  				}
   888  				addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
   889  
   890  			case dnsmessage.TypeAAAA:
   891  				aaaa, err := result.p.AAAAResource()
   892  				if err != nil {
   893  					lastErr = &net.DNSError{
   894  						Err:    errCannotMarshalDNSMessage.Error(),
   895  						Name:   host,
   896  						Server: result.server,
   897  					}
   898  					break loop
   899  				}
   900  				addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
   901  
   902  			default:
   903  				if err := result.p.SkipAnswer(); err != nil {
   904  					lastErr = &net.DNSError{
   905  						Err:    errCannotMarshalDNSMessage.Error(),
   906  						Name:   host,
   907  						Server: result.server,
   908  					}
   909  					break loop
   910  				}
   911  				continue
   912  			}
   913  		}
   914  	}
   915  	// We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
   916  	var addrs []netip.Addr
   917  	if tnet.hasV6 {
   918  		addrs = append(addrsV6, addrsV4...)
   919  	} else {
   920  		addrs = append(addrsV4, addrsV6...)
   921  	}
   922  
   923  	if len(addrs) == 0 && lastErr != nil {
   924  		return nil, lastErr
   925  	}
   926  	saddrs := make([]string, 0, len(addrs))
   927  	for _, ip := range addrs {
   928  		saddrs = append(saddrs, ip.String())
   929  	}
   930  	return saddrs, nil
   931  }
   932  
   933  func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
   934  	if deadline.IsZero() {
   935  		return deadline, nil
   936  	}
   937  	timeRemaining := deadline.Sub(now)
   938  	if timeRemaining <= 0 {
   939  		return time.Time{}, errTimeout
   940  	}
   941  	timeout := timeRemaining / time.Duration(addrsRemaining)
   942  	const saneMinimum = 2 * time.Second
   943  	if timeout < saneMinimum {
   944  		if timeRemaining < saneMinimum {
   945  			timeout = timeRemaining
   946  		} else {
   947  			timeout = saneMinimum
   948  		}
   949  	}
   950  	return now.Add(timeout), nil
   951  }
   952  
   953  var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
   954  
   955  func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   956  	if ctx == nil {
   957  		panic("nil context")
   958  	}
   959  	var acceptV4, acceptV6 bool
   960  	matches := protoSplitter.FindStringSubmatch(network)
   961  	if matches == nil {
   962  		return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
   963  	} else if len(matches[2]) == 0 {
   964  		acceptV4 = true
   965  		acceptV6 = true
   966  	} else {
   967  		acceptV4 = matches[2][0] == '4'
   968  		acceptV6 = !acceptV4
   969  	}
   970  	var host string
   971  	var port int
   972  	if matches[1] == "ping" {
   973  		host = address
   974  	} else {
   975  		var sport string
   976  		var err error
   977  		host, sport, err = net.SplitHostPort(address)
   978  		if err != nil {
   979  			return nil, &net.OpError{Op: "dial", Err: err}
   980  		}
   981  		port, err = strconv.Atoi(sport)
   982  		if err != nil || port < 0 || port > 65535 {
   983  			return nil, &net.OpError{Op: "dial", Err: errNumericPort}
   984  		}
   985  	}
   986  	allAddr, err := tnet.LookupContextHost(ctx, host)
   987  	if err != nil {
   988  		return nil, &net.OpError{Op: "dial", Err: err}
   989  	}
   990  	var addrs []netip.AddrPort
   991  	for _, addr := range allAddr {
   992  		ip, err := netip.ParseAddr(addr)
   993  		if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
   994  			addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
   995  		}
   996  	}
   997  	if len(addrs) == 0 && len(allAddr) != 0 {
   998  		return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
   999  	}
  1000  
  1001  	var firstErr error
  1002  	for i, addr := range addrs {
  1003  		select {
  1004  		case <-ctx.Done():
  1005  			err := ctx.Err()
  1006  			if err == context.Canceled {
  1007  				err = errCanceled
  1008  			} else if err == context.DeadlineExceeded {
  1009  				err = errTimeout
  1010  			}
  1011  			return nil, &net.OpError{Op: "dial", Err: err}
  1012  		default:
  1013  		}
  1014  
  1015  		dialCtx := ctx
  1016  		if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
  1017  			partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
  1018  			if err != nil {
  1019  				if firstErr == nil {
  1020  					firstErr = &net.OpError{Op: "dial", Err: err}
  1021  				}
  1022  				break
  1023  			}
  1024  			if partialDeadline.Before(deadline) {
  1025  				var cancel context.CancelFunc
  1026  				dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
  1027  				defer cancel()
  1028  			}
  1029  		}
  1030  
  1031  		var c net.Conn
  1032  		switch matches[1] {
  1033  		case "tcp":
  1034  			c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
  1035  		case "udp":
  1036  			c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
  1037  		case "ping":
  1038  			c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
  1039  		}
  1040  		if err == nil {
  1041  			return c, nil
  1042  		}
  1043  		if firstErr == nil {
  1044  			firstErr = err
  1045  		}
  1046  	}
  1047  	if firstErr == nil {
  1048  		firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
  1049  	}
  1050  	return nil, firstErr
  1051  }
  1052  
  1053  func (tnet *Net) Dial(network, address string) (net.Conn, error) {
  1054  	return tnet.DialContext(context.Background(), network, address)
  1055  }