github.com/chwjbn/xclash@v0.2.0/dns/dhcp.go (about)

     1  package dns
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/chwjbn/xclash/component/dhcp"
    11  	"github.com/chwjbn/xclash/component/iface"
    12  	"github.com/chwjbn/xclash/component/resolver"
    13  
    14  	D "github.com/miekg/dns"
    15  )
    16  
    17  const (
    18  	IfaceTTL    = time.Second * 20
    19  	DHCPTTL     = time.Hour
    20  	DHCPTimeout = time.Minute
    21  )
    22  
    23  type dhcpClient struct {
    24  	ifaceName string
    25  
    26  	lock            sync.Mutex
    27  	ifaceInvalidate time.Time
    28  	dnsInvalidate   time.Time
    29  
    30  	ifaceAddr *net.IPNet
    31  	done      chan struct{}
    32  	resolver  *Resolver
    33  	err       error
    34  }
    35  
    36  func (d *dhcpClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
    37  	ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
    38  	defer cancel()
    39  
    40  	return d.ExchangeContext(ctx, m)
    41  }
    42  
    43  func (d *dhcpClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
    44  	res, err := d.resolve(ctx)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	return res.ExchangeContext(ctx, m)
    50  }
    51  
    52  func (d *dhcpClient) resolve(ctx context.Context) (*Resolver, error) {
    53  	d.lock.Lock()
    54  
    55  	invalidated, err := d.invalidate()
    56  	if err != nil {
    57  		d.err = err
    58  	} else if invalidated {
    59  		done := make(chan struct{})
    60  
    61  		d.done = done
    62  
    63  		go func() {
    64  			ctx, cancel := context.WithTimeout(context.Background(), DHCPTimeout)
    65  			defer cancel()
    66  
    67  			var res *Resolver
    68  			dns, err := dhcp.ResolveDNSFromDHCP(ctx, d.ifaceName)
    69  			if err == nil {
    70  				nameserver := make([]NameServer, 0, len(dns))
    71  				for _, item := range dns {
    72  					nameserver = append(nameserver, NameServer{
    73  						Addr:      net.JoinHostPort(item.String(), "53"),
    74  						Interface: d.ifaceName,
    75  					})
    76  				}
    77  
    78  				res = NewResolver(Config{
    79  					Main: nameserver,
    80  				})
    81  			}
    82  
    83  			d.lock.Lock()
    84  			defer d.lock.Unlock()
    85  
    86  			close(done)
    87  
    88  			d.done = nil
    89  			d.resolver = res
    90  			d.err = err
    91  		}()
    92  	}
    93  
    94  	d.lock.Unlock()
    95  
    96  	for {
    97  		d.lock.Lock()
    98  
    99  		res, err, done := d.resolver, d.err, d.done
   100  
   101  		d.lock.Unlock()
   102  
   103  		// initializing
   104  		if res == nil && err == nil {
   105  			select {
   106  			case <-done:
   107  				continue
   108  			case <-ctx.Done():
   109  				return nil, ctx.Err()
   110  			}
   111  		}
   112  
   113  		// dirty return
   114  		return res, err
   115  	}
   116  }
   117  
   118  func (d *dhcpClient) invalidate() (bool, error) {
   119  	if time.Now().Before(d.ifaceInvalidate) {
   120  		return false, nil
   121  	}
   122  
   123  	d.ifaceInvalidate = time.Now().Add(IfaceTTL)
   124  
   125  	ifaceObj, err := iface.ResolveInterface(d.ifaceName)
   126  	if err != nil {
   127  		return false, err
   128  	}
   129  
   130  	addr, err := ifaceObj.PickIPv4Addr(nil)
   131  	if err != nil {
   132  		return false, err
   133  	}
   134  
   135  	if time.Now().Before(d.dnsInvalidate) && d.ifaceAddr.IP.Equal(addr.IP) && bytes.Equal(d.ifaceAddr.Mask, addr.Mask) {
   136  		return false, nil
   137  	}
   138  
   139  	d.dnsInvalidate = time.Now().Add(DHCPTTL)
   140  	d.ifaceAddr = addr
   141  
   142  	return d.done == nil, nil
   143  }
   144  
   145  func newDHCPClient(ifaceName string) *dhcpClient {
   146  	return &dhcpClient{ifaceName: ifaceName}
   147  }