github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/xtls/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/xtls/xray-core/common" 12 "github.com/xtls/xray-core/common/buf" 13 "github.com/xtls/xray-core/common/log" 14 "github.com/xtls/xray-core/common/net" 15 "github.com/xtls/xray-core/common/protocol" 16 "github.com/xtls/xray-core/common/session" 17 "github.com/xtls/xray-core/core" 18 "github.com/xtls/xray-core/features/dns" 19 "github.com/xtls/xray-core/features/outbound" 20 "github.com/xtls/xray-core/features/policy" 21 "github.com/xtls/xray-core/features/routing" 22 routing_session "github.com/xtls/xray-core/features/routing/session" 23 "github.com/xtls/xray-core/features/stats" 24 "github.com/xtls/xray-core/transport" 25 "github.com/xtls/xray-core/transport/pipe" 26 ) 27 28 var errSniffingTimeout = newError("timeout on sniffing") 29 30 type cachedReader struct { 31 sync.Mutex 32 reader *pipe.Reader 33 cache buf.MultiBuffer 34 } 35 36 func (r *cachedReader) Cache(b *buf.Buffer) { 37 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 38 r.Lock() 39 if !mb.IsEmpty() { 40 r.cache, _ = buf.MergeMulti(r.cache, mb) 41 } 42 b.Clear() 43 rawBytes := b.Extend(buf.Size) 44 n := r.cache.Copy(rawBytes) 45 b.Resize(0, int32(n)) 46 r.Unlock() 47 } 48 49 func (r *cachedReader) readInternal() buf.MultiBuffer { 50 r.Lock() 51 defer r.Unlock() 52 53 if r.cache != nil && !r.cache.IsEmpty() { 54 mb := r.cache 55 r.cache = nil 56 return mb 57 } 58 59 return nil 60 } 61 62 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 63 mb := r.readInternal() 64 if mb != nil { 65 return mb, nil 66 } 67 68 return r.reader.ReadMultiBuffer() 69 } 70 71 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 72 mb := r.readInternal() 73 if mb != nil { 74 return mb, nil 75 } 76 77 return r.reader.ReadMultiBufferTimeout(timeout) 78 } 79 80 func (r *cachedReader) Interrupt() { 81 r.Lock() 82 if r.cache != nil { 83 r.cache = buf.ReleaseMulti(r.cache) 84 } 85 r.Unlock() 86 r.reader.Interrupt() 87 } 88 89 // DefaultDispatcher is a default implementation of Dispatcher. 90 type DefaultDispatcher struct { 91 ohm outbound.Manager 92 router routing.Router 93 policy policy.Manager 94 stats stats.Manager 95 dns dns.Client 96 fdns dns.FakeDNSEngine 97 } 98 99 func init() { 100 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 101 d := new(DefaultDispatcher) 102 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { 103 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 104 d.fdns = fdns 105 }) 106 return d.Init(config.(*Config), om, router, pm, sm, dc) 107 }); err != nil { 108 return nil, err 109 } 110 return d, nil 111 })) 112 } 113 114 // Init initializes DefaultDispatcher. 115 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { 116 d.ohm = om 117 d.router = router 118 d.policy = pm 119 d.stats = sm 120 d.dns = dns 121 return nil 122 } 123 124 // Type implements common.HasType. 125 func (*DefaultDispatcher) Type() interface{} { 126 return routing.DispatcherType() 127 } 128 129 // Start implements common.Runnable. 130 func (*DefaultDispatcher) Start() error { 131 return nil 132 } 133 134 // Close implements common.Closable. 135 func (*DefaultDispatcher) Close() error { return nil } 136 137 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 138 opt := pipe.OptionsFromContext(ctx) 139 uplinkReader, uplinkWriter := pipe.New(opt...) 140 downlinkReader, downlinkWriter := pipe.New(opt...) 141 142 inboundLink := &transport.Link{ 143 Reader: downlinkReader, 144 Writer: uplinkWriter, 145 } 146 147 outboundLink := &transport.Link{ 148 Reader: uplinkReader, 149 Writer: downlinkWriter, 150 } 151 152 sessionInbound := session.InboundFromContext(ctx) 153 var user *protocol.MemoryUser 154 if sessionInbound != nil { 155 user = sessionInbound.User 156 } 157 158 if user != nil && len(user.Email) > 0 { 159 p := d.policy.ForLevel(user.Level) 160 if p.Stats.UserUplink { 161 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 162 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 163 inboundLink.Writer = &SizeStatWriter{ 164 Counter: c, 165 Writer: inboundLink.Writer, 166 } 167 } 168 } 169 if p.Stats.UserDownlink { 170 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 171 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 172 outboundLink.Writer = &SizeStatWriter{ 173 Counter: c, 174 Writer: outboundLink.Writer, 175 } 176 } 177 } 178 } 179 180 return inboundLink, outboundLink 181 } 182 183 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { 184 domain := result.Domain() 185 if domain == "" { 186 return false 187 } 188 for _, d := range request.ExcludeForDomain { 189 if strings.ToLower(domain) == d { 190 return false 191 } 192 } 193 protocolString := result.Protocol() 194 if resComp, ok := result.(SnifferResultComposite); ok { 195 protocolString = resComp.ProtocolForDomainResult() 196 } 197 for _, p := range request.OverrideDestinationForProtocol { 198 if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) { 199 return true 200 } 201 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && 202 fkr0.IsIPInIPPool(destination.Address) { 203 newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) 204 return true 205 } 206 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 207 if resultSubset.IsProtoSubsetOf(p) { 208 return true 209 } 210 } 211 } 212 213 return false 214 } 215 216 // Dispatch implements routing.Dispatcher. 217 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 218 if !destination.IsValid() { 219 panic("Dispatcher: Invalid destination.") 220 } 221 outbounds := session.OutboundsFromContext(ctx) 222 if len(outbounds) == 0 { 223 outbounds = []*session.Outbound{{}} 224 ctx = session.ContextWithOutbounds(ctx, outbounds) 225 } 226 ob := outbounds[len(outbounds) - 1] 227 ob.OriginalTarget = destination 228 ob.Target = destination 229 content := session.ContentFromContext(ctx) 230 if content == nil { 231 content = new(session.Content) 232 ctx = session.ContextWithContent(ctx, content) 233 } 234 235 sniffingRequest := content.SniffingRequest 236 inbound, outbound := d.getLink(ctx) 237 if !sniffingRequest.Enabled { 238 go d.routedDispatch(ctx, outbound, destination) 239 } else { 240 go func() { 241 cReader := &cachedReader{ 242 reader: outbound.Reader.(*pipe.Reader), 243 } 244 outbound.Reader = cReader 245 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 246 if err == nil { 247 content.Protocol = result.Protocol() 248 } 249 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 250 domain := result.Domain() 251 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 252 destination.Address = net.ParseAddress(domain) 253 protocol := result.Protocol() 254 if resComp, ok := result.(SnifferResultComposite); ok { 255 protocol = resComp.ProtocolForDomainResult() 256 } 257 isFakeIP := false 258 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) { 259 isFakeIP = true 260 } 261 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 262 ob.RouteTarget = destination 263 } else { 264 ob.Target = destination 265 } 266 } 267 d.routedDispatch(ctx, outbound, destination) 268 }() 269 } 270 return inbound, nil 271 } 272 273 // DispatchLink implements routing.Dispatcher. 274 func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { 275 if !destination.IsValid() { 276 return newError("Dispatcher: Invalid destination.") 277 } 278 outbounds := session.OutboundsFromContext(ctx) 279 if len(outbounds) == 0 { 280 outbounds = []*session.Outbound{{}} 281 ctx = session.ContextWithOutbounds(ctx, outbounds) 282 } 283 ob := outbounds[len(outbounds) - 1] 284 ob.OriginalTarget = destination 285 ob.Target = destination 286 content := session.ContentFromContext(ctx) 287 if content == nil { 288 content = new(session.Content) 289 ctx = session.ContextWithContent(ctx, content) 290 } 291 sniffingRequest := content.SniffingRequest 292 if !sniffingRequest.Enabled { 293 d.routedDispatch(ctx, outbound, destination) 294 } else { 295 cReader := &cachedReader{ 296 reader: outbound.Reader.(*pipe.Reader), 297 } 298 outbound.Reader = cReader 299 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 300 if err == nil { 301 content.Protocol = result.Protocol() 302 } 303 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 304 domain := result.Domain() 305 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 306 destination.Address = net.ParseAddress(domain) 307 protocol := result.Protocol() 308 if resComp, ok := result.(SnifferResultComposite); ok { 309 protocol = resComp.ProtocolForDomainResult() 310 } 311 isFakeIP := false 312 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) { 313 isFakeIP = true 314 } 315 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 316 ob.RouteTarget = destination 317 } else { 318 ob.Target = destination 319 } 320 } 321 d.routedDispatch(ctx, outbound, destination) 322 } 323 324 return nil 325 } 326 327 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { 328 payload := buf.New() 329 defer payload.Release() 330 331 sniffer := NewSniffer(ctx) 332 333 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 334 335 if metadataOnly { 336 return metaresult, metadataErr 337 } 338 339 contentResult, contentErr := func() (SniffResult, error) { 340 totalAttempt := 0 341 for { 342 select { 343 case <-ctx.Done(): 344 return nil, ctx.Err() 345 default: 346 totalAttempt++ 347 if totalAttempt > 2 { 348 return nil, errSniffingTimeout 349 } 350 351 cReader.Cache(payload) 352 if !payload.IsEmpty() { 353 result, err := sniffer.Sniff(ctx, payload.Bytes(), network) 354 if err != common.ErrNoClue { 355 return result, err 356 } 357 } 358 if payload.IsFull() { 359 return nil, errUnknownContent 360 } 361 } 362 } 363 }() 364 if contentErr != nil && metadataErr == nil { 365 return metaresult, nil 366 } 367 if contentErr == nil && metadataErr == nil { 368 return CompositeResult(metaresult, contentResult), nil 369 } 370 return contentResult, contentErr 371 } 372 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 373 outbounds := session.OutboundsFromContext(ctx) 374 ob := outbounds[len(outbounds) - 1] 375 if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { 376 proxied := hosts.LookupHosts(ob.Target.String()) 377 if proxied != nil { 378 ro := ob.RouteTarget == destination 379 destination.Address = *proxied 380 if ro { 381 ob.RouteTarget = destination 382 } else { 383 ob.Target = destination 384 } 385 } 386 } 387 388 var handler outbound.Handler 389 390 routingLink := routing_session.AsRoutingContext(ctx) 391 inTag := routingLink.GetInboundTag() 392 isPickRoute := 0 393 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 394 ctx = session.SetForcedOutboundTagToContext(ctx, "") 395 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 396 isPickRoute = 1 397 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 398 handler = h 399 } else { 400 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 401 common.Close(link.Writer) 402 common.Interrupt(link.Reader) 403 return 404 } 405 } else if d.router != nil { 406 if route, err := d.router.PickRoute(routingLink); err == nil { 407 outTag := route.GetOutboundTag() 408 if h := d.ohm.GetHandler(outTag); h != nil { 409 isPickRoute = 2 410 newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 411 handler = h 412 } else { 413 newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 414 } 415 } else { 416 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 417 } 418 } 419 420 if handler == nil { 421 handler = d.ohm.GetDefaultHandler() 422 } 423 424 if handler == nil { 425 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 426 common.Close(link.Writer) 427 common.Interrupt(link.Reader) 428 return 429 } 430 431 ob.Tag = handler.Tag() 432 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 433 if tag := handler.Tag(); tag != "" { 434 if inTag == "" { 435 accessMessage.Detour = tag 436 } else if isPickRoute == 1 { 437 accessMessage.Detour = inTag + " ==> " + tag 438 } else if isPickRoute == 2 { 439 accessMessage.Detour = inTag + " -> " + tag 440 } else { 441 accessMessage.Detour = inTag + " >> " + tag 442 } 443 } 444 log.Record(accessMessage) 445 } 446 447 handler.Dispatch(ctx, link) 448 }