github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/dialer/default.go (about)

     1  package dialer
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"time"
     7  
     8  	"github.com/inazumav/sing-box/adapter"
     9  	"github.com/inazumav/sing-box/common/dialer/conntrack"
    10  	C "github.com/inazumav/sing-box/constant"
    11  	"github.com/inazumav/sing-box/option"
    12  	"github.com/sagernet/sing/common/control"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  )
    17  
    18  type DefaultDialer struct {
    19  	dialer4     tcpDialer
    20  	dialer6     tcpDialer
    21  	udpDialer4  net.Dialer
    22  	udpDialer6  net.Dialer
    23  	udpListener net.ListenConfig
    24  	udpAddr4    string
    25  	udpAddr6    string
    26  }
    27  
    28  func NewDefault(router adapter.Router, options option.DialerOptions) (*DefaultDialer, error) {
    29  	var dialer net.Dialer
    30  	var listener net.ListenConfig
    31  	if options.BindInterface != "" {
    32  		bindFunc := control.BindToInterface(router.InterfaceFinder(), options.BindInterface, -1)
    33  		dialer.Control = control.Append(dialer.Control, bindFunc)
    34  		listener.Control = control.Append(listener.Control, bindFunc)
    35  	} else if router.AutoDetectInterface() {
    36  		bindFunc := router.AutoDetectInterfaceFunc()
    37  		dialer.Control = control.Append(dialer.Control, bindFunc)
    38  		listener.Control = control.Append(listener.Control, bindFunc)
    39  	} else if router.DefaultInterface() != "" {
    40  		bindFunc := control.BindToInterface(router.InterfaceFinder(), router.DefaultInterface(), -1)
    41  		dialer.Control = control.Append(dialer.Control, bindFunc)
    42  		listener.Control = control.Append(listener.Control, bindFunc)
    43  	}
    44  	if options.RoutingMark != 0 {
    45  		dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark))
    46  		listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark))
    47  	} else if router.DefaultMark() != 0 {
    48  		dialer.Control = control.Append(dialer.Control, control.RoutingMark(router.DefaultMark()))
    49  		listener.Control = control.Append(listener.Control, control.RoutingMark(router.DefaultMark()))
    50  	}
    51  	if options.ReuseAddr {
    52  		listener.Control = control.Append(listener.Control, control.ReuseAddr())
    53  	}
    54  	if options.ProtectPath != "" {
    55  		dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath))
    56  		listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath))
    57  	}
    58  	if options.ConnectTimeout != 0 {
    59  		dialer.Timeout = time.Duration(options.ConnectTimeout)
    60  	} else {
    61  		dialer.Timeout = C.TCPTimeout
    62  	}
    63  	var udpFragment bool
    64  	if options.UDPFragment != nil {
    65  		udpFragment = *options.UDPFragment
    66  	} else {
    67  		udpFragment = options.UDPFragmentDefault
    68  	}
    69  	if !udpFragment {
    70  		dialer.Control = control.Append(dialer.Control, control.DisableUDPFragment())
    71  		listener.Control = control.Append(listener.Control, control.DisableUDPFragment())
    72  	}
    73  	var (
    74  		dialer4    = dialer
    75  		udpDialer4 = dialer
    76  		udpAddr4   string
    77  	)
    78  	if options.Inet4BindAddress != nil {
    79  		bindAddr := options.Inet4BindAddress.Build()
    80  		dialer4.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
    81  		udpDialer4.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
    82  		udpAddr4 = M.SocksaddrFrom(bindAddr, 0).String()
    83  	}
    84  	var (
    85  		dialer6    = dialer
    86  		udpDialer6 = dialer
    87  		udpAddr6   string
    88  	)
    89  	if options.Inet6BindAddress != nil {
    90  		bindAddr := options.Inet6BindAddress.Build()
    91  		dialer6.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()}
    92  		udpDialer6.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()}
    93  		udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String()
    94  	}
    95  	if options.TCPMultiPath {
    96  		if !go121Available {
    97  			return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.")
    98  		}
    99  		setMultiPathTCP(&dialer4)
   100  	}
   101  	tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	return &DefaultDialer{
   110  		tcpDialer4,
   111  		tcpDialer6,
   112  		udpDialer4,
   113  		udpDialer6,
   114  		listener,
   115  		udpAddr4,
   116  		udpAddr6,
   117  	}, nil
   118  }
   119  
   120  func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) {
   121  	if !address.IsValid() {
   122  		return nil, E.New("invalid address")
   123  	}
   124  	switch N.NetworkName(network) {
   125  	case N.NetworkUDP:
   126  		if !address.IsIPv6() {
   127  			return trackConn(d.udpDialer4.DialContext(ctx, network, address.String()))
   128  		} else {
   129  			return trackConn(d.udpDialer6.DialContext(ctx, network, address.String()))
   130  		}
   131  	}
   132  	if !address.IsIPv6() {
   133  		return trackConn(DialSlowContext(&d.dialer4, ctx, network, address))
   134  	} else {
   135  		return trackConn(DialSlowContext(&d.dialer6, ctx, network, address))
   136  	}
   137  }
   138  
   139  func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   140  	if !destination.IsIPv6() {
   141  		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4))
   142  	} else {
   143  		return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6))
   144  	}
   145  }
   146  
   147  func trackConn(conn net.Conn, err error) (net.Conn, error) {
   148  	if !conntrack.Enabled || err != nil {
   149  		return conn, err
   150  	}
   151  	return conntrack.NewConn(conn)
   152  }
   153  
   154  func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) {
   155  	if !conntrack.Enabled || err != nil {
   156  		return conn, err
   157  	}
   158  	return conntrack.NewPacketConn(conn)
   159  }