github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/trojan.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  
     7  	"github.com/inazumav/sing-box/adapter"
     8  	"github.com/inazumav/sing-box/common/dialer"
     9  	"github.com/inazumav/sing-box/common/mux"
    10  	"github.com/inazumav/sing-box/common/tls"
    11  	C "github.com/inazumav/sing-box/constant"
    12  	"github.com/inazumav/sing-box/log"
    13  	"github.com/inazumav/sing-box/option"
    14  	"github.com/inazumav/sing-box/transport/trojan"
    15  	"github.com/inazumav/sing-box/transport/v2ray"
    16  	"github.com/sagernet/sing/common"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  )
    22  
    23  var _ adapter.Outbound = (*Trojan)(nil)
    24  
    25  type Trojan struct {
    26  	myOutboundAdapter
    27  	dialer          N.Dialer
    28  	serverAddr      M.Socksaddr
    29  	key             [56]byte
    30  	multiplexDialer *mux.Client
    31  	tlsConfig       tls.Config
    32  	transport       adapter.V2RayClientTransport
    33  }
    34  
    35  func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (*Trojan, error) {
    36  	outboundDialer, err := dialer.New(router, options.DialerOptions)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	outbound := &Trojan{
    41  		myOutboundAdapter: myOutboundAdapter{
    42  			protocol:     C.TypeTrojan,
    43  			network:      options.Network.Build(),
    44  			router:       router,
    45  			logger:       logger,
    46  			tag:          tag,
    47  			dependencies: withDialerDependency(options.DialerOptions),
    48  		},
    49  		dialer:     outboundDialer,
    50  		serverAddr: options.ServerOptions.Build(),
    51  		key:        trojan.Key(options.Password),
    52  	}
    53  	if options.TLS != nil {
    54  		outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  	}
    59  	if options.Transport != nil {
    60  		outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
    61  		if err != nil {
    62  			return nil, E.Cause(err, "create client transport: ", options.Transport.Type)
    63  		}
    64  	}
    65  	outbound.multiplexDialer, err = mux.NewClientWithOptions((*trojanDialer)(outbound), common.PtrValueOrDefault(options.Multiplex))
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	return outbound, nil
    70  }
    71  
    72  func (h *Trojan) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    73  	if h.multiplexDialer == nil {
    74  		switch N.NetworkName(network) {
    75  		case N.NetworkTCP:
    76  			h.logger.InfoContext(ctx, "outbound connection to ", destination)
    77  		case N.NetworkUDP:
    78  			h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
    79  		}
    80  		return (*trojanDialer)(h).DialContext(ctx, network, destination)
    81  	} else {
    82  		switch N.NetworkName(network) {
    83  		case N.NetworkTCP:
    84  			h.logger.InfoContext(ctx, "outbound multiplex connection to ", destination)
    85  		case N.NetworkUDP:
    86  			h.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
    87  		}
    88  		return h.multiplexDialer.DialContext(ctx, network, destination)
    89  	}
    90  }
    91  
    92  func (h *Trojan) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
    93  	if h.multiplexDialer == nil {
    94  		h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
    95  		return (*trojanDialer)(h).ListenPacket(ctx, destination)
    96  	} else {
    97  		h.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
    98  		return h.multiplexDialer.ListenPacket(ctx, destination)
    99  	}
   100  }
   101  
   102  func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   103  	return NewConnection(ctx, h, conn, metadata)
   104  }
   105  
   106  func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   107  	return NewPacketConnection(ctx, h, conn, metadata)
   108  }
   109  
   110  func (h *Trojan) InterfaceUpdated() {
   111  	if h.multiplexDialer != nil {
   112  		h.multiplexDialer.Reset()
   113  	}
   114  	return
   115  }
   116  
   117  func (h *Trojan) Close() error {
   118  	return common.Close(common.PtrOrNil(h.multiplexDialer), h.transport)
   119  }
   120  
   121  type trojanDialer Trojan
   122  
   123  func (h *trojanDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   124  	ctx, metadata := adapter.AppendContext(ctx)
   125  	metadata.Outbound = h.tag
   126  	metadata.Destination = destination
   127  	var conn net.Conn
   128  	var err error
   129  	if h.transport != nil {
   130  		conn, err = h.transport.DialContext(ctx)
   131  	} else {
   132  		conn, err = h.dialer.DialContext(ctx, N.NetworkTCP, h.serverAddr)
   133  		if err == nil && h.tlsConfig != nil {
   134  			conn, err = tls.ClientHandshake(ctx, conn, h.tlsConfig)
   135  		}
   136  	}
   137  	if err != nil {
   138  		common.Close(conn)
   139  		return nil, err
   140  	}
   141  	switch N.NetworkName(network) {
   142  	case N.NetworkTCP:
   143  		return trojan.NewClientConn(conn, h.key, destination), nil
   144  	case N.NetworkUDP:
   145  		return bufio.NewBindPacketConn(trojan.NewClientPacketConn(conn, h.key), destination), nil
   146  	default:
   147  		return nil, E.Extend(N.ErrUnknownNetwork, network)
   148  	}
   149  }
   150  
   151  func (h *trojanDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   152  	conn, err := h.DialContext(ctx, N.NetworkUDP, destination)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	return conn.(net.PacketConn), nil
   157  }