github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/client.go (about)

     1  package fastdns
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"net/netip"
     7  	"os"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  var (
    13  	// ErrMaxConns is returned when dns client reaches the max connections limitation.
    14  	ErrMaxConns = errors.New("dns client reaches the max connections limitation")
    15  )
    16  
    17  // Client is an UDP client that supports DNS protocol.
    18  type Client struct {
    19  	AddrPort netip.AddrPort
    20  
    21  	// MaxIdleConns controls the maximum number of idle (keep-alive)
    22  	// connections. Zero means no limit.
    23  	MaxIdleConns int
    24  
    25  	// MaxConns optionally limits the total number of
    26  	// connections per host, including connections in the dialing,
    27  	// active, and idle states. On limit violation, ErrMaxConns will be return.
    28  	//
    29  	// Zero means no limit.
    30  	MaxConns int
    31  
    32  	// ReadTimeout is the maximum duration for reading the dns server response.
    33  	ReadTimeout time.Duration
    34  
    35  	mu    sync.Mutex
    36  	conns []*net.UDPConn
    37  }
    38  
    39  // Exchange executes a single DNS transaction, returning
    40  // a Response for the provided Request.
    41  func (c *Client) Exchange(req, resp *Message) (err error) {
    42  	err = c.exchange(req, resp)
    43  	if err != nil && os.IsTimeout(err) {
    44  		err = c.exchange(req, resp)
    45  	}
    46  	return err
    47  }
    48  
    49  func (c *Client) exchange(req, resp *Message) error {
    50  	var fresh bool
    51  	conn, err := c.get()
    52  	if conn == nil && err == nil {
    53  		conn, err = c.dial()
    54  		fresh = true
    55  	}
    56  	if err != nil {
    57  		return err
    58  	}
    59  
    60  	_, err = conn.Write(req.Raw)
    61  	if err != nil && !fresh {
    62  		// if error is a pooled conn, let's close it & retry again
    63  		conn.Close()
    64  		if conn, err = c.dial(); err != nil {
    65  			return err
    66  		}
    67  		if _, err = conn.Write(req.Raw); err != nil {
    68  			return err
    69  		}
    70  	}
    71  
    72  	if c.ReadTimeout > 0 {
    73  		err = conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
    74  		if err != nil {
    75  			return err
    76  		}
    77  	}
    78  
    79  	resp.Raw = resp.Raw[:cap(resp.Raw)]
    80  	n, err := conn.Read(resp.Raw)
    81  	if err == nil {
    82  		resp.Raw = resp.Raw[:n]
    83  		err = ParseMessage(resp, resp.Raw, false)
    84  	}
    85  
    86  	c.put(conn)
    87  
    88  	return err
    89  }
    90  
    91  func (c *Client) dial() (conn *net.UDPConn, err error) {
    92  	conn, err = net.DialUDP("udp", nil, net.UDPAddrFromAddrPort(c.AddrPort))
    93  	return
    94  }
    95  
    96  func (c *Client) get() (conn *net.UDPConn, err error) {
    97  	c.mu.Lock()
    98  	defer c.mu.Unlock()
    99  
   100  	count := len(c.conns)
   101  	if c.MaxConns != 0 && count > c.MaxConns {
   102  		err = ErrMaxConns
   103  
   104  		return
   105  	}
   106  	if count > 0 {
   107  		conn = c.conns[len(c.conns)-1]
   108  		c.conns = c.conns[:len(c.conns)-1]
   109  	}
   110  
   111  	return
   112  }
   113  
   114  func (c *Client) put(conn *net.UDPConn) {
   115  	c.mu.Lock()
   116  	defer c.mu.Unlock()
   117  
   118  	if (c.MaxIdleConns != 0 && len(c.conns) > c.MaxIdleConns) ||
   119  		(c.MaxConns != 0 && len(c.conns) > c.MaxConns) {
   120  		conn.Close()
   121  
   122  		return
   123  	}
   124  
   125  	c.conns = append(c.conns, conn)
   126  }