github.com/xmplusdev/xray-core@v1.8.10/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/xmplusdev/xray-core/common" 12 "github.com/xmplusdev/xray-core/common/buf" 13 "github.com/xmplusdev/xray-core/common/log" 14 "github.com/xmplusdev/xray-core/common/net" 15 "github.com/xmplusdev/xray-core/common/protocol" 16 "github.com/xmplusdev/xray-core/common/session" 17 "github.com/xmplusdev/xray-core/core" 18 "github.com/xmplusdev/xray-core/features/dns" 19 "github.com/xmplusdev/xray-core/features/outbound" 20 "github.com/xmplusdev/xray-core/features/policy" 21 "github.com/xmplusdev/xray-core/features/routing" 22 routing_session "github.com/xmplusdev/xray-core/features/routing/session" 23 "github.com/xmplusdev/xray-core/features/stats" 24 "github.com/xmplusdev/xray-core/transport" 25 "github.com/xmplusdev/xray-core/transport/pipe" 26 ) 27 28 var errSniffingTimeout = newError("timeout on sniffing") 29 30 type cachedReader struct { 31 sync.Mutex 32 reader *pipe.Reader 33 cache buf.MultiBuffer 34 } 35 36 func (r *cachedReader) Cache(b *buf.Buffer) { 37 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 38 r.Lock() 39 if !mb.IsEmpty() { 40 r.cache, _ = buf.MergeMulti(r.cache, mb) 41 } 42 b.Clear() 43 rawBytes := b.Extend(buf.Size) 44 n := r.cache.Copy(rawBytes) 45 b.Resize(0, int32(n)) 46 r.Unlock() 47 } 48 49 func (r *cachedReader) readInternal() buf.MultiBuffer { 50 r.Lock() 51 defer r.Unlock() 52 53 if r.cache != nil && !r.cache.IsEmpty() { 54 mb := r.cache 55 r.cache = nil 56 return mb 57 } 58 59 return nil 60 } 61 62 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 63 mb := r.readInternal() 64 if mb != nil { 65 return mb, nil 66 } 67 68 return r.reader.ReadMultiBuffer() 69 } 70 71 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 72 mb := r.readInternal() 73 if mb != nil { 74 return mb, nil 75 } 76 77 return r.reader.ReadMultiBufferTimeout(timeout) 78 } 79 80 func (r *cachedReader) Interrupt() { 81 r.Lock() 82 if r.cache != nil { 83 r.cache = buf.ReleaseMulti(r.cache) 84 } 85 r.Unlock() 86 r.reader.Interrupt() 87 } 88 89 // DefaultDispatcher is a default implementation of Dispatcher. 90 type DefaultDispatcher struct { 91 ohm outbound.Manager 92 router routing.Router 93 policy policy.Manager 94 stats stats.Manager 95 dns dns.Client 96 fdns dns.FakeDNSEngine 97 } 98 99 func init() { 100 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 101 d := new(DefaultDispatcher) 102 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { 103 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 104 d.fdns = fdns 105 }) 106 return d.Init(config.(*Config), om, router, pm, sm, dc) 107 }); err != nil { 108 return nil, err 109 } 110 return d, nil 111 })) 112 } 113 114 // Init initializes DefaultDispatcher. 115 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { 116 d.ohm = om 117 d.router = router 118 d.policy = pm 119 d.stats = sm 120 d.dns = dns 121 return nil 122 } 123 124 // Type implements common.HasType. 125 func (*DefaultDispatcher) Type() interface{} { 126 return routing.DispatcherType() 127 } 128 129 // Start implements common.Runnable. 130 func (*DefaultDispatcher) Start() error { 131 return nil 132 } 133 134 // Close implements common.Closable. 135 func (*DefaultDispatcher) Close() error { return nil } 136 137 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 138 opt := pipe.OptionsFromContext(ctx) 139 uplinkReader, uplinkWriter := pipe.New(opt...) 140 downlinkReader, downlinkWriter := pipe.New(opt...) 141 142 inboundLink := &transport.Link{ 143 Reader: downlinkReader, 144 Writer: uplinkWriter, 145 } 146 147 outboundLink := &transport.Link{ 148 Reader: uplinkReader, 149 Writer: downlinkWriter, 150 } 151 152 sessionInbound := session.InboundFromContext(ctx) 153 var user *protocol.MemoryUser 154 if sessionInbound != nil { 155 user = sessionInbound.User 156 } 157 158 if user != nil && len(user.Email) > 0 { 159 p := d.policy.ForLevel(user.Level) 160 if p.Stats.UserUplink { 161 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 162 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 163 inboundLink.Writer = &SizeStatWriter{ 164 Counter: c, 165 Writer: inboundLink.Writer, 166 } 167 } 168 } 169 if p.Stats.UserDownlink { 170 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 171 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 172 outboundLink.Writer = &SizeStatWriter{ 173 Counter: c, 174 Writer: outboundLink.Writer, 175 } 176 } 177 } 178 } 179 180 return inboundLink, outboundLink 181 } 182 183 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { 184 domain := result.Domain() 185 if domain == "" { 186 return false 187 } 188 for _, d := range request.ExcludeForDomain { 189 if strings.ToLower(domain) == d { 190 return false 191 } 192 } 193 protocolString := result.Protocol() 194 if resComp, ok := result.(SnifferResultComposite); ok { 195 protocolString = resComp.ProtocolForDomainResult() 196 } 197 for _, p := range request.OverrideDestinationForProtocol { 198 if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) { 199 return true 200 } 201 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && 202 fkr0.IsIPInIPPool(destination.Address) { 203 newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) 204 return true 205 } 206 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 207 if resultSubset.IsProtoSubsetOf(p) { 208 return true 209 } 210 } 211 } 212 213 return false 214 } 215 216 // Dispatch implements routing.Dispatcher. 217 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 218 if !destination.IsValid() { 219 panic("Dispatcher: Invalid destination.") 220 } 221 ob := session.OutboundFromContext(ctx) 222 if ob == nil { 223 ob = &session.Outbound{} 224 ctx = session.ContextWithOutbound(ctx, ob) 225 } 226 ob.OriginalTarget = destination 227 ob.Target = destination 228 content := session.ContentFromContext(ctx) 229 if content == nil { 230 content = new(session.Content) 231 ctx = session.ContextWithContent(ctx, content) 232 } 233 234 sniffingRequest := content.SniffingRequest 235 inbound, outbound := d.getLink(ctx) 236 if !sniffingRequest.Enabled { 237 go d.routedDispatch(ctx, outbound, destination) 238 } else { 239 go func() { 240 cReader := &cachedReader{ 241 reader: outbound.Reader.(*pipe.Reader), 242 } 243 outbound.Reader = cReader 244 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 245 if err == nil { 246 content.Protocol = result.Protocol() 247 } 248 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 249 domain := result.Domain() 250 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 251 destination.Address = net.ParseAddress(domain) 252 protocol := result.Protocol() 253 if resComp, ok := result.(SnifferResultComposite); ok { 254 protocol = resComp.ProtocolForDomainResult() 255 } 256 isFakeIP := false 257 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) { 258 isFakeIP = true 259 } 260 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 261 ob.RouteTarget = destination 262 } else { 263 ob.Target = destination 264 } 265 } 266 d.routedDispatch(ctx, outbound, destination) 267 }() 268 } 269 return inbound, nil 270 } 271 272 // DispatchLink implements routing.Dispatcher. 273 func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { 274 if !destination.IsValid() { 275 return newError("Dispatcher: Invalid destination.") 276 } 277 ob := session.OutboundFromContext(ctx) 278 if ob == nil { 279 ob = &session.Outbound{} 280 ctx = session.ContextWithOutbound(ctx, ob) 281 } 282 ob.OriginalTarget = destination 283 ob.Target = destination 284 content := session.ContentFromContext(ctx) 285 if content == nil { 286 content = new(session.Content) 287 ctx = session.ContextWithContent(ctx, content) 288 } 289 sniffingRequest := content.SniffingRequest 290 if !sniffingRequest.Enabled { 291 d.routedDispatch(ctx, outbound, destination) 292 } else { 293 cReader := &cachedReader{ 294 reader: outbound.Reader.(*pipe.Reader), 295 } 296 outbound.Reader = cReader 297 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 298 if err == nil { 299 content.Protocol = result.Protocol() 300 } 301 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 302 domain := result.Domain() 303 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 304 destination.Address = net.ParseAddress(domain) 305 protocol := result.Protocol() 306 if resComp, ok := result.(SnifferResultComposite); ok { 307 protocol = resComp.ProtocolForDomainResult() 308 } 309 isFakeIP := false 310 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) { 311 isFakeIP = true 312 } 313 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 314 ob.RouteTarget = destination 315 } else { 316 ob.Target = destination 317 } 318 } 319 d.routedDispatch(ctx, outbound, destination) 320 } 321 322 return nil 323 } 324 325 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { 326 payload := buf.New() 327 defer payload.Release() 328 329 sniffer := NewSniffer(ctx) 330 331 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 332 333 if metadataOnly { 334 return metaresult, metadataErr 335 } 336 337 contentResult, contentErr := func() (SniffResult, error) { 338 totalAttempt := 0 339 for { 340 select { 341 case <-ctx.Done(): 342 return nil, ctx.Err() 343 default: 344 totalAttempt++ 345 if totalAttempt > 2 { 346 return nil, errSniffingTimeout 347 } 348 349 cReader.Cache(payload) 350 if !payload.IsEmpty() { 351 result, err := sniffer.Sniff(ctx, payload.Bytes(), network) 352 if err != common.ErrNoClue { 353 return result, err 354 } 355 } 356 if payload.IsFull() { 357 return nil, errUnknownContent 358 } 359 } 360 } 361 }() 362 if contentErr != nil && metadataErr == nil { 363 return metaresult, nil 364 } 365 if contentErr == nil && metadataErr == nil { 366 return CompositeResult(metaresult, contentResult), nil 367 } 368 return contentResult, contentErr 369 } 370 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 371 ob := session.OutboundFromContext(ctx) 372 if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { 373 proxied := hosts.LookupHosts(ob.Target.String()) 374 if proxied != nil { 375 ro := ob.RouteTarget == destination 376 destination.Address = *proxied 377 if ro { 378 ob.RouteTarget = destination 379 } else { 380 ob.Target = destination 381 } 382 } 383 } 384 385 var handler outbound.Handler 386 387 routingLink := routing_session.AsRoutingContext(ctx) 388 inTag := routingLink.GetInboundTag() 389 isPickRoute := 0 390 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 391 ctx = session.SetForcedOutboundTagToContext(ctx, "") 392 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 393 isPickRoute = 1 394 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 395 handler = h 396 } else { 397 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 398 common.Close(link.Writer) 399 common.Interrupt(link.Reader) 400 return 401 } 402 } else if d.router != nil { 403 if route, err := d.router.PickRoute(routingLink); err == nil { 404 outTag := route.GetOutboundTag() 405 if h := d.ohm.GetHandler(outTag); h != nil { 406 isPickRoute = 2 407 newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 408 handler = h 409 } else { 410 newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 411 } 412 } else { 413 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 414 } 415 } 416 417 if handler == nil { 418 handler = d.ohm.GetDefaultHandler() 419 } 420 421 if handler == nil { 422 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 423 common.Close(link.Writer) 424 common.Interrupt(link.Reader) 425 return 426 } 427 428 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 429 if tag := handler.Tag(); tag != "" { 430 if inTag == "" { 431 accessMessage.Detour = tag 432 } else if isPickRoute == 1 { 433 accessMessage.Detour = inTag + " ==> " + tag 434 } else if isPickRoute == 2 { 435 accessMessage.Detour = inTag + " -> " + tag 436 } else { 437 accessMessage.Detour = inTag + " >> " + tag 438 } 439 } 440 log.Record(accessMessage) 441 } 442 443 handler.Dispatch(ctx, link) 444 }