github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/inbound/trojan.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"os"
     7  
     8  	"github.com/inazumav/sing-box/adapter"
     9  	"github.com/inazumav/sing-box/common/tls"
    10  	C "github.com/inazumav/sing-box/constant"
    11  	"github.com/inazumav/sing-box/log"
    12  	"github.com/inazumav/sing-box/option"
    13  	"github.com/inazumav/sing-box/transport/trojan"
    14  	"github.com/inazumav/sing-box/transport/v2ray"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/auth"
    17  	E "github.com/sagernet/sing/common/exceptions"
    18  	F "github.com/sagernet/sing/common/format"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  )
    22  
    23  var (
    24  	_ adapter.Inbound           = (*Trojan)(nil)
    25  	_ adapter.InjectableInbound = (*Trojan)(nil)
    26  )
    27  
    28  type Trojan struct {
    29  	myInboundAdapter
    30  	service                  *trojan.Service[int]
    31  	users                    []option.TrojanUser
    32  	tlsConfig                tls.ServerConfig
    33  	fallbackAddr             M.Socksaddr
    34  	fallbackAddrTLSNextProto map[string]M.Socksaddr
    35  	transport                adapter.V2RayServerTransport
    36  }
    37  
    38  func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanInboundOptions) (*Trojan, error) {
    39  	inbound := &Trojan{
    40  		myInboundAdapter: myInboundAdapter{
    41  			protocol:      C.TypeTrojan,
    42  			network:       []string{N.NetworkTCP},
    43  			ctx:           ctx,
    44  			router:        router,
    45  			logger:        logger,
    46  			tag:           tag,
    47  			listenOptions: options.ListenOptions,
    48  		},
    49  		users: options.Users,
    50  	}
    51  	if options.TLS != nil {
    52  		tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  		inbound.tlsConfig = tlsConfig
    57  	}
    58  	var fallbackHandler N.TCPConnectionHandler
    59  	if options.Fallback != nil && options.Fallback.Server != "" || len(options.FallbackForALPN) > 0 {
    60  		if options.Fallback != nil && options.Fallback.Server != "" {
    61  			inbound.fallbackAddr = options.Fallback.Build()
    62  			if !inbound.fallbackAddr.IsValid() {
    63  				return nil, E.New("invalid fallback address: ", inbound.fallbackAddr)
    64  			}
    65  		}
    66  		if len(options.FallbackForALPN) > 0 {
    67  			if inbound.tlsConfig == nil {
    68  				return nil, E.New("fallback for ALPN is not supported without TLS")
    69  			}
    70  			fallbackAddrNextProto := make(map[string]M.Socksaddr)
    71  			for nextProto, destination := range options.FallbackForALPN {
    72  				fallbackAddr := destination.Build()
    73  				if !fallbackAddr.IsValid() {
    74  					return nil, E.New("invalid fallback address for ALPN ", nextProto, ": ", fallbackAddr)
    75  				}
    76  				fallbackAddrNextProto[nextProto] = fallbackAddr
    77  			}
    78  			inbound.fallbackAddrTLSNextProto = fallbackAddrNextProto
    79  		}
    80  		fallbackHandler = adapter.NewUpstreamContextHandler(inbound.fallbackConnection, nil, nil)
    81  	}
    82  	service := trojan.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound), fallbackHandler)
    83  	err := service.UpdateUsers(common.MapIndexed(options.Users, func(index int, it option.TrojanUser) int {
    84  		return index
    85  	}), common.Map(options.Users, func(it option.TrojanUser) string {
    86  		return it.Password
    87  	}))
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if options.Transport != nil {
    92  		inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*trojanTransportHandler)(inbound))
    93  		if err != nil {
    94  			return nil, E.Cause(err, "create server transport: ", options.Transport.Type)
    95  		}
    96  	}
    97  	inbound.service = service
    98  	inbound.connHandler = inbound
    99  	return inbound, nil
   100  }
   101  
   102  func (h *Trojan) Start() error {
   103  	if h.tlsConfig != nil {
   104  		err := h.tlsConfig.Start()
   105  		if err != nil {
   106  			return E.Cause(err, "create TLS config")
   107  		}
   108  	}
   109  	if h.transport == nil {
   110  		return h.myInboundAdapter.Start()
   111  	}
   112  	if common.Contains(h.transport.Network(), N.NetworkTCP) {
   113  		tcpListener, err := h.myInboundAdapter.ListenTCP()
   114  		if err != nil {
   115  			return err
   116  		}
   117  		go func() {
   118  			sErr := h.transport.Serve(tcpListener)
   119  			if sErr != nil && !E.IsClosed(sErr) {
   120  				h.logger.Error("transport serve error: ", sErr)
   121  			}
   122  		}()
   123  	}
   124  	if common.Contains(h.transport.Network(), N.NetworkUDP) {
   125  		udpConn, err := h.myInboundAdapter.ListenUDP()
   126  		if err != nil {
   127  			return err
   128  		}
   129  		go func() {
   130  			sErr := h.transport.ServePacket(udpConn)
   131  			if sErr != nil && !E.IsClosed(sErr) {
   132  				h.logger.Error("transport serve error: ", sErr)
   133  			}
   134  		}()
   135  	}
   136  	return nil
   137  }
   138  
   139  func (h *Trojan) Close() error {
   140  	return common.Close(
   141  		&h.myInboundAdapter,
   142  		h.tlsConfig,
   143  		h.transport,
   144  	)
   145  }
   146  
   147  func (h *Trojan) AddUsers(users []option.TrojanUser) error {
   148  	if cap(h.users)-len(h.users) >= len(users) {
   149  		h.users = append(h.users, users...)
   150  	} else {
   151  		tmp := make([]option.TrojanUser, 0, len(h.users)+len(users)+10)
   152  		tmp = append(tmp, h.users...)
   153  		tmp = append(tmp, users...)
   154  		h.users = tmp
   155  	}
   156  	err := h.service.UpdateUsers(common.MapIndexed(h.users, func(index int, user option.TrojanUser) int {
   157  		return index
   158  	}), common.Map(h.users, func(user option.TrojanUser) string {
   159  		return user.Password
   160  	}))
   161  	if err != nil {
   162  		return err
   163  	}
   164  	return nil
   165  }
   166  
   167  func (h *Trojan) DelUsers(names []string) error {
   168  	is := make([]int, 0, len(names))
   169  	ulen := len(names)
   170  	for i := range h.users {
   171  		for _, n := range names {
   172  			if h.users[i].Name == n {
   173  				is = append(is, i)
   174  				ulen--
   175  			}
   176  			if ulen == 0 {
   177  				break
   178  			}
   179  		}
   180  	}
   181  	ulen = len(h.users)
   182  	for _, i := range is {
   183  		h.users[i] = h.users[ulen-1]
   184  		h.users[ulen-1] = option.TrojanUser{}
   185  		h.users = h.users[:ulen-1]
   186  		ulen--
   187  	}
   188  	err := h.service.UpdateUsers(common.MapIndexed(h.users, func(index int, user option.TrojanUser) int {
   189  		return index
   190  	}), common.Map(h.users, func(user option.TrojanUser) string {
   191  		return user.Password
   192  	}))
   193  	return err
   194  }
   195  
   196  func (h *Trojan) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   197  	h.injectTCP(conn, metadata)
   198  	return nil
   199  }
   200  
   201  func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   202  	var err error
   203  	if h.tlsConfig != nil && h.transport == nil {
   204  		conn, err = tls.ServerHandshake(ctx, conn, h.tlsConfig)
   205  		if err != nil {
   206  			return err
   207  		}
   208  	}
   209  	return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata))
   210  }
   211  
   212  func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   213  	return os.ErrInvalid
   214  }
   215  
   216  func (h *Trojan) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   217  	userIndex, loaded := auth.UserFromContext[int](ctx)
   218  	if !loaded {
   219  		return os.ErrInvalid
   220  	}
   221  	user := h.users[userIndex].Name
   222  	if user == "" {
   223  		user = F.ToString(userIndex)
   224  	} else {
   225  		metadata.User = user
   226  	}
   227  	h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination)
   228  	return h.router.RouteConnection(ctx, conn, metadata)
   229  }
   230  
   231  func (h *Trojan) fallbackConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   232  	var fallbackAddr M.Socksaddr
   233  	if len(h.fallbackAddrTLSNextProto) > 0 {
   234  		if tlsConn, loaded := common.Cast[tls.Conn](conn); loaded {
   235  			connectionState := tlsConn.ConnectionState()
   236  			if connectionState.NegotiatedProtocol != "" {
   237  				if fallbackAddr, loaded = h.fallbackAddrTLSNextProto[connectionState.NegotiatedProtocol]; !loaded {
   238  					return E.New("fallback disabled for ALPN: ", connectionState.NegotiatedProtocol)
   239  				}
   240  			}
   241  		}
   242  	}
   243  	if !fallbackAddr.IsValid() {
   244  		if !h.fallbackAddr.IsValid() {
   245  			return E.New("fallback disabled by default")
   246  		}
   247  		fallbackAddr = h.fallbackAddr
   248  	}
   249  	h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr)
   250  	metadata.Destination = fallbackAddr
   251  	return h.router.RouteConnection(ctx, conn, metadata)
   252  }
   253  
   254  func (h *Trojan) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   255  	userIndex, loaded := auth.UserFromContext[int](ctx)
   256  	if !loaded {
   257  		return os.ErrInvalid
   258  	}
   259  	user := h.users[userIndex].Name
   260  	if user == "" {
   261  		user = F.ToString(userIndex)
   262  	} else {
   263  		metadata.User = user
   264  	}
   265  	h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination)
   266  	return h.router.RoutePacketConnection(ctx, conn, metadata)
   267  }
   268  
   269  var _ adapter.V2RayServerTransportHandler = (*trojanTransportHandler)(nil)
   270  
   271  type trojanTransportHandler Trojan
   272  
   273  func (t *trojanTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   274  	return (*Trojan)(t).newTransportConnection(ctx, conn, adapter.InboundContext{
   275  		Source:      metadata.Source,
   276  		Destination: metadata.Destination,
   277  	})
   278  }
   279  
   280  func (t *trojanTransportHandler) FallbackConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   281  	return (*Trojan)(t).fallbackConnection(ctx, conn, adapter.InboundContext{
   282  		Source:      metadata.Source,
   283  		Destination: metadata.Destination,
   284  	})
   285  }