github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/inbound/trojan.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "net" 6 "os" 7 8 "github.com/inazumav/sing-box/adapter" 9 "github.com/inazumav/sing-box/common/tls" 10 C "github.com/inazumav/sing-box/constant" 11 "github.com/inazumav/sing-box/log" 12 "github.com/inazumav/sing-box/option" 13 "github.com/inazumav/sing-box/transport/trojan" 14 "github.com/inazumav/sing-box/transport/v2ray" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/auth" 17 E "github.com/sagernet/sing/common/exceptions" 18 F "github.com/sagernet/sing/common/format" 19 M "github.com/sagernet/sing/common/metadata" 20 N "github.com/sagernet/sing/common/network" 21 ) 22 23 var ( 24 _ adapter.Inbound = (*Trojan)(nil) 25 _ adapter.InjectableInbound = (*Trojan)(nil) 26 ) 27 28 type Trojan struct { 29 myInboundAdapter 30 service *trojan.Service[int] 31 users []option.TrojanUser 32 tlsConfig tls.ServerConfig 33 fallbackAddr M.Socksaddr 34 fallbackAddrTLSNextProto map[string]M.Socksaddr 35 transport adapter.V2RayServerTransport 36 } 37 38 func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanInboundOptions) (*Trojan, error) { 39 inbound := &Trojan{ 40 myInboundAdapter: myInboundAdapter{ 41 protocol: C.TypeTrojan, 42 network: []string{N.NetworkTCP}, 43 ctx: ctx, 44 router: router, 45 logger: logger, 46 tag: tag, 47 listenOptions: options.ListenOptions, 48 }, 49 users: options.Users, 50 } 51 if options.TLS != nil { 52 tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) 53 if err != nil { 54 return nil, err 55 } 56 inbound.tlsConfig = tlsConfig 57 } 58 var fallbackHandler N.TCPConnectionHandler 59 if options.Fallback != nil && options.Fallback.Server != "" || len(options.FallbackForALPN) > 0 { 60 if options.Fallback != nil && options.Fallback.Server != "" { 61 inbound.fallbackAddr = options.Fallback.Build() 62 if !inbound.fallbackAddr.IsValid() { 63 return nil, E.New("invalid fallback address: ", inbound.fallbackAddr) 64 } 65 } 66 if len(options.FallbackForALPN) > 0 { 67 if inbound.tlsConfig == nil { 68 return nil, E.New("fallback for ALPN is not supported without TLS") 69 } 70 fallbackAddrNextProto := make(map[string]M.Socksaddr) 71 for nextProto, destination := range options.FallbackForALPN { 72 fallbackAddr := destination.Build() 73 if !fallbackAddr.IsValid() { 74 return nil, E.New("invalid fallback address for ALPN ", nextProto, ": ", fallbackAddr) 75 } 76 fallbackAddrNextProto[nextProto] = fallbackAddr 77 } 78 inbound.fallbackAddrTLSNextProto = fallbackAddrNextProto 79 } 80 fallbackHandler = adapter.NewUpstreamContextHandler(inbound.fallbackConnection, nil, nil) 81 } 82 service := trojan.NewService[int](adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound), fallbackHandler) 83 err := service.UpdateUsers(common.MapIndexed(options.Users, func(index int, it option.TrojanUser) int { 84 return index 85 }), common.Map(options.Users, func(it option.TrojanUser) string { 86 return it.Password 87 })) 88 if err != nil { 89 return nil, err 90 } 91 if options.Transport != nil { 92 inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*trojanTransportHandler)(inbound)) 93 if err != nil { 94 return nil, E.Cause(err, "create server transport: ", options.Transport.Type) 95 } 96 } 97 inbound.service = service 98 inbound.connHandler = inbound 99 return inbound, nil 100 } 101 102 func (h *Trojan) Start() error { 103 if h.tlsConfig != nil { 104 err := h.tlsConfig.Start() 105 if err != nil { 106 return E.Cause(err, "create TLS config") 107 } 108 } 109 if h.transport == nil { 110 return h.myInboundAdapter.Start() 111 } 112 if common.Contains(h.transport.Network(), N.NetworkTCP) { 113 tcpListener, err := h.myInboundAdapter.ListenTCP() 114 if err != nil { 115 return err 116 } 117 go func() { 118 sErr := h.transport.Serve(tcpListener) 119 if sErr != nil && !E.IsClosed(sErr) { 120 h.logger.Error("transport serve error: ", sErr) 121 } 122 }() 123 } 124 if common.Contains(h.transport.Network(), N.NetworkUDP) { 125 udpConn, err := h.myInboundAdapter.ListenUDP() 126 if err != nil { 127 return err 128 } 129 go func() { 130 sErr := h.transport.ServePacket(udpConn) 131 if sErr != nil && !E.IsClosed(sErr) { 132 h.logger.Error("transport serve error: ", sErr) 133 } 134 }() 135 } 136 return nil 137 } 138 139 func (h *Trojan) Close() error { 140 return common.Close( 141 &h.myInboundAdapter, 142 h.tlsConfig, 143 h.transport, 144 ) 145 } 146 147 func (h *Trojan) AddUsers(users []option.TrojanUser) error { 148 if cap(h.users)-len(h.users) >= len(users) { 149 h.users = append(h.users, users...) 150 } else { 151 tmp := make([]option.TrojanUser, 0, len(h.users)+len(users)+10) 152 tmp = append(tmp, h.users...) 153 tmp = append(tmp, users...) 154 h.users = tmp 155 } 156 err := h.service.UpdateUsers(common.MapIndexed(h.users, func(index int, user option.TrojanUser) int { 157 return index 158 }), common.Map(h.users, func(user option.TrojanUser) string { 159 return user.Password 160 })) 161 if err != nil { 162 return err 163 } 164 return nil 165 } 166 167 func (h *Trojan) DelUsers(names []string) error { 168 is := make([]int, 0, len(names)) 169 ulen := len(names) 170 for i := range h.users { 171 for _, n := range names { 172 if h.users[i].Name == n { 173 is = append(is, i) 174 ulen-- 175 } 176 if ulen == 0 { 177 break 178 } 179 } 180 } 181 ulen = len(h.users) 182 for _, i := range is { 183 h.users[i] = h.users[ulen-1] 184 h.users[ulen-1] = option.TrojanUser{} 185 h.users = h.users[:ulen-1] 186 ulen-- 187 } 188 err := h.service.UpdateUsers(common.MapIndexed(h.users, func(index int, user option.TrojanUser) int { 189 return index 190 }), common.Map(h.users, func(user option.TrojanUser) string { 191 return user.Password 192 })) 193 return err 194 } 195 196 func (h *Trojan) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 197 h.injectTCP(conn, metadata) 198 return nil 199 } 200 201 func (h *Trojan) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 202 var err error 203 if h.tlsConfig != nil && h.transport == nil { 204 conn, err = tls.ServerHandshake(ctx, conn, h.tlsConfig) 205 if err != nil { 206 return err 207 } 208 } 209 return h.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata)) 210 } 211 212 func (h *Trojan) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 213 return os.ErrInvalid 214 } 215 216 func (h *Trojan) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 217 userIndex, loaded := auth.UserFromContext[int](ctx) 218 if !loaded { 219 return os.ErrInvalid 220 } 221 user := h.users[userIndex].Name 222 if user == "" { 223 user = F.ToString(userIndex) 224 } else { 225 metadata.User = user 226 } 227 h.logger.InfoContext(ctx, "[", user, "] inbound connection to ", metadata.Destination) 228 return h.router.RouteConnection(ctx, conn, metadata) 229 } 230 231 func (h *Trojan) fallbackConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 232 var fallbackAddr M.Socksaddr 233 if len(h.fallbackAddrTLSNextProto) > 0 { 234 if tlsConn, loaded := common.Cast[tls.Conn](conn); loaded { 235 connectionState := tlsConn.ConnectionState() 236 if connectionState.NegotiatedProtocol != "" { 237 if fallbackAddr, loaded = h.fallbackAddrTLSNextProto[connectionState.NegotiatedProtocol]; !loaded { 238 return E.New("fallback disabled for ALPN: ", connectionState.NegotiatedProtocol) 239 } 240 } 241 } 242 } 243 if !fallbackAddr.IsValid() { 244 if !h.fallbackAddr.IsValid() { 245 return E.New("fallback disabled by default") 246 } 247 fallbackAddr = h.fallbackAddr 248 } 249 h.logger.InfoContext(ctx, "fallback connection to ", fallbackAddr) 250 metadata.Destination = fallbackAddr 251 return h.router.RouteConnection(ctx, conn, metadata) 252 } 253 254 func (h *Trojan) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 255 userIndex, loaded := auth.UserFromContext[int](ctx) 256 if !loaded { 257 return os.ErrInvalid 258 } 259 user := h.users[userIndex].Name 260 if user == "" { 261 user = F.ToString(userIndex) 262 } else { 263 metadata.User = user 264 } 265 h.logger.InfoContext(ctx, "[", user, "] inbound packet connection to ", metadata.Destination) 266 return h.router.RoutePacketConnection(ctx, conn, metadata) 267 } 268 269 var _ adapter.V2RayServerTransportHandler = (*trojanTransportHandler)(nil) 270 271 type trojanTransportHandler Trojan 272 273 func (t *trojanTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 274 return (*Trojan)(t).newTransportConnection(ctx, conn, adapter.InboundContext{ 275 Source: metadata.Source, 276 Destination: metadata.Destination, 277 }) 278 } 279 280 func (t *trojanTransportHandler) FallbackConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 281 return (*Trojan)(t).fallbackConnection(ctx, conn, adapter.InboundContext{ 282 Source: metadata.Source, 283 Destination: metadata.Destination, 284 }) 285 }