github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/requests/dialer_swarm.go (about) 1 /* 2 * Copyright (C) 2020 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 package requests 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "net" 25 "strings" 26 "syscall" 27 "time" 28 29 "github.com/rs/zerolog/log" 30 31 "github.com/mysteriumnetwork/node/requests/resolver" 32 "github.com/mysteriumnetwork/node/router" 33 ) 34 35 // ErrAllDialsFailed is returned when connecting to a peer has ultimately failed. 36 var ErrAllDialsFailed = errors.New("all dials failed") 37 38 // DialerSwarm is a dials to multiple addresses in parallel and earliest successful connection wins. 39 type DialerSwarm struct { 40 // ResolveContext specifies the resolve function for doing custom DNS lookup. 41 // If ResolveContext is nil, then the transport dials using package net. 42 ResolveContext resolver.ResolveContext 43 44 // Dialer specifies the dial function for creating unencrypted TCP connections. 45 Dialer DialContext 46 47 // dnsHeadstart specifies the time delay that requests via IP incur. 48 dnsHeadstart time.Duration 49 } 50 51 // NewDialerSwarm creates swarm dialer with default configuration. 52 func NewDialerSwarm(srcIP string, dnsHeadstart time.Duration) *DialerSwarm { 53 return &DialerSwarm{ 54 dnsHeadstart: dnsHeadstart, 55 Dialer: (wrapDialer(&net.Dialer{ 56 Timeout: 60 * time.Second, 57 KeepAlive: 30 * time.Second, 58 LocalAddr: &net.TCPAddr{IP: net.ParseIP(srcIP)}, 59 Control: func(net, address string, c syscall.RawConn) (err error) { 60 if net == "tcp6" { 61 return fmt.Errorf("ipv6 not supported") 62 } 63 64 err = c.Control(func(f uintptr) { 65 log.Trace().Msgf("Protecting connection to: %s (%s)", address, net) 66 67 fd := int(f) 68 err := router.Protect(fd) 69 if err != nil { 70 log.Error().Err(err).Msgf("Failed to protect connection to: %s (%s)", address, net) 71 } 72 }) 73 return err 74 }, 75 })).DialContext, 76 } 77 } 78 79 // DialContext connects to the address on the named network using the provided context. 80 func (ds *DialerSwarm) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 81 if ds.ResolveContext != nil { 82 addrs, err := ds.ResolveContext(ctx, network, addr) 83 if err != nil { 84 return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} 85 } 86 87 conn, errDial := ds.dialAddrs(ctx, network, addrs) 88 if errDial != nil { 89 errDial.OriginalAddr = addr 90 91 return nil, errDial 92 } 93 94 return conn, nil 95 } 96 97 return ds.Dialer(ctx, network, addr) 98 } 99 100 func (ds *DialerSwarm) dialAddrs(ctx context.Context, network string, addrs []string) (net.Conn, *ErrorSwarmDial) { 101 addrChan := make(chan string, len(addrs)) 102 for _, addr := range addrs { 103 addrChan <- addr 104 } 105 106 close(addrChan) 107 108 ctx, cancel := context.WithCancel(ctx) 109 defer cancel() 110 111 resultCh := make(chan dialResult) 112 err := &ErrorSwarmDial{} 113 114 var active int 115 dialLoop: 116 for addrChan != nil || active > 0 { 117 // Check for context cancellations and/or responses first. 118 select { 119 // Overall dialing canceled. 120 case <-ctx.Done(): 121 break dialLoop 122 123 // Some dial result arrived. 124 case resp := <-resultCh: 125 active-- 126 if resp.Err != nil { 127 err.addErr(resp.Addr, resp.Err) 128 } else if resp.Conn != nil { 129 return resp.Conn, nil 130 } 131 132 continue 133 134 default: 135 } 136 137 // Now, attempt to dial. 138 select { 139 case addr, ok := <-addrChan: 140 if !ok { 141 addrChan = nil 142 143 continue 144 } 145 146 // Prefer dialing via dns, give them a head start. 147 if !isIP(addr) { 148 go ds.dialAddr(ctx, network, addr, resultCh) 149 } else { 150 go func() { 151 select { 152 case <-time.After(ds.dnsHeadstart): 153 break 154 case <-ctx.Done(): 155 return 156 } 157 ds.dialAddr(ctx, network, addr, resultCh) 158 }() 159 } 160 161 active++ 162 163 case <-ctx.Done(): 164 break dialLoop 165 166 case resp := <-resultCh: 167 active-- 168 if resp.Err != nil { 169 err.addErr(resp.Addr, resp.Err) 170 } else if resp.Conn != nil { 171 return resp.Conn, nil 172 } 173 } 174 } 175 176 if ctxErr := ctx.Err(); ctxErr != nil { 177 err.Cause = ctxErr 178 } else { 179 err.Cause = ErrAllDialsFailed 180 } 181 182 return nil, err 183 } 184 185 func isIP(addr string) bool { 186 host, _, err := net.SplitHostPort(addr) 187 if err != nil { 188 ip := net.ParseIP(addr) 189 return ip != nil 190 } 191 ip := net.ParseIP(host) 192 return ip != nil 193 } 194 195 func (ds *DialerSwarm) dialAddr(ctx context.Context, network, addr string, resp chan dialResult) { 196 // Dialing might be canceled already. 197 if ctx.Err() != nil { 198 return 199 } 200 201 conn, err := ds.Dialer(ctx, network, addr) 202 select { 203 case resp <- dialResult{Conn: conn, Addr: addr, Err: err}: 204 case <-ctx.Done(): 205 if err == nil { 206 conn.Close() 207 } 208 } 209 } 210 211 type dialResult struct { 212 Conn net.Conn 213 Addr string 214 Err error 215 } 216 217 // ErrorSwarmDial is the error type returned when dialing multiple addresses. 218 type ErrorSwarmDial struct { 219 OriginalAddr string 220 DialErrors []ErrorDial 221 Cause error 222 } 223 224 func (e *ErrorSwarmDial) addErr(addr string, err error) { 225 e.DialErrors = append(e.DialErrors, ErrorDial{ 226 Addr: addr, 227 Cause: err, 228 }) 229 } 230 231 // Error returns string equivalent for error. 232 func (e *ErrorSwarmDial) Error() string { 233 var builder strings.Builder 234 235 fmt.Fprintf(&builder, "failed to dial %s:", e.OriginalAddr) 236 237 if e.Cause != nil { 238 fmt.Fprintf(&builder, " %s", e.Cause) 239 } 240 241 for _, te := range e.DialErrors { 242 fmt.Fprintf(&builder, "\n * [%s] %s", te.Addr, te.Cause) 243 } 244 245 return builder.String() 246 } 247 248 // Unwrap unwraps the original err for use with errors.Unwrap. 249 func (e *ErrorSwarmDial) Unwrap() error { 250 return e.Cause 251 } 252 253 // ErrorDial is the error returned when dialing a specific address. 254 type ErrorDial struct { 255 Addr string 256 Cause error 257 } 258 259 // Error returns string equivalent for error. 260 func (e *ErrorDial) Error() string { 261 return fmt.Sprintf("failed to dial %s: %s", e.Addr, e.Cause) 262 } 263 264 // Unwrap unwraps the original err for use with errors.Unwrap. 265 func (e *ErrorDial) Unwrap() error { 266 return e.Cause 267 } 268 269 type dialerWithDNSCache struct { 270 dialer *net.Dialer 271 } 272 273 func wrapDialer(dialer *net.Dialer) *dialerWithDNSCache { 274 return &dialerWithDNSCache{ 275 dialer: dialer, 276 } 277 } 278 279 func (wd *dialerWithDNSCache) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 280 go func() { 281 if !isIP(addr) { 282 addrHost, _, err := net.SplitHostPort(addr) 283 if err != nil { 284 log.Warn().Msgf("Failed to get host from: %s (%s)", addr, network) 285 return 286 } 287 288 lookupCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 289 defer cancel() 290 291 addrs, err := net.DefaultResolver.LookupHost(lookupCtx, addrHost) 292 if err != nil { 293 log.Warn().Err(err).Msgf("Failed to lookup host: %q", addrHost) 294 return 295 } 296 297 resolver.CacheDNSRecord(addrHost, addrs) 298 } 299 }() 300 301 return wd.dialer.DialContext(ctx, network, addr) 302 }