github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/net/rrdialer/impl.go (about) 1 package rrdialer 2 3 import ( 4 "bufio" 5 "context" 6 "fmt" 7 "math" 8 "net" 9 "os" 10 "os/user" 11 "path/filepath" 12 "runtime" 13 "strconv" 14 "strings" 15 "time" 16 17 "github.com/Cloud-Foundations/Dominator/lib/json" 18 "github.com/Cloud-Foundations/Dominator/lib/log" 19 ) 20 21 const ( 22 weight = 0.2 23 24 dirPerms = 0755 25 privateFilePerms = 0600 26 publicFilePerms = 0644 27 ) 28 29 var ( 30 pid = strconv.FormatInt(int64(os.Getpid()), 10) 31 ) 32 33 type endpointType struct { 34 address string // Host:port 35 conn net.Conn 36 dialing bool 37 err error 38 LastUpdate time.Time 39 LatencyVariance float64 // Seconds^2. 40 MaximumLatency float64 // Seconds. 41 MeanLatency float64 // Seconds. 42 MinimumLatency float64 // Seconds. 43 standardDeviationOfLatency float64 // Seconds. 44 } 45 46 func getFastestEndpoint(endpoints []*endpointType) *endpointType { 47 var fastestEndpoint *endpointType 48 for _, endpoint := range endpoints { 49 if endpoint.dialing { 50 continue 51 } 52 if (fastestEndpoint == nil) || 53 (endpoint.MeanLatency > 0 && 54 endpoint.MeanLatency < fastestEndpoint.MeanLatency) { 55 fastestEndpoint = endpoint 56 } 57 } 58 return fastestEndpoint 59 } 60 61 func getHomeDirectory() (string, error) { 62 if homeDir := os.Getenv("HOME"); homeDir != "" { 63 return homeDir, nil 64 } 65 if usr, err := user.Current(); err != nil { 66 return "", err 67 } else { 68 return usr.HomeDir, nil 69 } 70 } 71 72 func getMostStaleEndpoint(endpoints []*endpointType) *endpointType { 73 var mostStaleEndpoint *endpointType 74 for _, endpoint := range endpoints { 75 if endpoint.dialing { 76 continue 77 } 78 if (mostStaleEndpoint == nil) || 79 endpoint.LastUpdate.Before(mostStaleEndpoint.LastUpdate) { 80 mostStaleEndpoint = endpoint 81 } 82 } 83 return mostStaleEndpoint 84 } 85 86 func newDialer(dialer *net.Dialer, cacheDir string, 87 logger log.DebugLogger) (*Dialer, error) { 88 rrDialer := &Dialer{ 89 logger: logger, 90 rawDialer: dialer, 91 } 92 if cacheDir == "" { 93 homedir, err := getHomeDirectory() 94 if err != nil { 95 return nil, err 96 } 97 cacheDir = filepath.Join(homedir, ".cache") 98 } 99 rrDialer.dirname = filepath.Join(cacheDir, "round-robin-dialer") 100 return rrDialer, nil 101 } 102 103 func makeFilename(dirname, address string) string { 104 if runtime.GOOS == "windows" { 105 address = strings.Replace(address, ":", "_", -1) 106 } 107 return filepath.Join(dirname, address) 108 } 109 110 func (d *Dialer) loadEndpointHistories(hostAddrs []string, 111 port string) ([]*endpointType, error) { 112 endpoints := make([]*endpointType, 0, len(hostAddrs)) 113 for _, hostAddr := range hostAddrs { 114 address := hostAddr + ":" + port 115 if endpoint, err := d.loadEndpointHistory(address); err != nil { 116 return nil, err 117 } else { 118 endpoints = append(endpoints, endpoint) 119 } 120 } 121 return endpoints, nil 122 } 123 124 func (d *Dialer) loadEndpointHistory(address string) (*endpointType, error) { 125 filename := makeFilename(d.dirname, address) 126 var endpoint endpointType 127 if err := json.ReadFromFile(filename, &endpoint); err != nil { 128 if !os.IsNotExist(err) { 129 return nil, err 130 } 131 return &endpointType{address: address}, nil 132 } else { 133 endpoint.address = address 134 endpoint.computeStandardDeviationOfLatency() 135 return &endpoint, nil 136 } 137 } 138 139 func (d *Dialer) dialContext(ctx context.Context, network, 140 address string) (net.Conn, error) { 141 host, port, err := net.SplitHostPort(address) 142 if err != nil { 143 return nil, err 144 } 145 resolver := d.rawDialer.Resolver 146 if resolver == nil { 147 resolver = net.DefaultResolver 148 } 149 hostAddrs, err := resolver.LookupHost(context.Background(), host) 150 if err != nil { 151 return nil, err 152 } 153 if len(hostAddrs) < 1 { 154 return nil, fmt.Errorf("no addresses found for: %s", host) 155 } else if len(hostAddrs) == 1 { 156 return d.rawDialer.DialContext(ctx, network, hostAddrs[0]+":"+port) 157 } 158 logLevel := int16(-1) 159 if getter, ok := d.logger.(log.DebugLogLevelGetter); ok { 160 logLevel = getter.GetLevel() 161 } 162 endpoints, err := d.loadEndpointHistories(hostAddrs, port) 163 if err != nil { 164 return nil, err 165 } 166 return d.dialEndpoints(ctx, network, address, endpoints, logLevel) 167 } 168 169 func (d *Dialer) dialEndpoints(ctx context.Context, network, address string, 170 endpoints []*endpointType, logLevel int16) (net.Conn, error) { 171 timeoutTimer := time.NewTimer(d.rawDialer.Timeout) 172 results := make(chan *endpointType, len(endpoints)) 173 // Immediately dial the historically fastest endpoint. 174 fastestEndpoint := getFastestEndpoint(endpoints) 175 d.goDialEndpoint(ctx, network, fastestEndpoint, "fastest", results) 176 impatienceTimerFastest := fastestEndpoint.makeImpatienceTimer() 177 stalestEndpoint := getMostStaleEndpoint(endpoints) 178 d.goDialEndpoint(ctx, network, stalestEndpoint, "oldest", results) 179 impatienceTimerStalest := stalestEndpoint.makeImpatienceTimer() 180 // Dial all endpoints without history or if debug mode is enabled. 181 for _, endpoint := range endpoints { 182 if logLevel >= 3 || endpoint.MeanLatency <= 0 { 183 d.goDialEndpoint(ctx, network, endpoint, "all", results) 184 } 185 } 186 failureCounter := 0 187 problemCounter := 0 188 for { 189 select { 190 case endpoint := <-results: 191 if endpoint.err != nil { 192 failureCounter++ 193 problemCounter++ 194 if failureCounter >= len(endpoints) { 195 for _, endpoint := range endpoints { 196 d.logger.Printf("error dialing: %s: %s\n", 197 endpoint.address, endpoint.err) 198 } 199 return nil, fmt.Errorf("failed connecting to: %s", address) 200 } 201 for _, endpoint := range endpoints { 202 d.goDialEndpoint(ctx, network, endpoint, "backups", 203 results) 204 } 205 if problemCounter == 2 { 206 d.logger.Println( 207 "At least 2 endpoints have issues, dialed remaining endpoints") 208 } 209 break 210 } 211 d.logger.Debugf(2, "connected: %s\n", endpoint.conn.RemoteAddr()) 212 return endpoint.conn, nil 213 case <-impatienceTimerFastest.C: 214 problemCounter++ 215 for _, endpoint := range endpoints { 216 d.goDialEndpoint(ctx, network, endpoint, "impatiently", results) 217 } 218 if problemCounter == 2 { 219 d.logger.Println( 220 "At least 2 endpoints have issues, dialed remaining endpoints") 221 } 222 case <-impatienceTimerStalest.C: 223 problemCounter++ 224 for _, endpoint := range endpoints { 225 d.goDialEndpoint(ctx, network, endpoint, "impatiently", results) 226 } 227 if problemCounter == 2 { 228 d.logger.Println( 229 "At least 2 endpoints have issues, dialed remaining endpoints") 230 } 231 case <-timeoutTimer.C: 232 return nil, fmt.Errorf("timed out connecting to: %s", address) 233 } 234 } 235 } 236 237 func (d *Dialer) goDialEndpoint(ctx context.Context, network string, 238 endpoint *endpointType, reason string, result chan<- *endpointType) { 239 if endpoint.dialing { 240 return 241 } 242 endpoint.dialing = true 243 endpoint.LastUpdate = time.Now() 244 d.logger.Debugf(2, "dialing %s: %s\n", reason, endpoint.address) 245 d.waitGroup.Add(1) 246 go func() { 247 defer d.waitGroup.Done() 248 startTime := time.Now() 249 conn, err := d.rawDialer.DialContext(ctx, network, endpoint.address) 250 if err != nil { 251 endpoint.err = err 252 } else { 253 endpoint.conn = conn 254 d.recordEvent(endpoint, time.Since(startTime).Seconds()) 255 } 256 result <- endpoint 257 }() 258 } 259 260 func (d *Dialer) recordEvent(endpoint *endpointType, latency float64) { 261 if d.dirname == "" { // When testing. 262 return 263 } 264 filename := makeFilename(d.dirname, endpoint.address) 265 tmpFilename := makeFilename(d.dirname, endpoint.address+pid) 266 endpoint.LastUpdate = time.Now() 267 if endpoint.MeanLatency <= 0 { 268 endpoint.MeanLatency = latency 269 } else { 270 delta := latency - endpoint.MeanLatency 271 endpoint.MeanLatency = latency*weight + 272 (1.0-weight)*endpoint.MeanLatency 273 endpoint.LatencyVariance = (1.0 - weight) * 274 (endpoint.LatencyVariance + weight*delta*delta) 275 } 276 endpoint.computeStandardDeviationOfLatency() 277 d.logger.Debugf(3, "%s: L: %f ms, Lm: %f ms, Lsd: %f ms\n", 278 endpoint.address, latency*1e3, endpoint.MeanLatency*1e3, 279 endpoint.standardDeviationOfLatency*1e3) 280 if latency > endpoint.MaximumLatency { 281 endpoint.MaximumLatency = latency 282 } 283 if endpoint.MinimumLatency <= 0 || latency < endpoint.MinimumLatency { 284 endpoint.MinimumLatency = latency 285 } 286 file, err := os.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 287 publicFilePerms) 288 if err != nil { 289 if os.IsNotExist(err) { 290 if e := os.MkdirAll(d.dirname, dirPerms); e != nil { 291 d.logger.Println(err) 292 d.logger.Println(e) 293 return 294 } 295 } 296 file, err = os.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 297 publicFilePerms) 298 } 299 if err != nil { 300 d.logger.Println(err) 301 return 302 } 303 defer file.Close() 304 defer os.Remove(tmpFilename) 305 writer := bufio.NewWriter(file) 306 defer writer.Flush() 307 if err := json.WriteWithIndent(writer, " ", endpoint); err != nil { 308 d.logger.Println(err) 309 return 310 } 311 if err := writer.Flush(); err != nil { 312 d.logger.Println(err) 313 return 314 } 315 if err := file.Close(); err != nil { 316 d.logger.Println(err) 317 return 318 } 319 if err := os.Rename(tmpFilename, filename); err != nil { 320 d.logger.Println(err) 321 return 322 } 323 } 324 325 func (d *Dialer) waitForBackgroundResults(timeout time.Duration) { 326 finished := make(chan struct{}, 1) 327 timer := time.NewTimer(timeout) 328 go func(finished chan<- struct{}) { 329 d.waitGroup.Wait() 330 finished <- struct{}{} 331 }(finished) 332 select { 333 case <-finished: 334 timer.Stop() 335 case <-timer.C: 336 } 337 } 338 339 func (e *endpointType) computeStandardDeviationOfLatency() { 340 if e.LatencyVariance <= 0 { 341 return 342 } 343 e.standardDeviationOfLatency = math.Sqrt(e.LatencyVariance) 344 } 345 346 func (e *endpointType) makeImpatienceTimer() *time.Timer { 347 if e.LatencyVariance <= 0 { 348 timer := time.NewTimer(time.Second) 349 timer.Stop() 350 return timer 351 } 352 timeoutDelta := e.MeanLatency * 0.1 353 if td := 2 * e.standardDeviationOfLatency; td > timeoutDelta { 354 timeoutDelta = td 355 } 356 return time.NewTimer(time.Duration(float64(time.Second) * 357 (e.MeanLatency + timeoutDelta))) 358 }