github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/dns.go (about)

     1  /*
     2   * Copyright (c) 2016, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package server
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"math/rand"
    26  	"net"
    27  	"strings"
    28  	"sync/atomic"
    29  	"time"
    30  
    31  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    32  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    33  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/monotime"
    34  )
    35  
    36  const (
    37  	DNS_SYSTEM_CONFIG_FILENAME      = "/etc/resolv.conf"
    38  	DNS_SYSTEM_CONFIG_RELOAD_PERIOD = 5 * time.Second
    39  	DNS_RESOLVER_PORT               = 53
    40  )
    41  
    42  // DNSResolver maintains fresh DNS resolver values, monitoring
    43  // "/etc/resolv.conf" on platforms where it is available; and
    44  // otherwise using a default value.
    45  type DNSResolver struct {
    46  	// Note: 64-bit ints used with atomic operations are placed
    47  	// at the start of struct to ensure 64-bit alignment.
    48  	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
    49  	lastReloadTime int64
    50  	common.ReloadableFile
    51  	isReloading int32
    52  	resolvers   []net.IP
    53  }
    54  
    55  // NewDNSResolver initializes a new DNSResolver, loading it with
    56  // fresh resolver values. The load must succeed, so either
    57  // "/etc/resolv.conf" must contain valid "nameserver" lines with
    58  // a DNS server IP address, or a valid "defaultResolver" default
    59  // value must be provided.
    60  // On systems without "/etc/resolv.conf", "defaultResolver" is
    61  // required.
    62  //
    63  // The resolver is considered stale and reloaded if last checked
    64  // more than 5 seconds before the last Get(), which is similar to
    65  // frequencies in other implementations:
    66  //
    67  // - https://golang.org/src/net/dnsclient_unix.go,
    68  //   resolverConfig.tryUpdate: 5 seconds
    69  //
    70  // - https://github.com/ambrop72/badvpn/blob/master/udpgw/udpgw.c,
    71  //   maybe_update_dns: 2 seconds
    72  //
    73  func NewDNSResolver(defaultResolver string) (*DNSResolver, error) {
    74  
    75  	dns := &DNSResolver{
    76  		lastReloadTime: int64(monotime.Now()),
    77  	}
    78  
    79  	dns.ReloadableFile = common.NewReloadableFile(
    80  		DNS_SYSTEM_CONFIG_FILENAME,
    81  		true,
    82  		func(fileContent []byte, _ time.Time) error {
    83  
    84  			resolvers, err := parseResolveConf(fileContent)
    85  			if err != nil {
    86  				// On error, state remains the same
    87  				return errors.Trace(err)
    88  			}
    89  
    90  			dns.resolvers = resolvers
    91  
    92  			log.WithTraceFields(
    93  				LogFields{
    94  					"resolvers": resolvers,
    95  				}).Debug("loaded system DNS resolvers")
    96  
    97  			return nil
    98  		})
    99  
   100  	_, err := dns.Reload()
   101  	if err != nil {
   102  		if defaultResolver == "" {
   103  			return nil, errors.Trace(err)
   104  		}
   105  
   106  		log.WithTraceFields(
   107  			LogFields{"err": err}).Info(
   108  			"failed to load system DNS resolver; using default")
   109  
   110  		resolver, err := parseResolver(defaultResolver)
   111  		if err != nil {
   112  			return nil, errors.Trace(err)
   113  		}
   114  
   115  		dns.resolvers = []net.IP{resolver}
   116  	}
   117  
   118  	return dns, nil
   119  }
   120  
   121  // Get returns one of the cached resolvers, selected at random,
   122  // after first updating the cached values if they're stale. If
   123  // reloading fails, the previous values are used.
   124  //
   125  // Randomly selecting any one of the configured resolvers is
   126  // expected to be more resiliant to failure; e.g., if one of
   127  // the resolvers becomes unavailable.
   128  func (dns *DNSResolver) Get() net.IP {
   129  
   130  	dns.reloadWhenStale()
   131  
   132  	dns.ReloadableFile.RLock()
   133  	defer dns.ReloadableFile.RUnlock()
   134  
   135  	return dns.resolvers[rand.Intn(len(dns.resolvers))]
   136  }
   137  
   138  func (dns *DNSResolver) reloadWhenStale() {
   139  
   140  	// Every UDP DNS port forward frequently calls Get(), so this code
   141  	// is intended to minimize blocking. Most callers will hit just the
   142  	// atomic.LoadInt64 reload time check and the RLock (an atomic.AddInt32
   143  	// when no write lock is pending). An atomic.CompareAndSwapInt32 is
   144  	// used to ensure only one goroutine enters Reload() and blocks on
   145  	// its write lock. Finally, since since ReloadableFile.Reload
   146  	// checks whether the underlying file has changed _before_ acquiring a
   147  	// write lock, we only incur write lock blocking when "/etc/resolv.conf"
   148  	// has actually changed.
   149  
   150  	lastReloadTime := monotime.Time(atomic.LoadInt64(&dns.lastReloadTime))
   151  	stale := monotime.Now().After(lastReloadTime.Add(DNS_SYSTEM_CONFIG_RELOAD_PERIOD))
   152  
   153  	if stale {
   154  
   155  		isReloader := atomic.CompareAndSwapInt32(&dns.isReloading, 0, 1)
   156  
   157  		if isReloader {
   158  
   159  			// Unconditionally set last reload time. Even on failure only
   160  			// want to retry after another DNS_SYSTEM_CONFIG_RELOAD_PERIOD.
   161  			atomic.StoreInt64(&dns.lastReloadTime, int64(monotime.Now()))
   162  
   163  			_, err := dns.Reload()
   164  			if err != nil {
   165  				log.WithTraceFields(
   166  					LogFields{"err": err}).Info(
   167  					"failed to reload system DNS resolver")
   168  			}
   169  
   170  			atomic.StoreInt32(&dns.isReloading, 0)
   171  		}
   172  	}
   173  }
   174  
   175  // GetAll returns a list of all DNS resolver addresses. Cached values are
   176  // updated if they're stale. If reloading fails, the previous values are
   177  // used.
   178  func (dns *DNSResolver) GetAll() []net.IP {
   179  	return dns.getAll(true, true)
   180  }
   181  
   182  // GetAllIPv4 returns a list of all IPv4 DNS resolver addresses.
   183  // Cached values are updated if they're stale. If reloading fails,
   184  // the previous values are used.
   185  func (dns *DNSResolver) GetAllIPv4() []net.IP {
   186  	return dns.getAll(true, false)
   187  }
   188  
   189  // GetAllIPv6 returns a list of all IPv6 DNS resolver addresses.
   190  // Cached values are updated if they're stale. If reloading fails,
   191  // the previous values are used.
   192  func (dns *DNSResolver) GetAllIPv6() []net.IP {
   193  	return dns.getAll(false, true)
   194  }
   195  
   196  func (dns *DNSResolver) getAll(wantIPv4, wantIPv6 bool) []net.IP {
   197  
   198  	dns.reloadWhenStale()
   199  
   200  	dns.ReloadableFile.RLock()
   201  	defer dns.ReloadableFile.RUnlock()
   202  
   203  	resolvers := make([]net.IP, 0)
   204  	for _, resolver := range dns.resolvers {
   205  		if resolver.To4() != nil {
   206  			if wantIPv4 {
   207  				resolvers = append(resolvers, resolver)
   208  			}
   209  		} else {
   210  			if wantIPv6 {
   211  				resolvers = append(resolvers, resolver)
   212  			}
   213  		}
   214  	}
   215  	return resolvers
   216  }
   217  
   218  func parseResolveConf(fileContent []byte) ([]net.IP, error) {
   219  
   220  	scanner := bufio.NewScanner(bytes.NewReader(fileContent))
   221  
   222  	var resolvers []net.IP
   223  
   224  	for scanner.Scan() {
   225  		line := scanner.Text()
   226  		if strings.HasPrefix(line, ";") || strings.HasPrefix(line, "#") {
   227  			continue
   228  		}
   229  		fields := strings.Fields(line)
   230  		if len(fields) == 2 && fields[0] == "nameserver" {
   231  			resolver, err := parseResolver(fields[1])
   232  			if err == nil {
   233  				resolvers = append(resolvers, resolver)
   234  			}
   235  		}
   236  	}
   237  
   238  	if err := scanner.Err(); err != nil {
   239  		return nil, errors.Trace(err)
   240  	}
   241  
   242  	if len(resolvers) == 0 {
   243  		return nil, errors.TraceNew("no nameservers found")
   244  	}
   245  
   246  	return resolvers, nil
   247  }
   248  
   249  func parseResolver(resolver string) (net.IP, error) {
   250  
   251  	ipAddress := net.ParseIP(resolver)
   252  	if ipAddress == nil {
   253  		return nil, errors.TraceNew("invalid IP address")
   254  	}
   255  
   256  	return ipAddress, nil
   257  }