
     1  package rrdialer
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"fmt"
     7  	"math"
     8  	"net"
     9  	"os"
    10  	"os/user"
    11  	"path/filepath"
    12  	"runtime"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    17  	""
    18  	""
    19  )
    21  const (
    22  	weight = 0.2
    24  	dirPerms         = 0755
    25  	privateFilePerms = 0600
    26  	publicFilePerms  = 0644
    27  )
    29  var (
    30  	pid = strconv.FormatInt(int64(os.Getpid()), 10)
    31  )
    33  type endpointType struct {
    34  	address                    string // Host:port
    35  	conn                       net.Conn
    36  	dialing                    bool
    37  	err                        error
    38  	LastUpdate                 time.Time
    39  	LatencyVariance            float64 // Seconds^2.
    40  	MaximumLatency             float64 // Seconds.
    41  	MeanLatency                float64 // Seconds.
    42  	MinimumLatency             float64 // Seconds.
    43  	standardDeviationOfLatency float64 // Seconds.
    44  }
    46  func getFastestEndpoint(endpoints []*endpointType) *endpointType {
    47  	var fastestEndpoint *endpointType
    48  	for _, endpoint := range endpoints {
    49  		if endpoint.dialing {
    50  			continue
    51  		}
    52  		if (fastestEndpoint == nil) ||
    53  			(endpoint.MeanLatency > 0 &&
    54  				endpoint.MeanLatency < fastestEndpoint.MeanLatency) {
    55  			fastestEndpoint = endpoint
    56  		}
    57  	}
    58  	return fastestEndpoint
    59  }
    61  func getHomeDirectory() (string, error) {
    62  	if homeDir := os.Getenv("HOME"); homeDir != "" {
    63  		return homeDir, nil
    64  	}
    65  	if usr, err := user.Current(); err != nil {
    66  		return "", err
    67  	} else {
    68  		return usr.HomeDir, nil
    69  	}
    70  }
    72  func getMostStaleEndpoint(endpoints []*endpointType) *endpointType {
    73  	var mostStaleEndpoint *endpointType
    74  	for _, endpoint := range endpoints {
    75  		if endpoint.dialing {
    76  			continue
    77  		}
    78  		if (mostStaleEndpoint == nil) ||
    79  			endpoint.LastUpdate.Before(mostStaleEndpoint.LastUpdate) {
    80  			mostStaleEndpoint = endpoint
    81  		}
    82  	}
    83  	return mostStaleEndpoint
    84  }
    86  func newDialer(dialer *net.Dialer, cacheDir string,
    87  	logger log.DebugLogger) (*Dialer, error) {
    88  	rrDialer := &Dialer{
    89  		logger:    logger,
    90  		rawDialer: dialer,
    91  	}
    92  	if cacheDir == "" {
    93  		homedir, err := getHomeDirectory()
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		cacheDir = filepath.Join(homedir, ".cache")
    98  	}
    99  	rrDialer.dirname = filepath.Join(cacheDir, "round-robin-dialer")
   100  	return rrDialer, nil
   101  }
   103  func makeFilename(dirname, address string) string {
   104  	if runtime.GOOS == "windows" {
   105  		address = strings.Replace(address, ":", "_", -1)
   106  	}
   107  	return filepath.Join(dirname, address)
   108  }
   110  func (d *Dialer) loadEndpointHistories(hostAddrs []string,
   111  	port string) ([]*endpointType, error) {
   112  	endpoints := make([]*endpointType, 0, len(hostAddrs))
   113  	for _, hostAddr := range hostAddrs {
   114  		address := hostAddr + ":" + port
   115  		if endpoint, err := d.loadEndpointHistory(address); err != nil {
   116  			return nil, err
   117  		} else {
   118  			endpoints = append(endpoints, endpoint)
   119  		}
   120  	}
   121  	return endpoints, nil
   122  }
   124  func (d *Dialer) loadEndpointHistory(address string) (*endpointType, error) {
   125  	filename := makeFilename(d.dirname, address)
   126  	var endpoint endpointType
   127  	if err := json.ReadFromFile(filename, &endpoint); err != nil {
   128  		if !os.IsNotExist(err) {
   129  			return nil, err
   130  		}
   131  		return &endpointType{address: address}, nil
   132  	} else {
   133  		endpoint.address = address
   134  		endpoint.computeStandardDeviationOfLatency()
   135  		return &endpoint, nil
   136  	}
   137  }
   139  func (d *Dialer) dialContext(ctx context.Context, network,
   140  	address string) (net.Conn, error) {
   141  	host, port, err := net.SplitHostPort(address)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	resolver := d.rawDialer.Resolver
   146  	if resolver == nil {
   147  		resolver = net.DefaultResolver
   148  	}
   149  	hostAddrs, err := resolver.LookupHost(context.Background(), host)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	if len(hostAddrs) < 1 {
   154  		return nil, fmt.Errorf("no addresses found for: %s", host)
   155  	} else if len(hostAddrs) == 1 {
   156  		return d.rawDialer.DialContext(ctx, network, hostAddrs[0]+":"+port)
   157  	}
   158  	logLevel := int16(-1)
   159  	if getter, ok := d.logger.(log.DebugLogLevelGetter); ok {
   160  		logLevel = getter.GetLevel()
   161  	}
   162  	endpoints, err := d.loadEndpointHistories(hostAddrs, port)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	return d.dialEndpoints(ctx, network, address, endpoints, logLevel)
   167  }
   169  func (d *Dialer) dialEndpoints(ctx context.Context, network, address string,
   170  	endpoints []*endpointType, logLevel int16) (net.Conn, error) {
   171  	timeoutTimer := time.NewTimer(d.rawDialer.Timeout)
   172  	results := make(chan *endpointType, len(endpoints))
   173  	// Immediately dial the historically fastest endpoint.
   174  	fastestEndpoint := getFastestEndpoint(endpoints)
   175  	d.goDialEndpoint(ctx, network, fastestEndpoint, "fastest", results)
   176  	impatienceTimerFastest := fastestEndpoint.makeImpatienceTimer()
   177  	stalestEndpoint := getMostStaleEndpoint(endpoints)
   178  	d.goDialEndpoint(ctx, network, stalestEndpoint, "oldest", results)
   179  	impatienceTimerStalest := stalestEndpoint.makeImpatienceTimer()
   180  	// Dial all endpoints without history or if debug mode is enabled.
   181  	for _, endpoint := range endpoints {
   182  		if logLevel >= 3 || endpoint.MeanLatency <= 0 {
   183  			d.goDialEndpoint(ctx, network, endpoint, "all", results)
   184  		}
   185  	}
   186  	failureCounter := 0
   187  	problemCounter := 0
   188  	for {
   189  		select {
   190  		case endpoint := <-results:
   191  			if endpoint.err != nil {
   192  				failureCounter++
   193  				problemCounter++
   194  				if failureCounter >= len(endpoints) {
   195  					for _, endpoint := range endpoints {
   196  						d.logger.Printf("error dialing: %s: %s\n",
   197  							endpoint.address, endpoint.err)
   198  					}
   199  					return nil, fmt.Errorf("failed connecting to: %s", address)
   200  				}
   201  				for _, endpoint := range endpoints {
   202  					d.goDialEndpoint(ctx, network, endpoint, "backups",
   203  						results)
   204  				}
   205  				if problemCounter == 2 {
   206  					d.logger.Println(
   207  						"At least 2 endpoints have issues, dialed remaining endpoints")
   208  				}
   209  				break
   210  			}
   211  			d.logger.Debugf(2, "connected: %s\n", endpoint.conn.RemoteAddr())
   212  			return endpoint.conn, nil
   213  		case <-impatienceTimerFastest.C:
   214  			problemCounter++
   215  			for _, endpoint := range endpoints {
   216  				d.goDialEndpoint(ctx, network, endpoint, "impatiently", results)
   217  			}
   218  			if problemCounter == 2 {
   219  				d.logger.Println(
   220  					"At least 2 endpoints have issues, dialed remaining endpoints")
   221  			}
   222  		case <-impatienceTimerStalest.C:
   223  			problemCounter++
   224  			for _, endpoint := range endpoints {
   225  				d.goDialEndpoint(ctx, network, endpoint, "impatiently", results)
   226  			}
   227  			if problemCounter == 2 {
   228  				d.logger.Println(
   229  					"At least 2 endpoints have issues, dialed remaining endpoints")
   230  			}
   231  		case <-timeoutTimer.C:
   232  			return nil, fmt.Errorf("timed out connecting to: %s", address)
   233  		}
   234  	}
   235  }
   237  func (d *Dialer) goDialEndpoint(ctx context.Context, network string,
   238  	endpoint *endpointType, reason string, result chan<- *endpointType) {
   239  	if endpoint.dialing {
   240  		return
   241  	}
   242  	endpoint.dialing = true
   243  	endpoint.LastUpdate = time.Now()
   244  	d.logger.Debugf(2, "dialing %s: %s\n", reason, endpoint.address)
   245  	d.waitGroup.Add(1)
   246  	go func() {
   247  		defer d.waitGroup.Done()
   248  		startTime := time.Now()
   249  		conn, err := d.rawDialer.DialContext(ctx, network, endpoint.address)
   250  		if err != nil {
   251  			endpoint.err = err
   252  		} else {
   253  			endpoint.conn = conn
   254  			d.recordEvent(endpoint, time.Since(startTime).Seconds())
   255  		}
   256  		result <- endpoint
   257  	}()
   258  }
   260  func (d *Dialer) recordEvent(endpoint *endpointType, latency float64) {
   261  	if d.dirname == "" { // When testing.
   262  		return
   263  	}
   264  	filename := makeFilename(d.dirname, endpoint.address)
   265  	tmpFilename := makeFilename(d.dirname, endpoint.address+pid)
   266  	endpoint.LastUpdate = time.Now()
   267  	if endpoint.MeanLatency <= 0 {
   268  		endpoint.MeanLatency = latency
   269  	} else {
   270  		delta := latency - endpoint.MeanLatency
   271  		endpoint.MeanLatency = latency*weight +
   272  			(1.0-weight)*endpoint.MeanLatency
   273  		endpoint.LatencyVariance = (1.0 - weight) *
   274  			(endpoint.LatencyVariance + weight*delta*delta)
   275  	}
   276  	endpoint.computeStandardDeviationOfLatency()
   277  	d.logger.Debugf(3, "%s: L: %f ms, Lm: %f ms, Lsd: %f ms\n",
   278  		endpoint.address, latency*1e3, endpoint.MeanLatency*1e3,
   279  		endpoint.standardDeviationOfLatency*1e3)
   280  	if latency > endpoint.MaximumLatency {
   281  		endpoint.MaximumLatency = latency
   282  	}
   283  	if endpoint.MinimumLatency <= 0 || latency < endpoint.MinimumLatency {
   284  		endpoint.MinimumLatency = latency
   285  	}
   286  	file, err := os.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY,
   287  		publicFilePerms)
   288  	if err != nil {
   289  		if os.IsNotExist(err) {
   290  			if e := os.MkdirAll(d.dirname, dirPerms); e != nil {
   291  				d.logger.Println(err)
   292  				d.logger.Println(e)
   293  				return
   294  			}
   295  		}
   296  		file, err = os.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY,
   297  			publicFilePerms)
   298  	}
   299  	if err != nil {
   300  		d.logger.Println(err)
   301  		return
   302  	}
   303  	defer file.Close()
   304  	defer os.Remove(tmpFilename)
   305  	writer := bufio.NewWriter(file)
   306  	defer writer.Flush()
   307  	if err := json.WriteWithIndent(writer, "    ", endpoint); err != nil {
   308  		d.logger.Println(err)
   309  		return
   310  	}
   311  	if err := writer.Flush(); err != nil {
   312  		d.logger.Println(err)
   313  		return
   314  	}
   315  	if err := file.Close(); err != nil {
   316  		d.logger.Println(err)
   317  		return
   318  	}
   319  	if err := os.Rename(tmpFilename, filename); err != nil {
   320  		d.logger.Println(err)
   321  		return
   322  	}
   323  }
   325  func (d *Dialer) waitForBackgroundResults(timeout time.Duration) {
   326  	finished := make(chan struct{}, 1)
   327  	timer := time.NewTimer(timeout)
   328  	go func(finished chan<- struct{}) {
   329  		d.waitGroup.Wait()
   330  		finished <- struct{}{}
   331  	}(finished)
   332  	select {
   333  	case <-finished:
   334  		timer.Stop()
   335  	case <-timer.C:
   336  	}
   337  }
   339  func (e *endpointType) computeStandardDeviationOfLatency() {
   340  	if e.LatencyVariance <= 0 {
   341  		return
   342  	}
   343  	e.standardDeviationOfLatency = math.Sqrt(e.LatencyVariance)
   344  }
   346  func (e *endpointType) makeImpatienceTimer() *time.Timer {
   347  	if e.LatencyVariance <= 0 {
   348  		timer := time.NewTimer(time.Second)
   349  		timer.Stop()
   350  		return timer
   351  	}
   352  	timeoutDelta := e.MeanLatency * 0.1
   353  	if td := 2 * e.standardDeviationOfLatency; td > timeoutDelta {
   354  		timeoutDelta = td
   355  	}
   356  	return time.NewTimer(time.Duration(float64(time.Second) *
   357  		(e.MeanLatency + timeoutDelta)))
   358  }