github.com/phuslu/fastdns@v0.8.3-0.20240310041952-69506fc67dd1/client.go (about) 1 package fastdns 2 3 import ( 4 "errors" 5 "net" 6 "net/netip" 7 "os" 8 "sync" 9 "time" 10 ) 11 12 var ( 13 // ErrMaxConns is returned when dns client reaches the max connections limitation. 14 ErrMaxConns = errors.New("dns client reaches the max connections limitation") 15 ) 16 17 // Client is an UDP client that supports DNS protocol. 18 type Client struct { 19 AddrPort netip.AddrPort 20 21 // MaxIdleConns controls the maximum number of idle (keep-alive) 22 // connections. Zero means no limit. 23 MaxIdleConns int 24 25 // MaxConns optionally limits the total number of 26 // connections per host, including connections in the dialing, 27 // active, and idle states. On limit violation, ErrMaxConns will be return. 28 // 29 // Zero means no limit. 30 MaxConns int 31 32 // ReadTimeout is the maximum duration for reading the dns server response. 33 ReadTimeout time.Duration 34 35 mu sync.Mutex 36 conns []*net.UDPConn 37 } 38 39 // Exchange executes a single DNS transaction, returning 40 // a Response for the provided Request. 41 func (c *Client) Exchange(req, resp *Message) (err error) { 42 err = c.exchange(req, resp) 43 if err != nil && os.IsTimeout(err) { 44 err = c.exchange(req, resp) 45 } 46 return err 47 } 48 49 func (c *Client) exchange(req, resp *Message) error { 50 var fresh bool 51 conn, err := c.get() 52 if conn == nil && err == nil { 53 conn, err = c.dial() 54 fresh = true 55 } 56 if err != nil { 57 return err 58 } 59 60 _, err = conn.Write(req.Raw) 61 if err != nil && !fresh { 62 // if error is a pooled conn, let's close it & retry again 63 conn.Close() 64 if conn, err = c.dial(); err != nil { 65 return err 66 } 67 if _, err = conn.Write(req.Raw); err != nil { 68 return err 69 } 70 } 71 72 if c.ReadTimeout > 0 { 73 err = conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)) 74 if err != nil { 75 return err 76 } 77 } 78 79 resp.Raw = resp.Raw[:cap(resp.Raw)] 80 n, err := conn.Read(resp.Raw) 81 if err == nil { 82 resp.Raw = resp.Raw[:n] 83 err = ParseMessage(resp, resp.Raw, false) 84 } 85 86 c.put(conn) 87 88 return err 89 } 90 91 func (c *Client) dial() (conn *net.UDPConn, err error) { 92 conn, err = net.DialUDP("udp", nil, net.UDPAddrFromAddrPort(c.AddrPort)) 93 return 94 } 95 96 func (c *Client) get() (conn *net.UDPConn, err error) { 97 c.mu.Lock() 98 defer c.mu.Unlock() 99 100 count := len(c.conns) 101 if c.MaxConns != 0 && count > c.MaxConns { 102 err = ErrMaxConns 103 104 return 105 } 106 if count > 0 { 107 conn = c.conns[len(c.conns)-1] 108 c.conns = c.conns[:len(c.conns)-1] 109 } 110 111 return 112 } 113 114 func (c *Client) put(conn *net.UDPConn) { 115 c.mu.Lock() 116 defer c.mu.Unlock() 117 118 if (c.MaxIdleConns != 0 && len(c.conns) > c.MaxIdleConns) || 119 (c.MaxConns != 0 && len(c.conns) > c.MaxConns) { 120 conn.Close() 121 122 return 123 } 124 125 c.conns = append(c.conns, conn) 126 }