github.com/mailru/activerecord@v1.12.2/pkg/iproto/netutil/dialer.go (about)

     1  package netutil
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/mailru/activerecord/pkg/iproto/syncutil"
    10  	egotime "github.com/mailru/activerecord/pkg/iproto/util/time"
    11  	"golang.org/x/net/context"
    12  )
    13  
    14  const DefaultLoopInterval = time.Millisecond * 50
    15  
    16  var (
    17  	ErrClosed = errors.New("dialer owner has been gone")
    18  )
    19  
    20  // BackgroundDialer is a wrapper around Dialer that contains logic of glueing
    21  // and cancellation of dial requests.
    22  type BackgroundDialer struct {
    23  	mu       sync.Mutex
    24  	timer    *time.Timer
    25  	deadline time.Time
    26  
    27  	Dialer     *Dialer
    28  	TaskGroup  *syncutil.TaskGroup
    29  	TaskRunner *syncutil.TaskRunner
    30  }
    31  
    32  // Dial begins dial routine if no one was started yet. It returns channel
    33  // that signals about routine is done. If some routine was started before and
    34  // not done yet, it returns done channel of that goroutine.
    35  // It returns non-nil error only if dial routine was not started.
    36  //
    37  // Started routine could be cancelled by calling Cancel method.
    38  //
    39  // Note that cb is called only once. That is, if caller A calls Dial and caller
    40  // B calls Dial immediately after, both of them will receive the same done
    41  // channel, but only A's callback will be called in the end.
    42  func (d *BackgroundDialer) Dial(ctx context.Context, cb func(net.Conn, error)) <-chan error {
    43  	ps := d.TaskRunner.Do(ctx, func(ctx context.Context) error {
    44  		conn, err := d.Dialer.Dial(ctx)
    45  		cb(conn, err)
    46  		return err
    47  	})
    48  
    49  	return ps
    50  }
    51  
    52  // Cancel stops current background dial routine.
    53  func (d *BackgroundDialer) Cancel() {
    54  	d.TaskRunner.Cancel()
    55  }
    56  
    57  // SetDeadline sets the dial deadline.
    58  //
    59  // A deadline is an absolute time after which all dial routines fail.
    60  // The deadline applies to all future and pending dials, not just the
    61  // immediately following call to Dial.
    62  // Cancelling some routine by calling Cancel method will not affect deadline.
    63  // After a deadline has been exceeded, the dialer can be refreshed by setting a
    64  // deadline in the future.
    65  //
    66  // A zero value for t means dial routines will not time out.
    67  func (d *BackgroundDialer) SetDeadline(t time.Time) {
    68  	d.mu.Lock()
    69  	defer d.mu.Unlock()
    70  	d.setDeadline(t)
    71  }
    72  
    73  // SetDeadlineAtLeast sets the dial deadline if current deadline is zero or
    74  // less than t. The other behaviour is the same as in SetDeadline.
    75  //
    76  // A zero value for t is ignored.
    77  //
    78  // It returns actual deadline value.
    79  func (d *BackgroundDialer) SetDeadlineAtLeast(t time.Time) time.Time {
    80  	d.mu.Lock()
    81  	defer d.mu.Unlock()
    82  
    83  	if d.deadline.Before(t) {
    84  		d.setDeadline(t)
    85  	}
    86  
    87  	return d.deadline
    88  }
    89  
    90  // Mutex must be held.
    91  func (d *BackgroundDialer) setDeadline(t time.Time) {
    92  	d.deadline = t
    93  
    94  	if t.IsZero() {
    95  		if d.timer != nil {
    96  			d.timer.Stop()
    97  		}
    98  
    99  		return
   100  	}
   101  
   102  	//nolint:gosimple
   103  	tm := t.Sub(time.Now())
   104  	if tm < 0 {
   105  		tm = 0
   106  	}
   107  
   108  	if d.timer == nil {
   109  		d.timer = time.AfterFunc(tm, d.Cancel)
   110  	} else {
   111  		// We do not check d.timer.Stop() here cause it is not a problem, if
   112  		// deadline has been reached and some dialing routine was cancelled.
   113  		d.timer.Reset(tm)
   114  	}
   115  }
   116  
   117  // Dialer contains options for connecting to an address.
   118  type Dialer struct {
   119  	// Network and Addr are destination credentials.
   120  	Network, Addr string
   121  
   122  	// Timeout is the maximum amount of time a dial will wait for a single
   123  	// connect to complete.
   124  	Timeout time.Duration
   125  
   126  	// LoopTimeout is the maximum amount of time a dial loop will wait for a
   127  	// successful established connection. It may fail earlier if Closed option
   128  	// is set or Cancel method is called.
   129  	LoopTimeout time.Duration
   130  
   131  	// LoopInterval is used to delay dial attepmts between each other.
   132  	LoopInterval time.Duration
   133  
   134  	// MaxLoopInterval is the maximum delay before next attempt to connect is
   135  	// prepared. Note that LoopInterval is used as initial delay, and could be
   136  	// increased by every dial attempt up to MaxLoopInterval.
   137  	MaxLoopInterval time.Duration
   138  
   139  	// Closed signals that Dialer owner is closed forever and will never want
   140  	// to dial again.
   141  	Closed chan struct{}
   142  
   143  	// OnAttempt will be called with every dial attempt error. Nil error means
   144  	// that dial succeed.
   145  	OnAttempt func(error)
   146  
   147  	// NetDial could be set to override dial function. By default net.Dial is
   148  	// used.
   149  	NetDial func(ctx context.Context, network, addr string) (net.Conn, error)
   150  
   151  	// Logf could be set to receive log messages from Dialer.
   152  	Logf   func(string, ...interface{})
   153  	Debugf func(string, ...interface{})
   154  
   155  	// DisableLogAddr removes addr part in log message prefix.
   156  	DisableLogAddr bool
   157  }
   158  
   159  // Dial tries to connect until some of events occur:
   160  // - successful connect;
   161  // – ctx is cancelled;
   162  // – dialer owner is closed;
   163  // – loop timeout exceeded (if set);
   164  func (d *Dialer) Dial(ctx context.Context) (conn net.Conn, err error) {
   165  	var (
   166  		maxInterval = d.MaxLoopInterval
   167  		step        = d.LoopInterval
   168  	)
   169  
   170  	if step == 0 {
   171  		step = DefaultLoopInterval
   172  	}
   173  
   174  	interval := step
   175  	if maxInterval < interval {
   176  		maxInterval = interval
   177  	}
   178  
   179  	loopTimer := egotime.AcquireTimer(interval)
   180  	defer egotime.ReleaseTimer(loopTimer)
   181  
   182  	if tm := d.LoopTimeout; tm != 0 {
   183  		ctx, _ = context.WithTimeout(ctx, tm)
   184  	}
   185  
   186  	var attempts int
   187  
   188  	for {
   189  		d.debugf("dialing (%d)", attempts)
   190  		attempts++
   191  
   192  		conn, err = d.dial(ctx)
   193  		if cb := d.OnAttempt; cb != nil {
   194  			cb(err)
   195  		}
   196  
   197  		if err == nil {
   198  			d.debugf("dial ok: local addr is %s", conn.LocalAddr().String())
   199  			return
   200  		}
   201  
   202  		if ctx.Err() == nil {
   203  			d.logf("dial error: %v; delaying next attempt for %s", err, interval)
   204  		} else {
   205  			d.logf("dial error: %v;", err)
   206  		}
   207  
   208  		select {
   209  		case <-loopTimer.C:
   210  			//
   211  		case <-ctx.Done():
   212  			err = ctx.Err()
   213  			return
   214  		case <-d.Closed:
   215  			err = ErrClosed
   216  			return
   217  		}
   218  
   219  		interval += step
   220  		if interval > maxInterval {
   221  			interval = maxInterval
   222  		}
   223  
   224  		loopTimer.Reset(interval)
   225  	}
   226  }
   227  
   228  func (d *Dialer) dial(ctx context.Context) (conn net.Conn, err error) {
   229  	if tm := d.Timeout; tm != 0 {
   230  		ctx, _ = context.WithTimeout(ctx, tm)
   231  	}
   232  
   233  	netDial := d.NetDial
   234  	if netDial == nil {
   235  		netDial = defaultNetDial
   236  	}
   237  
   238  	return netDial(ctx, d.Network, d.Addr)
   239  }
   240  
   241  func (d *Dialer) getLogPrefix() string {
   242  	if d.DisableLogAddr {
   243  		return "dialer: "
   244  	}
   245  
   246  	return `dialer to "` + d.Network + `:` + d.Addr + `": `
   247  }
   248  
   249  func (d *Dialer) logf(fmt string, args ...interface{}) {
   250  	prefix := d.getLogPrefix()
   251  
   252  	if logf := d.Logf; logf != nil {
   253  		logf(prefix+fmt, args...)
   254  	}
   255  }
   256  
   257  func (d *Dialer) debugf(fmt string, args ...interface{}) {
   258  	prefix := d.getLogPrefix()
   259  
   260  	if debugf := d.Debugf; debugf != nil {
   261  		debugf(prefix+fmt, args...)
   262  	}
   263  }
   264  
   265  var emptyDialer net.Dialer
   266  
   267  func defaultNetDial(ctx context.Context, network, addr string) (net.Conn, error) {
   268  	return emptyDialer.DialContext(ctx, network, addr)
   269  }