github.com/sagernet/sing-box@v1.9.0-rc.20/inbound/trojan.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"os"
     7  
     8  	"github.com/sagernet/sing-box/adapter"
     9  	"github.com/sagernet/sing-box/common/mux"
    10  	"github.com/sagernet/sing-box/common/tls"
    11  	C "github.com/sagernet/sing-box/constant"
    12  	"github.com/sagernet/sing-box/log"
    13  	"github.com/sagernet/sing-box/option"
    14  	"github.com/sagernet/sing-box/transport/trojan"
    15  	"github.com/sagernet/sing-box/transport/v2ray"
    16  	"github.com/sagernet/sing/common"
    17  	"github.com/sagernet/sing/common/auth"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	F "github.com/sagernet/sing/common/format"
    20  	M "github.com/sagernet/sing/common/metadata"
    21  	N "github.com/sagernet/sing/common/network"
    22  )
    23  
    24  var (
    25  	_ adapter.Inbound           = (*Trojan)(nil)
    26  	_ adapter.InjectableInbound = (*Trojan)(nil)
    27  )
    28  
    29  type Trojan struct {
    30  	myInboundAdapter
    31  	service                  *trojan.Service[int]
    32  	users                    []option.TrojanUser
    33  	tlsConfig                tls.ServerConfig
    34  	fallbackAddr             M.Socksaddr
    35  	fallbackAddrTLSNextProto map[string]M.Socksaddr
    36  	transport                adapter.V2RayServerTransport
    37  }
    38  
    39  func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanInboundOptions) (*Trojan, error) {
    40  	inbound := &Trojan{
    41  		myInboundAdapter: myInboundAdapter{
    42  			protocol:      C.TypeTrojan,
    43  			network:       []string{N.NetworkTCP},
    44  			ctx:           ctx,
    45  			router:        router,
    46  			logger:        logger,
    47  			tag:           tag,
    48  			listenOptions: options.ListenOptions,
    49  		},
    50  		users: options.Users,
    51  	}
    52  	if options.TLS != nil {
    53  		tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  		inbound.tlsConfig = tlsConfig
    58  	}
    59  	var fallbackHandler N.TCPConnectionHandler
    60  	if options.Fallback != nil && options.Fallback.Server != "" || len(options.FallbackForALPN) > 0 {
    61  		if options.Fallback != nil && options.Fallback.Server != "" {
    62  			inbound.fallbackAddr = options.Fallback.Build()
    63  			if !inbound.fallbackAddr.IsValid() {
    64  				return nil, E.New("invalid fallback address: ", inbound.fallbackAddr)
    65  			}
    66  		}
    67  		if len(options.FallbackForALPN) > 0 {
    68  			if inbound.tlsConfig == nil {
    69  				return nil, E.New("fallback for ALPN is not supported without TLS")
    70  			}
    71  			fallbackAddrNextProto := make(map[string]M.Socksaddr)
    72  			for nextProto, destination := range options.FallbackForALPN {
    73  				fallbackAddr := destination.Build()
    74  				if !fallbackAddr.IsValid() {
    75  					return nil, E.New("invalid fallback address for ALPN ", nextProto, ": ", fallbackAddr)
    76  				}
    77  				fallbackAddrNextProto[nextProto] = fallbackAddr
    78  			}
    79  			inbound.fallbackAddrTLSNextProto = fallbackAddrNextProto
    80  		}
    81  		fallbackHandler = adapter.NewUpstreamContextHandler(inbound.fallbackConnection, nil, nil)
    82  	}
    83  	service := trojan.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound), fallbackHandler)
    84  	err := service.UpdateUsers(common.MapIndexed(options.Users, func(index int, it option.TrojanUser) int {
    85  		return index
    86  	}), common.Map(options.Users, func(it option.TrojanUser) string {
    87  		return it.Password
    88  	}))
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if options.Transport != nil {
    93  		inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*trojanTransportHandler)(inbound))
    94  		if err != nil {
    95  			return nil, E.Cause(err, "create server transport: ", options.Transport.Type)
    96  		}
    97  	}
    98  	inbound.router, err = mux.NewRouterWithOptions(inbound.router, logger, common.PtrValueOrDefault(options.Multiplex))
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	inbound.service = service
   103  	inbound.connHandler = inbound
   104  	return inbound, nil
   105  }
   106  
   107  func (h *Trojan) Start() error {
   108  	if h.tlsConfig != nil {
   109  		err := h.tlsConfig.Start()
   110  		if err != nil {
   111  			return E.Cause(err, "create TLS config")
   112  		}
   113  	}
   114  	if h.transport == nil {
   115  		return h.myInboundAdapter.Start()
   116  	}
   117  	if common.Contains(h.transport.Network(), N.NetworkTCP) {
   118  		tcpListener, err := h.myInboundAdapter.ListenTCP()
   119  		if err != nil {
   120  			return err
   121  		}
   122  		go func() {
   123  			sErr := h.transport.Serve(tcpListener)
   124  			if sErr != nil && !E.IsClosed(sErr) {
   125  				h.logger.Error("transport serve error: ", sErr)
   126  			}
   127  		}()
   128  	}
   129  	if common.Contains(h.transport.Network(), N.NetworkUDP) {
   130  		udpConn, err := h.myInboundAdapter.ListenUDP()
   131  		if err != nil {
   132  			return err
   133  		}
   134  		go func() {
   135  			sErr := h.transport.ServePacket(udpConn)
   136  			if sErr != nil && !E.IsClosed(sErr) {
   137  				h.logger.Error("transport serve error: ", sErr)
   138  			}
   139  		}()
   140  	}
   141  	return nil
   142  }
   143  
   144  func (h *Trojan) Close() error {
   145  	return common.Close(
   146  		&h.myInboundAdapter,
   147  		h.tlsConfig,
   148  		h.transport,
   149  	)
   150  }
   151  
   152  func (h *Trojan) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   153  	h.injectTCP(conn, metadata)
   154  	return nil
   155  }
   156  
   157  func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   158  	var err error
   159  	if h.tlsConfig != nil && h.transport == nil {
   160  		conn, err = tls.ServerHandshake(ctx, conn, h.tlsConfig)
   161  		if err != nil {
   162  			return err
   163  		}
   164  	}
   165  	return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata))
   166  }
   167  
   168  func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   169  	return os.ErrInvalid
   170  }
   171  
   172  func (h *Trojan) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   173  	userIndex, loaded := auth.UserFromContext[int](ctx)
   174  	if !loaded {
   175  		return os.ErrInvalid
   176  	}
   177  	user := h.users[userIndex].Name
   178  	if user == "" {
   179  		user = F.ToString(userIndex)
   180  	} else {
   181  		metadata.User = user
   182  	}
   183  	h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination)
   184  	return h.router.RouteConnection(ctx, conn, metadata)
   185  }
   186  
   187  func (h *Trojan) fallbackConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   188  	var fallbackAddr M.Socksaddr
   189  	if len(h.fallbackAddrTLSNextProto) > 0 {
   190  		if tlsConn, loaded := common.Cast[tls.Conn](conn); loaded {
   191  			connectionState := tlsConn.ConnectionState()
   192  			if connectionState.NegotiatedProtocol != "" {
   193  				if fallbackAddr, loaded = h.fallbackAddrTLSNextProto[connectionState.NegotiatedProtocol]; !loaded {
   194  					return E.New("fallback disabled for ALPN: ", connectionState.NegotiatedProtocol)
   195  				}
   196  			}
   197  		}
   198  	}
   199  	if !fallbackAddr.IsValid() {
   200  		if !h.fallbackAddr.IsValid() {
   201  			return E.New("fallback disabled by default")
   202  		}
   203  		fallbackAddr = h.fallbackAddr
   204  	}
   205  	h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr)
   206  	metadata.Destination = fallbackAddr
   207  	return h.router.RouteConnection(ctx, conn, metadata)
   208  }
   209  
   210  func (h *Trojan) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   211  	userIndex, loaded := auth.UserFromContext[int](ctx)
   212  	if !loaded {
   213  		return os.ErrInvalid
   214  	}
   215  	user := h.users[userIndex].Name
   216  	if user == "" {
   217  		user = F.ToString(userIndex)
   218  	} else {
   219  		metadata.User = user
   220  	}
   221  	h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination)
   222  	return h.router.RoutePacketConnection(ctx, conn, metadata)
   223  }
   224  
   225  var _ adapter.V2RayServerTransportHandler = (*trojanTransportHandler)(nil)
   226  
   227  type trojanTransportHandler Trojan
   228  
   229  func (t *trojanTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   230  	return (*Trojan)(t).newTransportConnection(ctx, conn, adapter.InboundContext{
   231  		Source:      metadata.Source,
   232  		Destination: metadata.Destination,
   233  	})
   234  }