github.com/sagernet/sing-box@v1.9.0-rc.20/common/dialer/default.go (about) 1 package dialer 2 3 import ( 4 "context" 5 "net" 6 "time" 7 8 "github.com/sagernet/sing-box/adapter" 9 "github.com/sagernet/sing-box/common/conntrack" 10 C "github.com/sagernet/sing-box/constant" 11 "github.com/sagernet/sing-box/option" 12 "github.com/sagernet/sing/common/control" 13 E "github.com/sagernet/sing/common/exceptions" 14 M "github.com/sagernet/sing/common/metadata" 15 N "github.com/sagernet/sing/common/network" 16 ) 17 18 var _ WireGuardListener = (*DefaultDialer)(nil) 19 20 type DefaultDialer struct { 21 dialer4 tcpDialer 22 dialer6 tcpDialer 23 udpDialer4 net.Dialer 24 udpDialer6 net.Dialer 25 udpListener net.ListenConfig 26 udpAddr4 string 27 udpAddr6 string 28 isWireGuardListener bool 29 } 30 31 func NewDefault(router adapter.Router, options option.DialerOptions) (*DefaultDialer, error) { 32 var dialer net.Dialer 33 var listener net.ListenConfig 34 if options.BindInterface != "" { 35 var interfaceFinder control.InterfaceFinder 36 if router != nil { 37 interfaceFinder = router.InterfaceFinder() 38 } else { 39 interfaceFinder = control.NewDefaultInterfaceFinder() 40 } 41 bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) 42 dialer.Control = control.Append(dialer.Control, bindFunc) 43 listener.Control = control.Append(listener.Control, bindFunc) 44 } else if router != nil && router.AutoDetectInterface() { 45 bindFunc := router.AutoDetectInterfaceFunc() 46 dialer.Control = control.Append(dialer.Control, bindFunc) 47 listener.Control = control.Append(listener.Control, bindFunc) 48 } else if router != nil && router.DefaultInterface() != "" { 49 bindFunc := control.BindToInterface(router.InterfaceFinder(), router.DefaultInterface(), -1) 50 dialer.Control = control.Append(dialer.Control, bindFunc) 51 listener.Control = control.Append(listener.Control, bindFunc) 52 } 53 if options.RoutingMark != 0 { 54 dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark)) 55 listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark)) 56 } else if router != nil && router.DefaultMark() != 0 { 57 dialer.Control = control.Append(dialer.Control, control.RoutingMark(router.DefaultMark())) 58 listener.Control = control.Append(listener.Control, control.RoutingMark(router.DefaultMark())) 59 } 60 if options.ReuseAddr { 61 listener.Control = control.Append(listener.Control, control.ReuseAddr()) 62 } 63 if options.ProtectPath != "" { 64 dialer.Control = control.Append(dialer.Control, control.ProtectPath(options.ProtectPath)) 65 listener.Control = control.Append(listener.Control, control.ProtectPath(options.ProtectPath)) 66 } 67 if options.ConnectTimeout != 0 { 68 dialer.Timeout = time.Duration(options.ConnectTimeout) 69 } else { 70 dialer.Timeout = C.TCPTimeout 71 } 72 // TODO: Add an option to customize the keep alive period 73 dialer.KeepAlive = C.TCPKeepAliveInitial 74 dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval)) 75 var udpFragment bool 76 if options.UDPFragment != nil { 77 udpFragment = *options.UDPFragment 78 } else { 79 udpFragment = options.UDPFragmentDefault 80 } 81 if !udpFragment { 82 dialer.Control = control.Append(dialer.Control, control.DisableUDPFragment()) 83 listener.Control = control.Append(listener.Control, control.DisableUDPFragment()) 84 } 85 var ( 86 dialer4 = dialer 87 udpDialer4 = dialer 88 udpAddr4 string 89 ) 90 if options.Inet4BindAddress != nil { 91 bindAddr := options.Inet4BindAddress.Build() 92 dialer4.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()} 93 udpDialer4.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()} 94 udpAddr4 = M.SocksaddrFrom(bindAddr, 0).String() 95 } 96 var ( 97 dialer6 = dialer 98 udpDialer6 = dialer 99 udpAddr6 string 100 ) 101 if options.Inet6BindAddress != nil { 102 bindAddr := options.Inet6BindAddress.Build() 103 dialer6.LocalAddr = &net.TCPAddr{IP: bindAddr.AsSlice()} 104 udpDialer6.LocalAddr = &net.UDPAddr{IP: bindAddr.AsSlice()} 105 udpAddr6 = M.SocksaddrFrom(bindAddr, 0).String() 106 } 107 if options.TCPMultiPath { 108 if !go121Available { 109 return nil, E.New("MultiPath TCP requires go1.21, please recompile your binary.") 110 } 111 setMultiPathTCP(&dialer4) 112 } 113 if options.IsWireGuardListener { 114 for _, controlFn := range wgControlFns { 115 listener.Control = control.Append(listener.Control, controlFn) 116 } 117 } 118 tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) 119 if err != nil { 120 return nil, err 121 } 122 tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen) 123 if err != nil { 124 return nil, err 125 } 126 return &DefaultDialer{ 127 tcpDialer4, 128 tcpDialer6, 129 udpDialer4, 130 udpDialer6, 131 listener, 132 udpAddr4, 133 udpAddr6, 134 options.IsWireGuardListener, 135 }, nil 136 } 137 138 func (d *DefaultDialer) DialContext(ctx context.Context, network string, address M.Socksaddr) (net.Conn, error) { 139 if !address.IsValid() { 140 return nil, E.New("invalid address") 141 } 142 switch N.NetworkName(network) { 143 case N.NetworkUDP: 144 if !address.IsIPv6() { 145 return trackConn(d.udpDialer4.DialContext(ctx, network, address.String())) 146 } else { 147 return trackConn(d.udpDialer6.DialContext(ctx, network, address.String())) 148 } 149 } 150 if !address.IsIPv6() { 151 return trackConn(DialSlowContext(&d.dialer4, ctx, network, address)) 152 } else { 153 return trackConn(DialSlowContext(&d.dialer6, ctx, network, address)) 154 } 155 } 156 157 func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 158 if destination.IsIPv6() { 159 return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)) 160 } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { 161 return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4)) 162 } else { 163 return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)) 164 } 165 } 166 167 func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { 168 return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address)) 169 } 170 171 func trackConn(conn net.Conn, err error) (net.Conn, error) { 172 if !conntrack.Enabled || err != nil { 173 return conn, err 174 } 175 return conntrack.NewConn(conn) 176 } 177 178 func trackPacketConn(conn net.PacketConn, err error) (net.PacketConn, error) { 179 if !conntrack.Enabled || err != nil { 180 return conn, err 181 } 182 return conntrack.NewPacketConn(conn) 183 }