github.com/xraypb/xray-core@v1.6.6/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/xraypb/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 "fmt" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/xraypb/xray-core/common" 13 "github.com/xraypb/xray-core/common/buf" 14 "github.com/xraypb/xray-core/common/log" 15 "github.com/xraypb/xray-core/common/net" 16 "github.com/xraypb/xray-core/common/protocol" 17 "github.com/xraypb/xray-core/common/session" 18 "github.com/xraypb/xray-core/core" 19 "github.com/xraypb/xray-core/features/dns" 20 "github.com/xraypb/xray-core/features/outbound" 21 "github.com/xraypb/xray-core/features/policy" 22 "github.com/xraypb/xray-core/features/routing" 23 routing_session "github.com/xraypb/xray-core/features/routing/session" 24 "github.com/xraypb/xray-core/features/stats" 25 "github.com/xraypb/xray-core/transport" 26 "github.com/xraypb/xray-core/transport/pipe" 27 ) 28 29 var errSniffingTimeout = newError("timeout on sniffing") 30 31 type cachedReader struct { 32 sync.Mutex 33 reader *pipe.Reader 34 cache buf.MultiBuffer 35 } 36 37 func (r *cachedReader) Cache(b *buf.Buffer) { 38 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 39 r.Lock() 40 if !mb.IsEmpty() { 41 r.cache, _ = buf.MergeMulti(r.cache, mb) 42 } 43 b.Clear() 44 rawBytes := b.Extend(buf.Size) 45 n := r.cache.Copy(rawBytes) 46 b.Resize(0, int32(n)) 47 r.Unlock() 48 } 49 50 func (r *cachedReader) readInternal() buf.MultiBuffer { 51 r.Lock() 52 defer r.Unlock() 53 54 if r.cache != nil && !r.cache.IsEmpty() { 55 mb := r.cache 56 r.cache = nil 57 return mb 58 } 59 60 return nil 61 } 62 63 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 64 mb := r.readInternal() 65 if mb != nil { 66 return mb, nil 67 } 68 69 return r.reader.ReadMultiBuffer() 70 } 71 72 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 73 mb := r.readInternal() 74 if mb != nil { 75 return mb, nil 76 } 77 78 return r.reader.ReadMultiBufferTimeout(timeout) 79 } 80 81 func (r *cachedReader) Interrupt() { 82 r.Lock() 83 if r.cache != nil { 84 r.cache = buf.ReleaseMulti(r.cache) 85 } 86 r.Unlock() 87 r.reader.Interrupt() 88 } 89 90 // DefaultDispatcher is a default implementation of Dispatcher. 91 type DefaultDispatcher struct { 92 ohm outbound.Manager 93 router routing.Router 94 policy policy.Manager 95 stats stats.Manager 96 dns dns.Client 97 fdns dns.FakeDNSEngine 98 } 99 100 func init() { 101 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 102 d := new(DefaultDispatcher) 103 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { 104 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 105 d.fdns = fdns 106 }) 107 return d.Init(config.(*Config), om, router, pm, sm, dc) 108 }); err != nil { 109 return nil, err 110 } 111 return d, nil 112 })) 113 } 114 115 // Init initializes DefaultDispatcher. 116 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { 117 d.ohm = om 118 d.router = router 119 d.policy = pm 120 d.stats = sm 121 d.dns = dns 122 return nil 123 } 124 125 // Type implements common.HasType. 126 func (*DefaultDispatcher) Type() interface{} { 127 return routing.DispatcherType() 128 } 129 130 // Start implements common.Runnable. 131 func (*DefaultDispatcher) Start() error { 132 return nil 133 } 134 135 // Close implements common.Closable. 136 func (*DefaultDispatcher) Close() error { return nil } 137 138 func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) { 139 downOpt := pipe.OptionsFromContext(ctx) 140 upOpt := downOpt 141 142 if network == net.Network_UDP { 143 var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns 144 // Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs. 145 // When target replies, server will restore the domain and send back to client. 146 // Note: this map is not global but per connection context 147 upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { 148 for i, buffer := range mb { 149 if buffer.UDP == nil { 150 continue 151 } 152 addr := buffer.UDP.Address 153 if addr.Family().IsIP() { 154 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled { 155 domain := fkr0.GetDomainFromFakeDNS(addr) 156 if len(domain) > 0 { 157 buffer.UDP.Address = net.DomainAddress(domain) 158 newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) 159 } else { 160 newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 161 } 162 } 163 } else { 164 if ip2domain == nil { 165 ip2domain = new(sync.Map) 166 newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx)) 167 } 168 domain := addr.Domain() 169 ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false}) 170 if err == nil { 171 for _, ip := range ips { 172 ip2domain.Store(ip.String(), domain) 173 } 174 newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) 175 } else { 176 newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx)) 177 } 178 } 179 } 180 return mb 181 })) 182 downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer { 183 for i, buffer := range mb { 184 if buffer.UDP == nil { 185 continue 186 } 187 addr := buffer.UDP.Address 188 if addr.Family().IsIP() { 189 if ip2domain == nil { 190 continue 191 } 192 if domain, found := ip2domain.Load(addr.IP().String()); found { 193 buffer.UDP.Address = net.DomainAddress(domain.(string)) 194 newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) 195 } 196 } else { 197 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok { 198 fakeIp := fkr0.GetFakeIPForDomain(addr.Domain()) 199 buffer.UDP.Address = fakeIp[0] 200 newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx)) 201 } 202 } 203 } 204 return mb 205 })) 206 } 207 uplinkReader, uplinkWriter := pipe.New(upOpt...) 208 downlinkReader, downlinkWriter := pipe.New(downOpt...) 209 210 inboundLink := &transport.Link{ 211 Reader: downlinkReader, 212 Writer: uplinkWriter, 213 } 214 215 outboundLink := &transport.Link{ 216 Reader: uplinkReader, 217 Writer: downlinkWriter, 218 } 219 220 sessionInbound := session.InboundFromContext(ctx) 221 var user *protocol.MemoryUser 222 if sessionInbound != nil { 223 user = sessionInbound.User 224 } 225 226 if user != nil && len(user.Email) > 0 { 227 p := d.policy.ForLevel(user.Level) 228 if p.Stats.UserUplink { 229 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 230 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 231 inboundLink.Writer = &SizeStatWriter{ 232 Counter: c, 233 Writer: inboundLink.Writer, 234 } 235 } 236 } 237 if p.Stats.UserDownlink { 238 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 239 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 240 outboundLink.Writer = &SizeStatWriter{ 241 Counter: c, 242 Writer: outboundLink.Writer, 243 } 244 } 245 } 246 } 247 248 return inboundLink, outboundLink 249 } 250 251 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { 252 domain := result.Domain() 253 if domain == "" { 254 return false 255 } 256 for _, d := range request.ExcludeForDomain { 257 if strings.ToLower(domain) == d { 258 return false 259 } 260 } 261 protocolString := result.Protocol() 262 if resComp, ok := result.(SnifferResultComposite); ok { 263 protocolString = resComp.ProtocolForDomainResult() 264 } 265 for _, p := range request.OverrideDestinationForProtocol { 266 if strings.HasPrefix(protocolString, p) { 267 return true 268 } 269 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && 270 destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) { 271 newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) 272 return true 273 } 274 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 275 if resultSubset.IsProtoSubsetOf(p) { 276 return true 277 } 278 } 279 } 280 281 return false 282 } 283 284 // Dispatch implements routing.Dispatcher. 285 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 286 if !destination.IsValid() { 287 panic("Dispatcher: Invalid destination.") 288 } 289 ob := &session.Outbound{ 290 Target: destination, 291 } 292 ctx = session.ContextWithOutbound(ctx, ob) 293 content := session.ContentFromContext(ctx) 294 if content == nil { 295 content = new(session.Content) 296 ctx = session.ContextWithContent(ctx, content) 297 } 298 299 sniffingRequest := content.SniffingRequest 300 inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest) 301 if !sniffingRequest.Enabled { 302 go d.routedDispatch(ctx, outbound, destination) 303 } else { 304 go func() { 305 cReader := &cachedReader{ 306 reader: outbound.Reader.(*pipe.Reader), 307 } 308 outbound.Reader = cReader 309 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 310 if err == nil { 311 content.Protocol = result.Protocol() 312 } 313 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 314 domain := result.Domain() 315 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 316 destination.Address = net.ParseAddress(domain) 317 if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { 318 ob.RouteTarget = destination 319 } else { 320 ob.Target = destination 321 } 322 } 323 d.routedDispatch(ctx, outbound, destination) 324 }() 325 } 326 return inbound, nil 327 } 328 329 // DispatchLink implements routing.Dispatcher. 330 func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { 331 if !destination.IsValid() { 332 return newError("Dispatcher: Invalid destination.") 333 } 334 ob := &session.Outbound{ 335 Target: destination, 336 } 337 ctx = session.ContextWithOutbound(ctx, ob) 338 content := session.ContentFromContext(ctx) 339 if content == nil { 340 content = new(session.Content) 341 ctx = session.ContextWithContent(ctx, content) 342 } 343 sniffingRequest := content.SniffingRequest 344 if !sniffingRequest.Enabled { 345 go d.routedDispatch(ctx, outbound, destination) 346 } else { 347 go func() { 348 cReader := &cachedReader{ 349 reader: outbound.Reader.(*pipe.Reader), 350 } 351 outbound.Reader = cReader 352 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 353 if err == nil { 354 content.Protocol = result.Protocol() 355 } 356 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 357 domain := result.Domain() 358 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 359 destination.Address = net.ParseAddress(domain) 360 if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" { 361 ob.RouteTarget = destination 362 } else { 363 ob.Target = destination 364 } 365 } 366 d.routedDispatch(ctx, outbound, destination) 367 }() 368 } 369 370 return nil 371 } 372 373 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { 374 payload := buf.New() 375 defer payload.Release() 376 377 sniffer := NewSniffer(ctx) 378 379 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 380 381 if metadataOnly { 382 return metaresult, metadataErr 383 } 384 385 contentResult, contentErr := func() (SniffResult, error) { 386 totalAttempt := 0 387 for { 388 select { 389 case <-ctx.Done(): 390 return nil, ctx.Err() 391 default: 392 totalAttempt++ 393 if totalAttempt > 2 { 394 return nil, errSniffingTimeout 395 } 396 397 cReader.Cache(payload) 398 if !payload.IsEmpty() { 399 result, err := sniffer.Sniff(ctx, payload.Bytes(), network) 400 if err != common.ErrNoClue { 401 return result, err 402 } 403 } 404 if payload.IsFull() { 405 return nil, errUnknownContent 406 } 407 } 408 } 409 }() 410 if contentErr != nil && metadataErr == nil { 411 return metaresult, nil 412 } 413 if contentErr == nil && metadataErr == nil { 414 return CompositeResult(metaresult, contentResult), nil 415 } 416 return contentResult, contentErr 417 } 418 419 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 420 ob := session.OutboundFromContext(ctx) 421 if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { 422 proxied := hosts.LookupHosts(ob.Target.String()) 423 if proxied != nil { 424 ro := ob.RouteTarget == destination 425 destination.Address = *proxied 426 if ro { 427 ob.RouteTarget = destination 428 } else { 429 ob.Target = destination 430 } 431 } 432 } 433 434 var handler outbound.Handler 435 436 routingLink := routing_session.AsRoutingContext(ctx) 437 inTag := routingLink.GetInboundTag() 438 isPickRoute := 0 439 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 440 ctx = session.SetForcedOutboundTagToContext(ctx, "") 441 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 442 isPickRoute = 1 443 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 444 handler = h 445 } else { 446 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 447 common.Close(link.Writer) 448 common.Interrupt(link.Reader) 449 return 450 } 451 } else if d.router != nil { 452 if route, err := d.router.PickRoute(routingLink); err == nil { 453 outTag := route.GetOutboundTag() 454 if h := d.ohm.GetHandler(outTag); h != nil { 455 isPickRoute = 2 456 newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 457 handler = h 458 } else { 459 newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 460 } 461 } else { 462 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 463 } 464 } 465 466 if handler == nil { 467 handler = d.ohm.GetDefaultHandler() 468 } 469 470 if handler == nil { 471 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 472 common.Close(link.Writer) 473 common.Interrupt(link.Reader) 474 return 475 } 476 477 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 478 if tag := handler.Tag(); tag != "" { 479 if inTag == "" { 480 accessMessage.Detour = tag 481 } else if isPickRoute == 1 { 482 accessMessage.Detour = inTag + " ==> " + tag 483 } else if isPickRoute == 2 { 484 accessMessage.Detour = inTag + " -> " + tag 485 } else { 486 accessMessage.Detour = inTag + " >> " + tag 487 } 488 } 489 log.Record(accessMessage) 490 } 491 492 handler.Dispatch(ctx, link) 493 }