github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/direct.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"time"
     8  
     9  	"github.com/inazumav/sing-box/adapter"
    10  	"github.com/inazumav/sing-box/common/dialer"
    11  	C "github.com/inazumav/sing-box/constant"
    12  	"github.com/inazumav/sing-box/log"
    13  	"github.com/inazumav/sing-box/option"
    14  	"github.com/sagernet/sing-dns"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	E "github.com/sagernet/sing/common/exceptions"
    18  	M "github.com/sagernet/sing/common/metadata"
    19  	N "github.com/sagernet/sing/common/network"
    20  
    21  	"github.com/pires/go-proxyproto"
    22  )
    23  
    24  var (
    25  	_ adapter.Outbound = (*Direct)(nil)
    26  	_ N.ParallelDialer = (*Direct)(nil)
    27  )
    28  
    29  type Direct struct {
    30  	myOutboundAdapter
    31  	dialer              N.Dialer
    32  	domainStrategy      dns.DomainStrategy
    33  	fallbackDelay       time.Duration
    34  	overrideOption      int
    35  	overrideDestination M.Socksaddr
    36  	proxyProto          uint8
    37  }
    38  
    39  func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (*Direct, error) {
    40  	options.UDPFragmentDefault = true
    41  	outboundDialer, err := dialer.New(router, options.DialerOptions)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	outbound := &Direct{
    46  		myOutboundAdapter: myOutboundAdapter{
    47  			protocol:     C.TypeDirect,
    48  			network:      []string{N.NetworkTCP, N.NetworkUDP},
    49  			router:       router,
    50  			logger:       logger,
    51  			tag:          tag,
    52  			dependencies: withDialerDependency(options.DialerOptions),
    53  		},
    54  		domainStrategy: dns.DomainStrategy(options.DomainStrategy),
    55  		fallbackDelay:  time.Duration(options.FallbackDelay),
    56  		dialer:         outboundDialer,
    57  		proxyProto:     options.ProxyProtocol,
    58  	}
    59  	if options.ProxyProtocol > 2 {
    60  		return nil, E.New("invalid proxy protocol option: ", options.ProxyProtocol)
    61  	}
    62  	if options.OverrideAddress != "" && options.OverridePort != 0 {
    63  		outbound.overrideOption = 1
    64  		outbound.overrideDestination = M.ParseSocksaddrHostPort(options.OverrideAddress, options.OverridePort)
    65  	} else if options.OverrideAddress != "" {
    66  		outbound.overrideOption = 2
    67  		outbound.overrideDestination = M.ParseSocksaddrHostPort(options.OverrideAddress, options.OverridePort)
    68  	} else if options.OverridePort != 0 {
    69  		outbound.overrideOption = 3
    70  		outbound.overrideDestination = M.Socksaddr{Port: options.OverridePort}
    71  	}
    72  	return outbound, nil
    73  }
    74  
    75  func (h *Direct) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    76  	ctx, metadata := adapter.AppendContext(ctx)
    77  	originDestination := metadata.Destination
    78  	metadata.Outbound = h.tag
    79  	metadata.Destination = destination
    80  	switch h.overrideOption {
    81  	case 1:
    82  		destination = h.overrideDestination
    83  	case 2:
    84  		newDestination := h.overrideDestination
    85  		newDestination.Port = destination.Port
    86  		destination = newDestination
    87  	case 3:
    88  		destination.Port = h.overrideDestination.Port
    89  	}
    90  	network = N.NetworkName(network)
    91  	switch network {
    92  	case N.NetworkTCP:
    93  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
    94  	case N.NetworkUDP:
    95  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
    96  	}
    97  	conn, err := h.dialer.DialContext(ctx, network, destination)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	if h.proxyProto > 0 {
   102  		source := metadata.Source
   103  		if !source.IsValid() {
   104  			source = M.SocksaddrFromNet(conn.LocalAddr())
   105  		}
   106  		if originDestination.Addr.Is6() {
   107  			source = M.SocksaddrFrom(netip.AddrFrom16(source.Addr.As16()), source.Port)
   108  		}
   109  		header := proxyproto.HeaderProxyFromAddrs(h.proxyProto, source.TCPAddr(), originDestination.TCPAddr())
   110  		_, err = header.WriteTo(conn)
   111  		if err != nil {
   112  			conn.Close()
   113  			return nil, E.Cause(err, "write proxy protocol header")
   114  		}
   115  	}
   116  	return conn, nil
   117  }
   118  
   119  func (h *Direct) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
   120  	ctx, metadata := adapter.AppendContext(ctx)
   121  	originDestination := metadata.Destination
   122  	metadata.Outbound = h.tag
   123  	metadata.Destination = destination
   124  	switch h.overrideOption {
   125  	case 1, 2:
   126  		// override address
   127  		return h.DialContext(ctx, network, destination)
   128  	case 3:
   129  		destination.Port = h.overrideDestination.Port
   130  	}
   131  	network = N.NetworkName(network)
   132  	switch network {
   133  	case N.NetworkTCP:
   134  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
   135  	case N.NetworkUDP:
   136  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   137  	}
   138  	var domainStrategy dns.DomainStrategy
   139  	if h.domainStrategy != dns.DomainStrategyAsIS {
   140  		domainStrategy = h.domainStrategy
   141  	} else {
   142  		domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
   143  	}
   144  	conn, err := N.DialParallel(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.fallbackDelay)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	if h.proxyProto > 0 {
   149  		source := metadata.Source
   150  		if !source.IsValid() {
   151  			source = M.SocksaddrFromNet(conn.LocalAddr())
   152  		}
   153  		if originDestination.Addr.Is6() {
   154  			source = M.SocksaddrFrom(netip.AddrFrom16(source.Addr.As16()), source.Port)
   155  		}
   156  		header := proxyproto.HeaderProxyFromAddrs(h.proxyProto, source.TCPAddr(), originDestination.TCPAddr())
   157  		_, err = header.WriteTo(conn)
   158  		if err != nil {
   159  			conn.Close()
   160  			return nil, E.Cause(err, "write proxy protocol header")
   161  		}
   162  	}
   163  	return conn, nil
   164  }
   165  
   166  func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   167  	ctx, metadata := adapter.AppendContext(ctx)
   168  	metadata.Outbound = h.tag
   169  	metadata.Destination = destination
   170  	switch h.overrideOption {
   171  	case 1:
   172  		destination = h.overrideDestination
   173  	case 2:
   174  		newDestination := h.overrideDestination
   175  		newDestination.Port = destination.Port
   176  		destination = newDestination
   177  	case 3:
   178  		destination.Port = h.overrideDestination.Port
   179  	}
   180  	if h.overrideOption == 0 {
   181  		h.logger.InfoContext(ctx, "outbound packet connection")
   182  	} else {
   183  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   184  	}
   185  	conn, err := h.dialer.ListenPacket(ctx, destination)
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	if h.overrideOption == 0 {
   190  		return conn, nil
   191  	} else {
   192  		return &overridePacketConn{bufio.NewPacketConn(conn), destination}, nil
   193  	}
   194  }
   195  
   196  func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   197  	return NewConnection(ctx, h, conn, metadata)
   198  }
   199  
   200  func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   201  	return NewPacketConnection(ctx, h, conn, metadata)
   202  }
   203  
   204  type overridePacketConn struct {
   205  	N.NetPacketConn
   206  	overrideDestination M.Socksaddr
   207  }
   208  
   209  func (c *overridePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   210  	return c.NetPacketConn.WritePacket(buffer, c.overrideDestination)
   211  }
   212  
   213  func (c *overridePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   214  	return c.NetPacketConn.WriteTo(p, c.overrideDestination.UDPAddr())
   215  }
   216  
   217  func (c *overridePacketConn) Upstream() any {
   218  	return c.NetPacketConn
   219  }