github.com/sagernet/sing-box@v1.2.7/outbound/direct.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"time"
     8  
     9  	"github.com/sagernet/sing-box/adapter"
    10  	"github.com/sagernet/sing-box/common/dialer"
    11  	C "github.com/sagernet/sing-box/constant"
    12  	"github.com/sagernet/sing-box/log"
    13  	"github.com/sagernet/sing-box/option"
    14  	"github.com/sagernet/sing-dns"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	E "github.com/sagernet/sing/common/exceptions"
    18  	M "github.com/sagernet/sing/common/metadata"
    19  	N "github.com/sagernet/sing/common/network"
    20  
    21  	"github.com/pires/go-proxyproto"
    22  )
    23  
    24  var (
    25  	_ adapter.Outbound = (*Direct)(nil)
    26  	_ N.ParallelDialer = (*Direct)(nil)
    27  )
    28  
    29  type Direct struct {
    30  	myOutboundAdapter
    31  	dialer              N.Dialer
    32  	domainStrategy      dns.DomainStrategy
    33  	fallbackDelay       time.Duration
    34  	overrideOption      int
    35  	overrideDestination M.Socksaddr
    36  	proxyProto          uint8
    37  }
    38  
    39  func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (*Direct, error) {
    40  	options.UDPFragmentDefault = true
    41  	outbound := &Direct{
    42  		myOutboundAdapter: myOutboundAdapter{
    43  			protocol: C.TypeDirect,
    44  			network:  []string{N.NetworkTCP, N.NetworkUDP},
    45  			router:   router,
    46  			logger:   logger,
    47  			tag:      tag,
    48  		},
    49  		domainStrategy: dns.DomainStrategy(options.DomainStrategy),
    50  		fallbackDelay:  time.Duration(options.FallbackDelay),
    51  		dialer:         dialer.New(router, options.DialerOptions),
    52  		proxyProto:     options.ProxyProtocol,
    53  	}
    54  	if options.ProxyProtocol > 2 {
    55  		return nil, E.New("invalid proxy protocol option: ", options.ProxyProtocol)
    56  	}
    57  	if options.OverrideAddress != "" && options.OverridePort != 0 {
    58  		outbound.overrideOption = 1
    59  		outbound.overrideDestination = M.ParseSocksaddrHostPort(options.OverrideAddress, options.OverridePort)
    60  	} else if options.OverrideAddress != "" {
    61  		outbound.overrideOption = 2
    62  		outbound.overrideDestination = M.ParseSocksaddrHostPort(options.OverrideAddress, options.OverridePort)
    63  	} else if options.OverridePort != 0 {
    64  		outbound.overrideOption = 3
    65  		outbound.overrideDestination = M.Socksaddr{Port: options.OverridePort}
    66  	}
    67  	return outbound, nil
    68  }
    69  
    70  func (h *Direct) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    71  	ctx, metadata := adapter.AppendContext(ctx)
    72  	originDestination := metadata.Destination
    73  	metadata.Outbound = h.tag
    74  	metadata.Destination = destination
    75  	switch h.overrideOption {
    76  	case 1:
    77  		destination = h.overrideDestination
    78  	case 2:
    79  		newDestination := h.overrideDestination
    80  		newDestination.Port = destination.Port
    81  		destination = newDestination
    82  	case 3:
    83  		destination.Port = h.overrideDestination.Port
    84  	}
    85  	network = N.NetworkName(network)
    86  	switch network {
    87  	case N.NetworkTCP:
    88  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
    89  	case N.NetworkUDP:
    90  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
    91  	}
    92  	conn, err := h.dialer.DialContext(ctx, network, destination)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	if h.proxyProto > 0 {
    97  		source := metadata.Source
    98  		if !source.IsValid() {
    99  			source = M.SocksaddrFromNet(conn.LocalAddr())
   100  		}
   101  		if originDestination.Addr.Is6() {
   102  			source = M.SocksaddrFrom(netip.AddrFrom16(source.Addr.As16()), source.Port)
   103  		}
   104  		header := proxyproto.HeaderProxyFromAddrs(h.proxyProto, source.TCPAddr(), originDestination.TCPAddr())
   105  		_, err = header.WriteTo(conn)
   106  		if err != nil {
   107  			conn.Close()
   108  			return nil, E.Cause(err, "write proxy protocol header")
   109  		}
   110  	}
   111  	return conn, nil
   112  }
   113  
   114  func (h *Direct) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
   115  	ctx, metadata := adapter.AppendContext(ctx)
   116  	originDestination := metadata.Destination
   117  	metadata.Outbound = h.tag
   118  	metadata.Destination = destination
   119  	switch h.overrideOption {
   120  	case 1, 2:
   121  		// override address
   122  		return h.DialContext(ctx, network, destination)
   123  	case 3:
   124  		destination.Port = h.overrideDestination.Port
   125  	}
   126  	network = N.NetworkName(network)
   127  	switch network {
   128  	case N.NetworkTCP:
   129  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
   130  	case N.NetworkUDP:
   131  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   132  	}
   133  	var domainStrategy dns.DomainStrategy
   134  	if h.domainStrategy != dns.DomainStrategyAsIS {
   135  		domainStrategy = h.domainStrategy
   136  	} else {
   137  		domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
   138  	}
   139  	conn, err := N.DialParallel(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.fallbackDelay)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	if h.proxyProto > 0 {
   144  		source := metadata.Source
   145  		if !source.IsValid() {
   146  			source = M.SocksaddrFromNet(conn.LocalAddr())
   147  		}
   148  		if originDestination.Addr.Is6() {
   149  			source = M.SocksaddrFrom(netip.AddrFrom16(source.Addr.As16()), source.Port)
   150  		}
   151  		header := proxyproto.HeaderProxyFromAddrs(h.proxyProto, source.TCPAddr(), originDestination.TCPAddr())
   152  		_, err = header.WriteTo(conn)
   153  		if err != nil {
   154  			conn.Close()
   155  			return nil, E.Cause(err, "write proxy protocol header")
   156  		}
   157  	}
   158  	return conn, nil
   159  }
   160  
   161  func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   162  	ctx, metadata := adapter.AppendContext(ctx)
   163  	metadata.Outbound = h.tag
   164  	metadata.Destination = destination
   165  	switch h.overrideOption {
   166  	case 1:
   167  		destination = h.overrideDestination
   168  	case 2:
   169  		newDestination := h.overrideDestination
   170  		newDestination.Port = destination.Port
   171  		destination = newDestination
   172  	case 3:
   173  		destination.Port = h.overrideDestination.Port
   174  	}
   175  	if h.overrideOption == 0 {
   176  		h.logger.InfoContext(ctx, "outbound packet connection")
   177  	} else {
   178  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   179  	}
   180  	conn, err := h.dialer.ListenPacket(ctx, destination)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  	if h.overrideOption == 0 {
   185  		return conn, nil
   186  	} else {
   187  		return &overridePacketConn{bufio.NewPacketConn(conn), destination}, nil
   188  	}
   189  }
   190  
   191  func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   192  	return NewConnection(ctx, h, conn, metadata)
   193  }
   194  
   195  func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   196  	return NewPacketConnection(ctx, h, conn, metadata)
   197  }
   198  
   199  type overridePacketConn struct {
   200  	N.NetPacketConn
   201  	overrideDestination M.Socksaddr
   202  }
   203  
   204  func (c *overridePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   205  	return c.NetPacketConn.WritePacket(buffer, c.overrideDestination)
   206  }
   207  
   208  func (c *overridePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   209  	return c.NetPacketConn.WriteTo(p, c.overrideDestination.UDPAddr())
   210  }
   211  
   212  func (c *overridePacketConn) Upstream() any {
   213  	return c.NetPacketConn
   214  }