github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/fs/fshttp/dialer.go (about)

     1  package fshttp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"runtime"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/rclone/rclone/fs"
    12  	"github.com/rclone/rclone/fs/accounting"
    13  	"golang.org/x/net/ipv4"
    14  	"golang.org/x/net/ipv6"
    15  )
    16  
    17  // Dialer structure contains default dialer and timeout, tclass support
    18  type Dialer struct {
    19  	net.Dialer
    20  	timeout time.Duration
    21  	tclass  int
    22  }
    23  
    24  // NewDialer creates a Dialer structure with Timeout, Keepalive,
    25  // LocalAddr and DSCP set from rclone flags.
    26  func NewDialer(ctx context.Context) *Dialer {
    27  	ci := fs.GetConfig(ctx)
    28  	dialer := &Dialer{
    29  		Dialer: net.Dialer{
    30  			Timeout:   ci.ConnectTimeout,
    31  			KeepAlive: 30 * time.Second,
    32  		},
    33  		timeout: ci.Timeout,
    34  		tclass:  int(ci.TrafficClass),
    35  	}
    36  	if ci.BindAddr != nil {
    37  		dialer.Dialer.LocalAddr = &net.TCPAddr{IP: ci.BindAddr}
    38  	}
    39  	return dialer
    40  }
    41  
    42  // Dial connects to the network address.
    43  func (d *Dialer) Dial(network, address string) (net.Conn, error) {
    44  	return d.DialContext(context.Background(), network, address)
    45  }
    46  
    47  var warnDSCPFail, warnDSCPWindows sync.Once
    48  
    49  // DialContext connects to the network address using the provided context.
    50  func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
    51  	// If local address is 0.0.0.0 or ::0 force IPv4 or IPv6
    52  	// This works around https://github.com/golang/go/issues/48723
    53  	// Which means 0.0.0.0 and ::0 both bind to both IPv4 and IPv6
    54  	if ip, ok := d.Dialer.LocalAddr.(*net.TCPAddr); ok && ip.IP.IsUnspecified() && (network == "tcp" || network == "udp") {
    55  		if ip.IP.To4() != nil {
    56  			network += "4" // IPv4 address
    57  		} else {
    58  			network += "6" // IPv6 address
    59  		}
    60  	}
    61  
    62  	c, err := d.Dialer.DialContext(ctx, network, address)
    63  	if err != nil {
    64  		return c, err
    65  	}
    66  
    67  	if d.tclass != 0 {
    68  		// IPv6 addresses must have two or more ":"
    69  		if strings.Count(c.RemoteAddr().String(), ":") > 1 {
    70  			err = ipv6.NewConn(c).SetTrafficClass(d.tclass)
    71  		} else {
    72  			err = ipv4.NewConn(c).SetTOS(d.tclass)
    73  			// Warn of silent failure on Windows (IPv4 only, IPv6 caught by error handler)
    74  			if runtime.GOOS == "windows" {
    75  				warnDSCPWindows.Do(func() {
    76  					fs.LogLevelPrintf(fs.LogLevelWarning, nil, "dialer: setting DSCP on Windows/IPv4 fails silently; see https://github.com/golang/go/issues/42728")
    77  				})
    78  			}
    79  		}
    80  		if err != nil {
    81  			warnDSCPFail.Do(func() {
    82  				fs.LogLevelPrintf(fs.LogLevelWarning, nil, "dialer: failed to set DSCP socket options: %v", err)
    83  			})
    84  		}
    85  	}
    86  
    87  	t := &timeoutConn{
    88  		Conn:    c,
    89  		timeout: d.timeout,
    90  	}
    91  	return t, t.nudgeDeadline()
    92  }
    93  
    94  // A net.Conn that sets deadline for every Read/Write operation
    95  type timeoutConn struct {
    96  	net.Conn
    97  	timeout time.Duration
    98  }
    99  
   100  // Nudge the deadline for an idle timeout on by c.timeout if non-zero
   101  func (c *timeoutConn) nudgeDeadline() error {
   102  	if c.timeout > 0 {
   103  		return c.SetDeadline(time.Now().Add(c.timeout))
   104  	}
   105  	return nil
   106  }
   107  
   108  // Read bytes with rate limiting and idle timeouts
   109  func (c *timeoutConn) Read(b []byte) (n int, err error) {
   110  	// Ideally we would LimitBandwidth(len(b)) here and replace tokens we didn't use
   111  	n, err = c.Conn.Read(b)
   112  	accounting.TokenBucket.LimitBandwidth(accounting.TokenBucketSlotTransportRx, n)
   113  	if err == nil && n > 0 && c.timeout > 0 {
   114  		err = c.nudgeDeadline()
   115  	}
   116  	return n, err
   117  }
   118  
   119  // Write bytes with rate limiting and idle timeouts
   120  func (c *timeoutConn) Write(b []byte) (n int, err error) {
   121  	accounting.TokenBucket.LimitBandwidth(accounting.TokenBucketSlotTransportTx, len(b))
   122  	n, err = c.Conn.Write(b)
   123  	if err == nil && n > 0 && c.timeout > 0 {
   124  		err = c.nudgeDeadline()
   125  	}
   126  	return n, err
   127  }