github.com/sagernet/sing-box@v1.9.0-rc.20/common/dialer/default.go (about)

     1  package dialer
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"time"
     7  
     8  	"github.com/sagernet/sing-box/adapter"
     9  	"github.com/sagernet/sing-box/common/conntrack"
    10  	C "github.com/sagernet/sing-box/constant"
    11  	"github.com/sagernet/sing-box/option"
    12  	"github.com/sagernet/sing/common/control"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  )
    17  
    18  var _ WireGuardListener = (*DefaultDialer)(nil)
    19  
    20  type DefaultDialer struct {
    21  	dialer4             tcpDialer
    22  	dialer6             tcpDialer
    23  	udpDialer4          net.Dialer
    24  	udpDialer6          net.Dialer
    25  	udpListener         net.ListenConfig
    26  	udpAddr4            string
    27  	udpAddr6            string
    28  	isWireGuardListener bool
    29  }
    30  
    31  func NewDefault(router adapter.Router, options option.DialerOptions) (*DefaultDialer, error) {
    32  	var dialer net.Dialer
    33  	var listener net.ListenConfig
    34  	if options.BindInterface != "" {
    35  		var interfaceFinder control.InterfaceFinder
    36  		if router != nil {
    37  			interfaceFinder = router.InterfaceFinder()
    38  		} else {
    39  			interfaceFinder = control.NewDefaultInterfaceFinder()
    40  		}
    41  		bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1)
    42  		dialer.Control = control.Append(dialer.Control, bindFunc)
    43  		listener.Control = control.Append(listener.Control, bindFunc)
    44  	} else if router != nil && router.AutoDetectInterface() {
    45  		bindFunc := router.AutoDetectInterfaceFunc()
    46  		dialer.Control = control.Append(dialer.Control, bindFunc)
    47  		listener.Control = control.Append(listener.Control, bindFunc)
    48  	} else if router != nil && router.DefaultInterface() != "" {
    49  		bindFunc := control.BindToInterface(router.InterfaceFinder(), router.DefaultInterface(), -1)
    50  		dialer.Control = control.Append(dialer.Control, bindFunc)
    51  		listener.Control = control.Append(listener.Control, bindFunc)
    52  	}
    53  	if options.RoutingMark != 0 {
    54  		dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark))
    55  		listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark))
    56  	} else if router != nil && router.DefaultMark() != 0 {
    57  		dialer.Control = control.Append(dialer.Control, control.RoutingMark(router.DefaultMark()))
    58  		listener.Control = control.Append(listener.Control, control.RoutingMark(router.DefaultMark()))
    59  	}
    60  	if options.ReuseAddr {
    61  		listener.Control = control.Append(listener.Control, control.ReuseAddr())
    62  	}
    63  	if options.ProtectPath != "" {
    64  		dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
    65  		listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
    66  	}
    67  	if options.ConnectTimeout != 0 {
    68  		dialer.Timeout = time.Duration(options.ConnectTimeout)
    69  	} else {
    70  		dialer.Timeout = C.TCPTimeout
    71  	}
    72  	// TODO: Add an option to customize the keep alive period
    73  	dialer.KeepAlive = C.TCPKeepAliveInitial
    74  	dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval))
    75  	var udpFragment bool
    76  	if options.UDPFragment != nil {
    77  		udpFragment = *options.UDPFragment
    78  	} else {
    79  		udpFragment = options.UDPFragmentDefault
    80  	}
    81  	if !udpFragment {
    82  		dialer.Control = control.Append(dialer.Control, control.DisableUDPFragment())
    83  		listener.Control = control.Append(listener.Control, control.DisableUDPFragment())
    84  	}
    85  	var (
    86  		dialer4    = dialer
    87  		udpDialer4 = dialer
    88  		udpAddr4   string
    89  	)
    90  	if options.Inet4BindAddress != nil {
    91  		bindAddr := options.Inet4BindAddress.Build()
    92  		dialer4.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
    93  		udpDialer4.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
    94  		udpAddr4 = M.SocksaddrFrom(bindAddr, 0).String()
    95  	}
    96  	var (
    97  		dialer6    = dialer
    98  		udpDialer6 = dialer
    99  		udpAddr6   string
   100  	)
   101  	if options.Inet6BindAddress != nil {
   102  		bindAddr := options.Inet6BindAddress.Build()
   103  		dialer6.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
   104  		udpDialer6.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
   105  		udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String()
   106  	}
   107  	if options.TCPMultiPath {
   108  		if !go121Available {
   109  			return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
   110  		}
   111  		setMultiPathTCP(&dialer4)
   112  	}
   113  	if options.IsWireGuardListener {
   114  		for _, controlFn := range wgControlFns {
   115  			listener.Control = control.Append(listener.Control, controlFn)
   116  		}
   117  	}
   118  	tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	return &DefaultDialer{
   127  		tcpDialer4,
   128  		tcpDialer6,
   129  		udpDialer4,
   130  		udpDialer6,
   131  		listener,
   132  		udpAddr4,
   133  		udpAddr6,
   134  		options.IsWireGuardListener,
   135  	}, nil
   136  }
   137  
   138  func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
   139  	if !address.IsValid() {
   140  		return nil, E.New("invalid address")
   141  	}
   142  	switch N.NetworkName(network) {
   143  	case N.NetworkUDP:
   144  		if !address.IsIPv6() {
   145  			return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
   146  		} else {
   147  			return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
   148  		}
   149  	}
   150  	if !address.IsIPv6() {
   151  		return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
   152  	} else {
   153  		return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
   154  	}
   155  }
   156  
   157  func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   158  	if destination.IsIPv6() {
   159  		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
   160  	} else if destination.IsIPv4() && !destination.Addr.IsUnspecified() {
   161  		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4))
   162  	} else {
   163  		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
   164  	}
   165  }
   166  
   167  func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) {
   168  	return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address))
   169  }
   170  
   171  func trackConn(conn net.Conn, err error) (net.Conn, error) {
   172  	if !conntrack.Enabled || err != nil {
   173  		return conn, err
   174  	}
   175  	return conntrack.NewConn(conn)
   176  }
   177  
   178  func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
   179  	if !conntrack.Enabled || err != nil {
   180  		return conn, err
   181  	}
   182  	return conntrack.NewPacketConn(conn)
   183  }