github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/requests/dialer_swarm.go (about)

     1  /*
     2   * Copyright (C) 2020 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package requests
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"strings"
    26  	"syscall"
    27  	"time"
    28  
    29  	"github.com/rs/zerolog/log"
    30  
    31  	"github.com/mysteriumnetwork/node/requests/resolver"
    32  	"github.com/mysteriumnetwork/node/router"
    33  )
    34  
    35  // ErrAllDialsFailed is returned when connecting to a peer has ultimately failed.
    36  var ErrAllDialsFailed = errors.New("all dials failed")
    37  
    38  // DialerSwarm is a dials to multiple addresses in parallel and earliest successful connection wins.
    39  type DialerSwarm struct {
    40  	// ResolveContext specifies the resolve function for doing custom DNS lookup.
    41  	// If ResolveContext is nil, then the transport dials using package net.
    42  	ResolveContext resolver.ResolveContext
    43  
    44  	// Dialer specifies the dial function for creating unencrypted TCP connections.
    45  	Dialer DialContext
    46  
    47  	// dnsHeadstart specifies the time delay that requests via IP incur.
    48  	dnsHeadstart time.Duration
    49  }
    50  
    51  // NewDialerSwarm creates swarm dialer with default configuration.
    52  func NewDialerSwarm(srcIP string, dnsHeadstart time.Duration) *DialerSwarm {
    53  	return &DialerSwarm{
    54  		dnsHeadstart: dnsHeadstart,
    55  		Dialer: (wrapDialer(&net.Dialer{
    56  			Timeout:   60 * time.Second,
    57  			KeepAlive: 30 * time.Second,
    58  			LocalAddr: &net.TCPAddr{IP: net.ParseIP(srcIP)},
    59  			Control: func(net, address string, c syscall.RawConn) (err error) {
    60  				if net == "tcp6" {
    61  					return fmt.Errorf("ipv6 not supported")
    62  				}
    63  
    64  				err = c.Control(func(f uintptr) {
    65  					log.Trace().Msgf("Protecting connection to: %s (%s)", address, net)
    66  
    67  					fd := int(f)
    68  					err := router.Protect(fd)
    69  					if err != nil {
    70  						log.Error().Err(err).Msgf("Failed to protect connection to: %s (%s)", address, net)
    71  					}
    72  				})
    73  				return err
    74  			},
    75  		})).DialContext,
    76  	}
    77  }
    78  
    79  // DialContext connects to the address on the named network using the provided context.
    80  func (ds *DialerSwarm) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
    81  	if ds.ResolveContext != nil {
    82  		addrs, err := ds.ResolveContext(ctx, network, addr)
    83  		if err != nil {
    84  			return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
    85  		}
    86  
    87  		conn, errDial := ds.dialAddrs(ctx, network, addrs)
    88  		if errDial != nil {
    89  			errDial.OriginalAddr = addr
    90  
    91  			return nil, errDial
    92  		}
    93  
    94  		return conn, nil
    95  	}
    96  
    97  	return ds.Dialer(ctx, network, addr)
    98  }
    99  
   100  func (ds *DialerSwarm) dialAddrs(ctx context.Context, network string, addrs []string) (net.Conn, *ErrorSwarmDial) {
   101  	addrChan := make(chan string, len(addrs))
   102  	for _, addr := range addrs {
   103  		addrChan <- addr
   104  	}
   105  
   106  	close(addrChan)
   107  
   108  	ctx, cancel := context.WithCancel(ctx)
   109  	defer cancel()
   110  
   111  	resultCh := make(chan dialResult)
   112  	err := &ErrorSwarmDial{}
   113  
   114  	var active int
   115  dialLoop:
   116  	for addrChan != nil || active > 0 {
   117  		// Check for context cancellations and/or responses first.
   118  		select {
   119  		// Overall dialing canceled.
   120  		case <-ctx.Done():
   121  			break dialLoop
   122  
   123  		// Some dial result arrived.
   124  		case resp := <-resultCh:
   125  			active--
   126  			if resp.Err != nil {
   127  				err.addErr(resp.Addr, resp.Err)
   128  			} else if resp.Conn != nil {
   129  				return resp.Conn, nil
   130  			}
   131  
   132  			continue
   133  
   134  		default:
   135  		}
   136  
   137  		// Now, attempt to dial.
   138  		select {
   139  		case addr, ok := <-addrChan:
   140  			if !ok {
   141  				addrChan = nil
   142  
   143  				continue
   144  			}
   145  
   146  			// Prefer dialing via dns, give them a head start.
   147  			if !isIP(addr) {
   148  				go ds.dialAddr(ctx, network, addr, resultCh)
   149  			} else {
   150  				go func() {
   151  					select {
   152  					case <-time.After(ds.dnsHeadstart):
   153  						break
   154  					case <-ctx.Done():
   155  						return
   156  					}
   157  					ds.dialAddr(ctx, network, addr, resultCh)
   158  				}()
   159  			}
   160  
   161  			active++
   162  
   163  		case <-ctx.Done():
   164  			break dialLoop
   165  
   166  		case resp := <-resultCh:
   167  			active--
   168  			if resp.Err != nil {
   169  				err.addErr(resp.Addr, resp.Err)
   170  			} else if resp.Conn != nil {
   171  				return resp.Conn, nil
   172  			}
   173  		}
   174  	}
   175  
   176  	if ctxErr := ctx.Err(); ctxErr != nil {
   177  		err.Cause = ctxErr
   178  	} else {
   179  		err.Cause = ErrAllDialsFailed
   180  	}
   181  
   182  	return nil, err
   183  }
   184  
   185  func isIP(addr string) bool {
   186  	host, _, err := net.SplitHostPort(addr)
   187  	if err != nil {
   188  		ip := net.ParseIP(addr)
   189  		return ip != nil
   190  	}
   191  	ip := net.ParseIP(host)
   192  	return ip != nil
   193  }
   194  
   195  func (ds *DialerSwarm) dialAddr(ctx context.Context, network, addr string, resp chan dialResult) {
   196  	// Dialing might be canceled already.
   197  	if ctx.Err() != nil {
   198  		return
   199  	}
   200  
   201  	conn, err := ds.Dialer(ctx, network, addr)
   202  	select {
   203  	case resp <- dialResult{Conn: conn, Addr: addr, Err: err}:
   204  	case <-ctx.Done():
   205  		if err == nil {
   206  			conn.Close()
   207  		}
   208  	}
   209  }
   210  
   211  type dialResult struct {
   212  	Conn net.Conn
   213  	Addr string
   214  	Err  error
   215  }
   216  
   217  // ErrorSwarmDial is the error type returned when dialing multiple addresses.
   218  type ErrorSwarmDial struct {
   219  	OriginalAddr string
   220  	DialErrors   []ErrorDial
   221  	Cause        error
   222  }
   223  
   224  func (e *ErrorSwarmDial) addErr(addr string, err error) {
   225  	e.DialErrors = append(e.DialErrors, ErrorDial{
   226  		Addr:  addr,
   227  		Cause: err,
   228  	})
   229  }
   230  
   231  // Error returns string equivalent for error.
   232  func (e *ErrorSwarmDial) Error() string {
   233  	var builder strings.Builder
   234  
   235  	fmt.Fprintf(&builder, "failed to dial %s:", e.OriginalAddr)
   236  
   237  	if e.Cause != nil {
   238  		fmt.Fprintf(&builder, " %s", e.Cause)
   239  	}
   240  
   241  	for _, te := range e.DialErrors {
   242  		fmt.Fprintf(&builder, "\n  * [%s] %s", te.Addr, te.Cause)
   243  	}
   244  
   245  	return builder.String()
   246  }
   247  
   248  // Unwrap unwraps the original err for use with errors.Unwrap.
   249  func (e *ErrorSwarmDial) Unwrap() error {
   250  	return e.Cause
   251  }
   252  
   253  // ErrorDial is the error returned when dialing a specific address.
   254  type ErrorDial struct {
   255  	Addr  string
   256  	Cause error
   257  }
   258  
   259  // Error returns string equivalent for error.
   260  func (e *ErrorDial) Error() string {
   261  	return fmt.Sprintf("failed to dial %s: %s", e.Addr, e.Cause)
   262  }
   263  
   264  // Unwrap unwraps the original err for use with errors.Unwrap.
   265  func (e *ErrorDial) Unwrap() error {
   266  	return e.Cause
   267  }
   268  
   269  type dialerWithDNSCache struct {
   270  	dialer *net.Dialer
   271  }
   272  
   273  func wrapDialer(dialer *net.Dialer) *dialerWithDNSCache {
   274  	return &dialerWithDNSCache{
   275  		dialer: dialer,
   276  	}
   277  }
   278  
   279  func (wd *dialerWithDNSCache) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   280  	go func() {
   281  		if !isIP(addr) {
   282  			addrHost, _, err := net.SplitHostPort(addr)
   283  			if err != nil {
   284  				log.Warn().Msgf("Failed to get host from: %s (%s)", addr, network)
   285  				return
   286  			}
   287  
   288  			lookupCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
   289  			defer cancel()
   290  
   291  			addrs, err := net.DefaultResolver.LookupHost(lookupCtx, addrHost)
   292  			if err != nil {
   293  				log.Warn().Err(err).Msgf("Failed to lookup host: %q", addrHost)
   294  				return
   295  			}
   296  
   297  			resolver.CacheDNSRecord(addrHost, addrs)
   298  		}
   299  	}()
   300  
   301  	return wd.dialer.DialContext(ctx, network, addr)
   302  }