github.com/moqsien/xraycore@v1.8.5/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen 4 5 import ( 6 "context" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/moqsien/xraycore/common" 12 "github.com/moqsien/xraycore/common/buf" 13 "github.com/moqsien/xraycore/common/log" 14 "github.com/moqsien/xraycore/common/net" 15 "github.com/moqsien/xraycore/common/protocol" 16 "github.com/moqsien/xraycore/common/session" 17 "github.com/moqsien/xraycore/core" 18 "github.com/moqsien/xraycore/features/dns" 19 "github.com/moqsien/xraycore/features/outbound" 20 "github.com/moqsien/xraycore/features/policy" 21 "github.com/moqsien/xraycore/features/routing" 22 routing_session "github.com/moqsien/xraycore/features/routing/session" 23 "github.com/moqsien/xraycore/features/stats" 24 "github.com/moqsien/xraycore/transport" 25 "github.com/moqsien/xraycore/transport/pipe" 26 ) 27 28 var errSniffingTimeout = newError("timeout on sniffing") 29 30 type cachedReader struct { 31 sync.Mutex 32 reader *pipe.Reader 33 cache buf.MultiBuffer 34 } 35 36 func (r *cachedReader) Cache(b *buf.Buffer) { 37 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 38 r.Lock() 39 if !mb.IsEmpty() { 40 r.cache, _ = buf.MergeMulti(r.cache, mb) 41 } 42 b.Clear() 43 rawBytes := b.Extend(buf.Size) 44 n := r.cache.Copy(rawBytes) 45 b.Resize(0, int32(n)) 46 r.Unlock() 47 } 48 49 func (r *cachedReader) readInternal() buf.MultiBuffer { 50 r.Lock() 51 defer r.Unlock() 52 53 if r.cache != nil && !r.cache.IsEmpty() { 54 mb := r.cache 55 r.cache = nil 56 return mb 57 } 58 59 return nil 60 } 61 62 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 63 mb := r.readInternal() 64 if mb != nil { 65 return mb, nil 66 } 67 68 return r.reader.ReadMultiBuffer() 69 } 70 71 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 72 mb := r.readInternal() 73 if mb != nil { 74 return mb, nil 75 } 76 77 return r.reader.ReadMultiBufferTimeout(timeout) 78 } 79 80 func (r *cachedReader) Interrupt() { 81 r.Lock() 82 if r.cache != nil { 83 r.cache = buf.ReleaseMulti(r.cache) 84 } 85 r.Unlock() 86 r.reader.Interrupt() 87 } 88 89 // DefaultDispatcher is a default implementation of Dispatcher. 90 type DefaultDispatcher struct { 91 ohm outbound.Manager 92 router routing.Router 93 policy policy.Manager 94 stats stats.Manager 95 dns dns.Client 96 fdns dns.FakeDNSEngine 97 } 98 99 func init() { 100 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 101 d := new(DefaultDispatcher) 102 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { 103 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 104 d.fdns = fdns 105 }) 106 return d.Init(config.(*Config), om, router, pm, sm, dc) 107 }); err != nil { 108 return nil, err 109 } 110 return d, nil 111 })) 112 } 113 114 // Init initializes DefaultDispatcher. 115 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error { 116 d.ohm = om 117 d.router = router 118 d.policy = pm 119 d.stats = sm 120 d.dns = dns 121 return nil 122 } 123 124 // Type implements common.HasType. 125 func (*DefaultDispatcher) Type() interface{} { 126 return routing.DispatcherType() 127 } 128 129 // Start implements common.Runnable. 130 func (*DefaultDispatcher) Start() error { 131 return nil 132 } 133 134 // Close implements common.Closable. 135 func (*DefaultDispatcher) Close() error { return nil } 136 137 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 138 opt := pipe.OptionsFromContext(ctx) 139 uplinkReader, uplinkWriter := pipe.New(opt...) 140 downlinkReader, downlinkWriter := pipe.New(opt...) 141 142 inboundLink := &transport.Link{ 143 Reader: downlinkReader, 144 Writer: uplinkWriter, 145 } 146 147 outboundLink := &transport.Link{ 148 Reader: uplinkReader, 149 Writer: downlinkWriter, 150 } 151 152 sessionInbound := session.InboundFromContext(ctx) 153 var user *protocol.MemoryUser 154 if sessionInbound != nil { 155 user = sessionInbound.User 156 } 157 158 if user != nil && len(user.Email) > 0 { 159 p := d.policy.ForLevel(user.Level) 160 if p.Stats.UserUplink { 161 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 162 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 163 inboundLink.Writer = &SizeStatWriter{ 164 Counter: c, 165 Writer: inboundLink.Writer, 166 } 167 } 168 } 169 if p.Stats.UserDownlink { 170 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 171 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 172 outboundLink.Writer = &SizeStatWriter{ 173 Counter: c, 174 Writer: outboundLink.Writer, 175 } 176 } 177 } 178 } 179 180 return inboundLink, outboundLink 181 } 182 183 func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { 184 domain := result.Domain() 185 if domain == "" { 186 return false 187 } 188 for _, d := range request.ExcludeForDomain { 189 if strings.ToLower(domain) == d { 190 return false 191 } 192 } 193 protocolString := result.Protocol() 194 if resComp, ok := result.(SnifferResultComposite); ok { 195 protocolString = resComp.ProtocolForDomainResult() 196 } 197 for _, p := range request.OverrideDestinationForProtocol { 198 if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) { 199 return true 200 } 201 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" && 202 destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) { 203 newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) 204 return true 205 } 206 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 207 if resultSubset.IsProtoSubsetOf(p) { 208 return true 209 } 210 } 211 } 212 213 return false 214 } 215 216 // Dispatch implements routing.Dispatcher. 217 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 218 if !destination.IsValid() { 219 panic("Dispatcher: Invalid destination.") 220 } 221 ob := &session.Outbound{ 222 OriginalTarget: destination, 223 Target: destination, 224 } 225 ctx = session.ContextWithOutbound(ctx, ob) 226 content := session.ContentFromContext(ctx) 227 if content == nil { 228 content = new(session.Content) 229 ctx = session.ContextWithContent(ctx, content) 230 } 231 sniffingRequest := content.SniffingRequest 232 inbound, outbound := d.getLink(ctx) 233 if !sniffingRequest.Enabled { 234 go d.routedDispatch(ctx, outbound, destination) 235 } else { 236 go func() { 237 cReader := &cachedReader{ 238 reader: outbound.Reader.(*pipe.Reader), 239 } 240 outbound.Reader = cReader 241 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 242 if err == nil { 243 content.Protocol = result.Protocol() 244 } 245 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 246 domain := result.Domain() 247 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 248 destination.Address = net.ParseAddress(domain) 249 protocol := result.Protocol() 250 if resComp, ok := result.(SnifferResultComposite); ok { 251 protocol = resComp.ProtocolForDomainResult() 252 } 253 isFakeIP := false 254 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) { 255 isFakeIP = true 256 } 257 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 258 ob.RouteTarget = destination 259 } else { 260 ob.Target = destination 261 } 262 } 263 d.routedDispatch(ctx, outbound, destination) 264 }() 265 } 266 return inbound, nil 267 } 268 269 // DispatchLink implements routing.Dispatcher. 270 func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error { 271 if !destination.IsValid() { 272 return newError("Dispatcher: Invalid destination.") 273 } 274 ob := &session.Outbound{ 275 OriginalTarget: destination, 276 Target: destination, 277 } 278 ctx = session.ContextWithOutbound(ctx, ob) 279 content := session.ContentFromContext(ctx) 280 if content == nil { 281 content = new(session.Content) 282 ctx = session.ContextWithContent(ctx, content) 283 } 284 sniffingRequest := content.SniffingRequest 285 if !sniffingRequest.Enabled { 286 d.routedDispatch(ctx, outbound, destination) 287 } else { 288 cReader := &cachedReader{ 289 reader: outbound.Reader.(*pipe.Reader), 290 } 291 outbound.Reader = cReader 292 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 293 if err == nil { 294 content.Protocol = result.Protocol() 295 } 296 if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) { 297 domain := result.Domain() 298 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 299 destination.Address = net.ParseAddress(domain) 300 protocol := result.Protocol() 301 if resComp, ok := result.(SnifferResultComposite); ok { 302 protocol = resComp.ProtocolForDomainResult() 303 } 304 isFakeIP := false 305 if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) { 306 isFakeIP = true 307 } 308 if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP { 309 ob.RouteTarget = destination 310 } else { 311 ob.Target = destination 312 } 313 } 314 d.routedDispatch(ctx, outbound, destination) 315 } 316 317 return nil 318 } 319 320 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { 321 payload := buf.New() 322 defer payload.Release() 323 324 sniffer := NewSniffer(ctx) 325 326 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 327 328 if metadataOnly { 329 return metaresult, metadataErr 330 } 331 332 contentResult, contentErr := func() (SniffResult, error) { 333 totalAttempt := 0 334 for { 335 select { 336 case <-ctx.Done(): 337 return nil, ctx.Err() 338 default: 339 totalAttempt++ 340 if totalAttempt > 2 { 341 return nil, errSniffingTimeout 342 } 343 344 cReader.Cache(payload) 345 if !payload.IsEmpty() { 346 result, err := sniffer.Sniff(ctx, payload.Bytes(), network) 347 if err != common.ErrNoClue { 348 return result, err 349 } 350 } 351 if payload.IsFull() { 352 return nil, errUnknownContent 353 } 354 } 355 } 356 }() 357 if contentErr != nil && metadataErr == nil { 358 return metaresult, nil 359 } 360 if contentErr == nil && metadataErr == nil { 361 return CompositeResult(metaresult, contentResult), nil 362 } 363 return contentResult, contentErr 364 } 365 366 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 367 ob := session.OutboundFromContext(ctx) 368 if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() { 369 proxied := hosts.LookupHosts(ob.Target.String()) 370 if proxied != nil { 371 ro := ob.RouteTarget == destination 372 destination.Address = *proxied 373 if ro { 374 ob.RouteTarget = destination 375 } else { 376 ob.Target = destination 377 } 378 } 379 } 380 381 var handler outbound.Handler 382 383 routingLink := routing_session.AsRoutingContext(ctx) 384 inTag := routingLink.GetInboundTag() 385 isPickRoute := 0 386 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 387 ctx = session.SetForcedOutboundTagToContext(ctx, "") 388 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 389 isPickRoute = 1 390 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 391 handler = h 392 } else { 393 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 394 common.Close(link.Writer) 395 common.Interrupt(link.Reader) 396 return 397 } 398 } else if d.router != nil { 399 if route, err := d.router.PickRoute(routingLink); err == nil { 400 outTag := route.GetOutboundTag() 401 if h := d.ohm.GetHandler(outTag); h != nil { 402 isPickRoute = 2 403 newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 404 handler = h 405 } else { 406 newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 407 } 408 } else { 409 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 410 } 411 } 412 413 if handler == nil { 414 handler = d.ohm.GetDefaultHandler() 415 } 416 417 if handler == nil { 418 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 419 common.Close(link.Writer) 420 common.Interrupt(link.Reader) 421 return 422 } 423 424 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 425 if tag := handler.Tag(); tag != "" { 426 if inTag == "" { 427 accessMessage.Detour = tag 428 } else if isPickRoute == 1 { 429 accessMessage.Detour = inTag + " ==> " + tag 430 } else if isPickRoute == 2 { 431 accessMessage.Detour = inTag + " -> " + tag 432 } else { 433 accessMessage.Detour = inTag + " >> " + tag 434 } 435 } 436 log.Record(accessMessage) 437 } 438 439 handler.Dispatch(ctx, link) 440 }