github.com/mailru/activerecord@v1.12.2/pkg/iproto/netutil/dialer.go (about) 1 package netutil 2 3 import ( 4 "errors" 5 "net" 6 "sync" 7 "time" 8 9 "github.com/mailru/activerecord/pkg/iproto/syncutil" 10 egotime "github.com/mailru/activerecord/pkg/iproto/util/time" 11 "golang.org/x/net/context" 12 ) 13 14 const DefaultLoopInterval = time.Millisecond * 50 15 16 var ( 17 ErrClosed = errors.New("dialer owner has been gone") 18 ) 19 20 // BackgroundDialer is a wrapper around Dialer that contains logic of glueing 21 // and cancellation of dial requests. 22 type BackgroundDialer struct { 23 mu sync.Mutex 24 timer *time.Timer 25 deadline time.Time 26 27 Dialer *Dialer 28 TaskGroup *syncutil.TaskGroup 29 TaskRunner *syncutil.TaskRunner 30 } 31 32 // Dial begins dial routine if no one was started yet. It returns channel 33 // that signals about routine is done. If some routine was started before and 34 // not done yet, it returns done channel of that goroutine. 35 // It returns non-nil error only if dial routine was not started. 36 // 37 // Started routine could be cancelled by calling Cancel method. 38 // 39 // Note that cb is called only once. That is, if caller A calls Dial and caller 40 // B calls Dial immediately after, both of them will receive the same done 41 // channel, but only A's callback will be called in the end. 42 func (d *BackgroundDialer) Dial(ctx context.Context, cb func(net.Conn, error)) <-chan error { 43 ps := d.TaskRunner.Do(ctx, func(ctx context.Context) error { 44 conn, err := d.Dialer.Dial(ctx) 45 cb(conn, err) 46 return err 47 }) 48 49 return ps 50 } 51 52 // Cancel stops current background dial routine. 53 func (d *BackgroundDialer) Cancel() { 54 d.TaskRunner.Cancel() 55 } 56 57 // SetDeadline sets the dial deadline. 58 // 59 // A deadline is an absolute time after which all dial routines fail. 60 // The deadline applies to all future and pending dials, not just the 61 // immediately following call to Dial. 62 // Cancelling some routine by calling Cancel method will not affect deadline. 63 // After a deadline has been exceeded, the dialer can be refreshed by setting a 64 // deadline in the future. 65 // 66 // A zero value for t means dial routines will not time out. 67 func (d *BackgroundDialer) SetDeadline(t time.Time) { 68 d.mu.Lock() 69 defer d.mu.Unlock() 70 d.setDeadline(t) 71 } 72 73 // SetDeadlineAtLeast sets the dial deadline if current deadline is zero or 74 // less than t. The other behaviour is the same as in SetDeadline. 75 // 76 // A zero value for t is ignored. 77 // 78 // It returns actual deadline value. 79 func (d *BackgroundDialer) SetDeadlineAtLeast(t time.Time) time.Time { 80 d.mu.Lock() 81 defer d.mu.Unlock() 82 83 if d.deadline.Before(t) { 84 d.setDeadline(t) 85 } 86 87 return d.deadline 88 } 89 90 // Mutex must be held. 91 func (d *BackgroundDialer) setDeadline(t time.Time) { 92 d.deadline = t 93 94 if t.IsZero() { 95 if d.timer != nil { 96 d.timer.Stop() 97 } 98 99 return 100 } 101 102 //nolint:gosimple 103 tm := t.Sub(time.Now()) 104 if tm < 0 { 105 tm = 0 106 } 107 108 if d.timer == nil { 109 d.timer = time.AfterFunc(tm, d.Cancel) 110 } else { 111 // We do not check d.timer.Stop() here cause it is not a problem, if 112 // deadline has been reached and some dialing routine was cancelled. 113 d.timer.Reset(tm) 114 } 115 } 116 117 // Dialer contains options for connecting to an address. 118 type Dialer struct { 119 // Network and Addr are destination credentials. 120 Network, Addr string 121 122 // Timeout is the maximum amount of time a dial will wait for a single 123 // connect to complete. 124 Timeout time.Duration 125 126 // LoopTimeout is the maximum amount of time a dial loop will wait for a 127 // successful established connection. It may fail earlier if Closed option 128 // is set or Cancel method is called. 129 LoopTimeout time.Duration 130 131 // LoopInterval is used to delay dial attepmts between each other. 132 LoopInterval time.Duration 133 134 // MaxLoopInterval is the maximum delay before next attempt to connect is 135 // prepared. Note that LoopInterval is used as initial delay, and could be 136 // increased by every dial attempt up to MaxLoopInterval. 137 MaxLoopInterval time.Duration 138 139 // Closed signals that Dialer owner is closed forever and will never want 140 // to dial again. 141 Closed chan struct{} 142 143 // OnAttempt will be called with every dial attempt error. Nil error means 144 // that dial succeed. 145 OnAttempt func(error) 146 147 // NetDial could be set to override dial function. By default net.Dial is 148 // used. 149 NetDial func(ctx context.Context, network, addr string) (net.Conn, error) 150 151 // Logf could be set to receive log messages from Dialer. 152 Logf func(string, ...interface{}) 153 Debugf func(string, ...interface{}) 154 155 // DisableLogAddr removes addr part in log message prefix. 156 DisableLogAddr bool 157 } 158 159 // Dial tries to connect until some of events occur: 160 // - successful connect; 161 // – ctx is cancelled; 162 // – dialer owner is closed; 163 // – loop timeout exceeded (if set); 164 func (d *Dialer) Dial(ctx context.Context) (conn net.Conn, err error) { 165 var ( 166 maxInterval = d.MaxLoopInterval 167 step = d.LoopInterval 168 ) 169 170 if step == 0 { 171 step = DefaultLoopInterval 172 } 173 174 interval := step 175 if maxInterval < interval { 176 maxInterval = interval 177 } 178 179 loopTimer := egotime.AcquireTimer(interval) 180 defer egotime.ReleaseTimer(loopTimer) 181 182 if tm := d.LoopTimeout; tm != 0 { 183 ctx, _ = context.WithTimeout(ctx, tm) 184 } 185 186 var attempts int 187 188 for { 189 d.debugf("dialing (%d)", attempts) 190 attempts++ 191 192 conn, err = d.dial(ctx) 193 if cb := d.OnAttempt; cb != nil { 194 cb(err) 195 } 196 197 if err == nil { 198 d.debugf("dial ok: local addr is %s", conn.LocalAddr().String()) 199 return 200 } 201 202 if ctx.Err() == nil { 203 d.logf("dial error: %v; delaying next attempt for %s", err, interval) 204 } else { 205 d.logf("dial error: %v;", err) 206 } 207 208 select { 209 case <-loopTimer.C: 210 // 211 case <-ctx.Done(): 212 err = ctx.Err() 213 return 214 case <-d.Closed: 215 err = ErrClosed 216 return 217 } 218 219 interval += step 220 if interval > maxInterval { 221 interval = maxInterval 222 } 223 224 loopTimer.Reset(interval) 225 } 226 } 227 228 func (d *Dialer) dial(ctx context.Context) (conn net.Conn, err error) { 229 if tm := d.Timeout; tm != 0 { 230 ctx, _ = context.WithTimeout(ctx, tm) 231 } 232 233 netDial := d.NetDial 234 if netDial == nil { 235 netDial = defaultNetDial 236 } 237 238 return netDial(ctx, d.Network, d.Addr) 239 } 240 241 func (d *Dialer) getLogPrefix() string { 242 if d.DisableLogAddr { 243 return "dialer: " 244 } 245 246 return `dialer to "` + d.Network + `:` + d.Addr + `": ` 247 } 248 249 func (d *Dialer) logf(fmt string, args ...interface{}) { 250 prefix := d.getLogPrefix() 251 252 if logf := d.Logf; logf != nil { 253 logf(prefix+fmt, args...) 254 } 255 } 256 257 func (d *Dialer) debugf(fmt string, args ...interface{}) { 258 prefix := d.getLogPrefix() 259 260 if debugf := d.Debugf; debugf != nil { 261 debugf(prefix+fmt, args...) 262 } 263 } 264 265 var emptyDialer net.Dialer 266 267 func defaultNetDial(ctx context.Context, network, addr string) (net.Conn, error) { 268 return emptyDialer.DialContext(ctx, network, addr) 269 }