storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/cmd/http/dial_dnscache.go (about)

     1  /*
     2   * MinIO Cloud Storage, (C) 2020 MinIO, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package http
    18  
    19  import (
    20  	"context"
    21  	"math/rand"
    22  	"net"
    23  	"sync"
    24  	"time"
    25  )
    26  
    27  var randPerm = func(n int) []int {
    28  	return rand.Perm(n)
    29  }
    30  
    31  // DialContextWithDNSCache is a helper function which returns `net.DialContext` function.
    32  // It randomly fetches an IP from the DNS cache and dials it by the given dial
    33  // function. It dials one by one and returns first connected `net.Conn`.
    34  // If it fails to dial all IPs from cache it returns first error. If no baseDialFunc
    35  // is given, it sets default dial function.
    36  //
    37  // You can use returned dial function for `http.Transport.DialContext`.
    38  //
    39  // In this function, it uses functions from `rand` package. To make it really random,
    40  // you MUST call `rand.Seed` and change the value from the default in your application
    41  func DialContextWithDNSCache(cache *DNSCache, baseDialCtx DialContext) DialContext {
    42  	if baseDialCtx == nil {
    43  		// This is same as which `http.DefaultTransport` uses.
    44  		baseDialCtx = (&net.Dialer{
    45  			Timeout:   30 * time.Second,
    46  			KeepAlive: 30 * time.Second,
    47  		}).DialContext
    48  	}
    49  	return func(ctx context.Context, network, host string) (net.Conn, error) {
    50  		h, p, err := net.SplitHostPort(host)
    51  		if err != nil {
    52  			return nil, err
    53  		}
    54  
    55  		// Fetch DNS result from cache.
    56  		//
    57  		// ctxLookup is only used for canceling DNS Lookup.
    58  		ctxLookup, cancelF := context.WithTimeout(ctx, cache.lookupTimeout)
    59  		defer cancelF()
    60  		addrs, err := cache.Fetch(ctxLookup, h)
    61  		if err != nil {
    62  			return nil, err
    63  		}
    64  
    65  		var firstErr error
    66  		for _, randomIndex := range randPerm(len(addrs)) {
    67  			conn, err := baseDialCtx(ctx, "tcp", net.JoinHostPort(addrs[randomIndex], p))
    68  			if err == nil {
    69  				return conn, nil
    70  			}
    71  			if firstErr == nil {
    72  				firstErr = err
    73  			}
    74  		}
    75  
    76  		return nil, firstErr
    77  	}
    78  }
    79  
    80  const (
    81  	// cacheSize is initial size of addr and IP list cache map.
    82  	cacheSize = 64
    83  )
    84  
    85  // defaultFreq is default frequency a resolver refreshes DNS cache.
    86  var (
    87  	defaultFreq          = 3 * time.Second
    88  	defaultLookupTimeout = 10 * time.Second
    89  )
    90  
    91  // DNSCache is DNS cache resolver which cache DNS resolve results in memory.
    92  type DNSCache struct {
    93  	sync.RWMutex
    94  
    95  	lookupHostFn  func(ctx context.Context, host string) ([]string, error)
    96  	lookupTimeout time.Duration
    97  	loggerOnce    func(ctx context.Context, err error, id interface{}, errKind ...interface{})
    98  
    99  	cache    map[string][]string
   100  	doneOnce sync.Once
   101  	doneCh   chan struct{}
   102  }
   103  
   104  // NewDNSCache initializes DNS cache resolver and starts auto refreshing
   105  // in a new goroutine. To stop auto refreshing, call `Stop()` function.
   106  // Once `Stop()` is called auto refreshing cannot be resumed.
   107  func NewDNSCache(freq time.Duration, lookupTimeout time.Duration, loggerOnce func(ctx context.Context, err error, id interface{}, errKind ...interface{})) *DNSCache {
   108  	if freq <= 0 {
   109  		freq = defaultFreq
   110  	}
   111  
   112  	if lookupTimeout <= 0 {
   113  		lookupTimeout = defaultLookupTimeout
   114  	}
   115  
   116  	r := &DNSCache{
   117  		lookupHostFn:  net.DefaultResolver.LookupHost,
   118  		lookupTimeout: lookupTimeout,
   119  		loggerOnce:    loggerOnce,
   120  		cache:         make(map[string][]string, cacheSize),
   121  		doneCh:        make(chan struct{}),
   122  	}
   123  
   124  	rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
   125  
   126  	timer := time.NewTimer(freq)
   127  	go func() {
   128  		defer timer.Stop()
   129  
   130  		for {
   131  			select {
   132  			case <-timer.C:
   133  				// Make sure that refreshes on DNS do not be attempted
   134  				// at the same time, allows for reduced load on the
   135  				// DNS servers.
   136  				timer.Reset(time.Duration(rnd.Float64() * float64(freq)))
   137  
   138  				r.Refresh()
   139  			case <-r.doneCh:
   140  				return
   141  			}
   142  		}
   143  	}()
   144  
   145  	return r
   146  }
   147  
   148  // LookupHost lookups address list from DNS server, persist the results
   149  // in-memory cache. `Fetch` is used to obtain the values for a given host.
   150  func (r *DNSCache) LookupHost(ctx context.Context, host string) ([]string, error) {
   151  	addrs, err := r.lookupHostFn(ctx, host)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	r.Lock()
   157  	r.cache[host] = addrs
   158  	r.Unlock()
   159  
   160  	return addrs, nil
   161  }
   162  
   163  // Fetch fetches IP list from the cache. If IP list of the given addr is not in the cache,
   164  // then it lookups from DNS server by `Lookup` function.
   165  func (r *DNSCache) Fetch(ctx context.Context, host string) ([]string, error) {
   166  	r.RLock()
   167  	addrs, ok := r.cache[host]
   168  	r.RUnlock()
   169  	if ok {
   170  		return addrs, nil
   171  	}
   172  	return r.LookupHost(ctx, host)
   173  }
   174  
   175  // Refresh refreshes IP list cache, automatically.
   176  func (r *DNSCache) Refresh() {
   177  	r.RLock()
   178  	hosts := make([]string, 0, len(r.cache))
   179  	for host := range r.cache {
   180  		hosts = append(hosts, host)
   181  	}
   182  	r.RUnlock()
   183  
   184  	for _, host := range hosts {
   185  		ctx, cancelF := context.WithTimeout(context.Background(), r.lookupTimeout)
   186  		if _, err := r.LookupHost(ctx, host); err != nil {
   187  			r.loggerOnce(ctx, err, host)
   188  		}
   189  		cancelF()
   190  	}
   191  }
   192  
   193  // Stop stops auto refreshing.
   194  func (r *DNSCache) Stop() {
   195  	r.doneOnce.Do(func() {
   196  		close(r.doneCh)
   197  	})
   198  }