github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/dns/doh.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"encoding/json"
     7  	"errors"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/FloatTech/ttl"
    16  	"golang.org/x/net/http2"
    17  
    18  	"github.com/fumiama/terasu"
    19  	"github.com/fumiama/terasu/ip"
    20  )
    21  
    22  var (
    23  	ErrEmptyHostAddress = errors.New("empty host addr")
    24  )
    25  
    26  type recordType uint16
    27  
    28  const (
    29  	recordTypeNone recordType = 0
    30  	recordTypeA    recordType = 1
    31  	recordTypeAAAA recordType = 28
    32  )
    33  
    34  type dohjsonresponse struct {
    35  	Status   uint32
    36  	TC       bool
    37  	RD       bool
    38  	RA       bool
    39  	AD       bool
    40  	CD       bool
    41  	Question []struct {
    42  		Name string     `json:"name"`
    43  		Type recordType `json:"type"`
    44  	}
    45  	Answer []struct {
    46  		Name string     `json:"name"`
    47  		Type recordType `json:"type"`
    48  		TTL  uint16
    49  		Data string `json:"data"`
    50  	}
    51  	EdnsClientSubnet string `json:"edns_client_subnet"`
    52  	Comment          string
    53  }
    54  
    55  func (jr *dohjsonresponse) hosts() []string {
    56  	if len(jr.Answer) == 0 {
    57  		return nil
    58  	}
    59  	hosts := make([]string, 0, len(jr.Answer))
    60  	for _, ans := range jr.Answer {
    61  		if ans.Type == recordTypeA || ans.Type == recordTypeAAAA {
    62  			hosts = append(hosts, ans.Data)
    63  		}
    64  	}
    65  	return hosts
    66  }
    67  
    68  var lookupTable = ttl.NewCache[string, []string](time.Hour)
    69  
    70  var trsHTTP2ClientWithSystemDNS = http.Client{
    71  	Transport: &http2.Transport{
    72  		DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
    73  			if defaultDialer.Timeout != 0 {
    74  				var cancel context.CancelFunc
    75  				ctx, cancel = context.WithTimeout(ctx, defaultDialer.Timeout)
    76  				defer cancel()
    77  			}
    78  
    79  			if !defaultDialer.Deadline.IsZero() {
    80  				var cancel context.CancelFunc
    81  				ctx, cancel = context.WithDeadline(ctx, defaultDialer.Deadline)
    82  				defer cancel()
    83  			}
    84  
    85  			host, port, err := net.SplitHostPort(addr)
    86  			if err != nil {
    87  				return nil, err
    88  			}
    89  			addrs := lookupTable.Get(host)
    90  			if len(addrs) == 0 {
    91  				addrs, err = net.DefaultResolver.LookupHost(ctx, host)
    92  				if err != nil {
    93  					return nil, err
    94  				}
    95  				lookupTable.Set(host, addrs)
    96  			}
    97  			if len(addr) == 0 {
    98  				return nil, ErrEmptyHostAddress
    99  			}
   100  			var conn net.Conn
   101  			var tlsConn *tls.Conn
   102  			for _, a := range addrs {
   103  				conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
   104  				if err != nil {
   105  					continue
   106  				}
   107  				tlsConn = tls.Client(conn, cfg)
   108  				err = terasu.Use(tlsConn).HandshakeContext(ctx, terasu.DefaultFirstFragmentLen)
   109  				if err == nil {
   110  					break
   111  				}
   112  				_ = tlsConn.Close()
   113  				tlsConn = nil
   114  				conn, err = defaultDialer.DialContext(ctx, network, net.JoinHostPort(a, port))
   115  				if err != nil {
   116  					continue
   117  				}
   118  				tlsConn = tls.Client(conn, cfg)
   119  				err = tlsConn.HandshakeContext(ctx)
   120  				if err == nil {
   121  					break
   122  				}
   123  				_ = tlsConn.Close()
   124  				tlsConn = nil
   125  			}
   126  			return tlsConn, err
   127  		},
   128  	},
   129  }
   130  
   131  func lookupdoh(server, u string) (jr dohjsonresponse, err error) {
   132  	jr, err = lookupdohwithtype(server, u, preferreddohtype())
   133  	if err == nil {
   134  		return
   135  	}
   136  	if ip.IsIPv6Available.Get() {
   137  		jr, err = lookupdohwithtype(server, u, recordTypeA)
   138  	}
   139  	return
   140  }
   141  
   142  func lookupdohwithtype(server, u string, typ recordType) (jr dohjsonresponse, err error) {
   143  	sb := strings.Builder{}
   144  	sb.WriteString(server)
   145  	sb.WriteString("?name=")
   146  	sb.WriteString(url.QueryEscape(u))
   147  	if typ != recordTypeNone {
   148  		sb.WriteString("&type=")
   149  		sb.WriteString(strconv.Itoa(int(typ)))
   150  	}
   151  	req, err := http.NewRequest("GET", sb.String(), nil)
   152  	if err != nil {
   153  		return
   154  	}
   155  	req.Header.Add("accept", "application/dns-json")
   156  	resp, err := trsHTTP2ClientWithSystemDNS.Do(req)
   157  	if err != nil {
   158  		return
   159  	}
   160  	defer resp.Body.Close()
   161  	err = json.NewDecoder(resp.Body).Decode(&jr)
   162  	if err != nil {
   163  		return
   164  	}
   165  	if jr.Status != 0 {
   166  		err = errors.New("comment: " + jr.Comment)
   167  	}
   168  	return
   169  }
   170  
   171  func preferreddohtype() recordType {
   172  	if ip.IsIPv6Available.Get() {
   173  		return recordTypeAAAA
   174  	}
   175  	return recordTypeA
   176  }