github.com/sagernet/sing-box@v1.9.0-rc.20/inbound/trojan.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "net" 6 "os" 7 8 "github.com/sagernet/sing-box/adapter" 9 "github.com/sagernet/sing-box/common/mux" 10 "github.com/sagernet/sing-box/common/tls" 11 C "github.com/sagernet/sing-box/constant" 12 "github.com/sagernet/sing-box/log" 13 "github.com/sagernet/sing-box/option" 14 "github.com/sagernet/sing-box/transport/trojan" 15 "github.com/sagernet/sing-box/transport/v2ray" 16 "github.com/sagernet/sing/common" 17 "github.com/sagernet/sing/common/auth" 18 E "github.com/sagernet/sing/common/exceptions" 19 F "github.com/sagernet/sing/common/format" 20 M "github.com/sagernet/sing/common/metadata" 21 N "github.com/sagernet/sing/common/network" 22 ) 23 24 var ( 25 _ adapter.Inbound = (*Trojan)(nil) 26 _ adapter.InjectableInbound = (*Trojan)(nil) 27 ) 28 29 type Trojan struct { 30 myInboundAdapter 31 service *trojan.Service[int] 32 users []option.TrojanUser 33 tlsConfig tls.ServerConfig 34 fallbackAddr M.Socksaddr 35 fallbackAddrTLSNextProto map[string]M.Socksaddr 36 transport adapter.V2RayServerTransport 37 } 38 39 func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanInboundOptions) (*Trojan, error) { 40 inbound := &Trojan{ 41 myInboundAdapter: myInboundAdapter{ 42 protocol: C.TypeTrojan, 43 network: []string{N.NetworkTCP}, 44 ctx: ctx, 45 router: router, 46 logger: logger, 47 tag: tag, 48 listenOptions: options.ListenOptions, 49 }, 50 users: options.Users, 51 } 52 if options.TLS != nil { 53 tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) 54 if err != nil { 55 return nil, err 56 } 57 inbound.tlsConfig = tlsConfig 58 } 59 var fallbackHandler N.TCPConnectionHandler 60 if options.Fallback != nil && options.Fallback.Server != "" || len(options.FallbackForALPN) > 0 { 61 if options.Fallback != nil && options.Fallback.Server != "" { 62 inbound.fallbackAddr = options.Fallback.Build() 63 if !inbound.fallbackAddr.IsValid() { 64 return nil, E.New("invalid fallback address: ", inbound.fallbackAddr) 65 } 66 } 67 if len(options.FallbackForALPN) > 0 { 68 if inbound.tlsConfig == nil { 69 return nil, E.New("fallback for ALPN is not supported without TLS") 70 } 71 fallbackAddrNextProto := make(map[string]M.Socksaddr) 72 for nextProto, destination := range options.FallbackForALPN { 73 fallbackAddr := destination.Build() 74 if !fallbackAddr.IsValid() { 75 return nil, E.New("invalid fallback address for ALPN ", nextProto, ": ", fallbackAddr) 76 } 77 fallbackAddrNextProto[nextProto] = fallbackAddr 78 } 79 inbound.fallbackAddrTLSNextProto = fallbackAddrNextProto 80 } 81 fallbackHandler = adapter.NewUpstreamContextHandler(inbound.fallbackConnection, nil, nil) 82 } 83 service := trojan.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound), fallbackHandler) 84 err := service.UpdateUsers(common.MapIndexed(options.Users, func(index int, it option.TrojanUser) int { 85 return index 86 }), common.Map(options.Users, func(it option.TrojanUser) string { 87 return it.Password 88 })) 89 if err != nil { 90 return nil, err 91 } 92 if options.Transport != nil { 93 inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*trojanTransportHandler)(inbound)) 94 if err != nil { 95 return nil, E.Cause(err, "create server transport: ", options.Transport.Type) 96 } 97 } 98 inbound.router, err = mux.NewRouterWithOptions(inbound.router, logger, common.PtrValueOrDefault(options.Multiplex)) 99 if err != nil { 100 return nil, err 101 } 102 inbound.service = service 103 inbound.connHandler = inbound 104 return inbound, nil 105 } 106 107 func (h *Trojan) Start() error { 108 if h.tlsConfig != nil { 109 err := h.tlsConfig.Start() 110 if err != nil { 111 return E.Cause(err, "create TLS config") 112 } 113 } 114 if h.transport == nil { 115 return h.myInboundAdapter.Start() 116 } 117 if common.Contains(h.transport.Network(), N.NetworkTCP) { 118 tcpListener, err := h.myInboundAdapter.ListenTCP() 119 if err != nil { 120 return err 121 } 122 go func() { 123 sErr := h.transport.Serve(tcpListener) 124 if sErr != nil && !E.IsClosed(sErr) { 125 h.logger.Error("transport serve error: ", sErr) 126 } 127 }() 128 } 129 if common.Contains(h.transport.Network(), N.NetworkUDP) { 130 udpConn, err := h.myInboundAdapter.ListenUDP() 131 if err != nil { 132 return err 133 } 134 go func() { 135 sErr := h.transport.ServePacket(udpConn) 136 if sErr != nil && !E.IsClosed(sErr) { 137 h.logger.Error("transport serve error: ", sErr) 138 } 139 }() 140 } 141 return nil 142 } 143 144 func (h *Trojan) Close() error { 145 return common.Close( 146 &h.myInboundAdapter, 147 h.tlsConfig, 148 h.transport, 149 ) 150 } 151 152 func (h *Trojan) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 153 h.injectTCP(conn, metadata) 154 return nil 155 } 156 157 func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 158 var err error 159 if h.tlsConfig != nil && h.transport == nil { 160 conn, err = tls.ServerHandshake(ctx, conn, h.tlsConfig) 161 if err != nil { 162 return err 163 } 164 } 165 return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata)) 166 } 167 168 func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 169 return os.ErrInvalid 170 } 171 172 func (h *Trojan) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 173 userIndex, loaded := auth.UserFromContext[int](ctx) 174 if !loaded { 175 return os.ErrInvalid 176 } 177 user := h.users[userIndex].Name 178 if user == "" { 179 user = F.ToString(userIndex) 180 } else { 181 metadata.User = user 182 } 183 h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination) 184 return h.router.RouteConnection(ctx, conn, metadata) 185 } 186 187 func (h *Trojan) fallbackConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 188 var fallbackAddr M.Socksaddr 189 if len(h.fallbackAddrTLSNextProto) > 0 { 190 if tlsConn, loaded := common.Cast[tls.Conn](conn); loaded { 191 connectionState := tlsConn.ConnectionState() 192 if connectionState.NegotiatedProtocol != "" { 193 if fallbackAddr, loaded = h.fallbackAddrTLSNextProto[connectionState.NegotiatedProtocol]; !loaded { 194 return E.New("fallback disabled for ALPN: ", connectionState.NegotiatedProtocol) 195 } 196 } 197 } 198 } 199 if !fallbackAddr.IsValid() { 200 if !h.fallbackAddr.IsValid() { 201 return E.New("fallback disabled by default") 202 } 203 fallbackAddr = h.fallbackAddr 204 } 205 h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr) 206 metadata.Destination = fallbackAddr 207 return h.router.RouteConnection(ctx, conn, metadata) 208 } 209 210 func (h *Trojan) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 211 userIndex, loaded := auth.UserFromContext[int](ctx) 212 if !loaded { 213 return os.ErrInvalid 214 } 215 user := h.users[userIndex].Name 216 if user == "" { 217 user = F.ToString(userIndex) 218 } else { 219 metadata.User = user 220 } 221 h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination) 222 return h.router.RoutePacketConnection(ctx, conn, metadata) 223 } 224 225 var _ adapter.V2RayServerTransportHandler = (*trojanTransportHandler)(nil) 226 227 type trojanTransportHandler Trojan 228 229 func (t *trojanTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 230 return (*Trojan)(t).newTransportConnection(ctx, conn, adapter.InboundContext{ 231 Source: metadata.Source, 232 Destination: metadata.Destination, 233 }) 234 }