github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/network/multi.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"time"
     8  
     9  	"github.com/sagernet/sing/common"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  )
    13  
    14  const DefaultFallbackDelay = 300 * time.Millisecond
    15  
    16  func DialSerial(ctx context.Context, dialer Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
    17  	if parallelDialer, isParallel := dialer.(ParallelDialer); isParallel {
    18  		return parallelDialer.DialParallel(ctx, network, destination, destinationAddresses)
    19  	}
    20  	var conn net.Conn
    21  	var err error
    22  	var connErrors []error
    23  	for _, address := range destinationAddresses {
    24  		conn, err = dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port))
    25  		if err != nil {
    26  			connErrors = append(connErrors, err)
    27  			continue
    28  		}
    29  		return conn, nil
    30  	}
    31  	return nil, E.Errors(connErrors...)
    32  }
    33  
    34  func ListenSerial(ctx context.Context, dialer Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.PacketConn, netip.Addr, error) {
    35  	var conn net.PacketConn
    36  	var err error
    37  	var connErrors []error
    38  	for _, address := range destinationAddresses {
    39  		conn, err = dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port))
    40  		if err != nil {
    41  			connErrors = append(connErrors, err)
    42  			continue
    43  		}
    44  		return conn, address, nil
    45  	}
    46  	return nil, netip.Addr{}, E.Errors(connErrors...)
    47  }
    48  
    49  func DialParallel(ctx context.Context, dialer Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, fallbackDelay time.Duration) (net.Conn, error) {
    50  	// kanged form net.Dial
    51  
    52  	if fallbackDelay == 0 {
    53  		fallbackDelay = DefaultFallbackDelay
    54  	}
    55  
    56  	returned := make(chan struct{})
    57  	defer close(returned)
    58  
    59  	addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
    60  		return address.Is4() || address.Is4In6()
    61  	})
    62  	addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool {
    63  		return address.Is6() && !address.Is4In6()
    64  	})
    65  	if len(addresses4) == 0 || len(addresses6) == 0 {
    66  		return DialSerial(ctx, dialer, network, destination, destinationAddresses)
    67  	}
    68  	var primaries, fallbacks []netip.Addr
    69  	if preferIPv6 {
    70  		primaries = addresses6
    71  		fallbacks = addresses4
    72  	} else {
    73  		primaries = addresses4
    74  		fallbacks = addresses6
    75  	}
    76  	type dialResult struct {
    77  		net.Conn
    78  		error
    79  		primary bool
    80  		done    bool
    81  	}
    82  	results := make(chan dialResult) // unbuffered
    83  	startRacer := func(ctx context.Context, primary bool) {
    84  		ras := primaries
    85  		if !primary {
    86  			ras = fallbacks
    87  		}
    88  		c, err := DialSerial(ctx, dialer, network, destination, ras)
    89  		select {
    90  		case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
    91  		case <-returned:
    92  			if c != nil {
    93  				c.Close()
    94  			}
    95  		}
    96  	}
    97  	var primary, fallback dialResult
    98  	primaryCtx, primaryCancel := context.WithCancel(ctx)
    99  	defer primaryCancel()
   100  	go startRacer(primaryCtx, true)
   101  	fallbackTimer := time.NewTimer(fallbackDelay)
   102  	defer fallbackTimer.Stop()
   103  	for {
   104  		select {
   105  		case <-fallbackTimer.C:
   106  			fallbackCtx, fallbackCancel := context.WithCancel(ctx)
   107  			defer fallbackCancel()
   108  			go startRacer(fallbackCtx, false)
   109  
   110  		case res := <-results:
   111  			if res.error == nil {
   112  				return res.Conn, nil
   113  			}
   114  			if res.primary {
   115  				primary = res
   116  			} else {
   117  				fallback = res
   118  			}
   119  			if primary.done && fallback.done {
   120  				return nil, primary.error
   121  			}
   122  			if res.primary && fallbackTimer.Stop() {
   123  				fallbackTimer.Reset(0)
   124  			}
   125  		}
   126  	}
   127  }