github.com/EagleQL/Xray-core@v1.4.3/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/xtls/xray-core/common/errors/errorgen 4 5 import ( 6 "context" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/xtls/xray-core/common" 12 "github.com/xtls/xray-core/common/buf" 13 "github.com/xtls/xray-core/common/log" 14 "github.com/xtls/xray-core/common/net" 15 "github.com/xtls/xray-core/common/protocol" 16 "github.com/xtls/xray-core/common/session" 17 "github.com/xtls/xray-core/core" 18 "github.com/xtls/xray-core/features/dns" 19 "github.com/xtls/xray-core/features/outbound" 20 "github.com/xtls/xray-core/features/policy" 21 "github.com/xtls/xray-core/features/routing" 22 routing_session "github.com/xtls/xray-core/features/routing/session" 23 "github.com/xtls/xray-core/features/stats" 24 "github.com/xtls/xray-core/transport" 25 "github.com/xtls/xray-core/transport/pipe" 26 ) 27 28 var ( 29 errSniffingTimeout = newError("timeout on sniffing") 30 ) 31 32 type cachedReader struct { 33 sync.Mutex 34 reader *pipe.Reader 35 cache buf.MultiBuffer 36 } 37 38 func (r *cachedReader) Cache(b *buf.Buffer) { 39 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 40 r.Lock() 41 if !mb.IsEmpty() { 42 r.cache, _ = buf.MergeMulti(r.cache, mb) 43 } 44 b.Clear() 45 rawBytes := b.Extend(buf.Size) 46 n := r.cache.Copy(rawBytes) 47 b.Resize(0, int32(n)) 48 r.Unlock() 49 } 50 51 func (r *cachedReader) readInternal() buf.MultiBuffer { 52 r.Lock() 53 defer r.Unlock() 54 55 if r.cache != nil && !r.cache.IsEmpty() { 56 mb := r.cache 57 r.cache = nil 58 return mb 59 } 60 61 return nil 62 } 63 64 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 65 mb := r.readInternal() 66 if mb != nil { 67 return mb, nil 68 } 69 70 return r.reader.ReadMultiBuffer() 71 } 72 73 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 74 mb := r.readInternal() 75 if mb != nil { 76 return mb, nil 77 } 78 79 return r.reader.ReadMultiBufferTimeout(timeout) 80 } 81 82 func (r *cachedReader) Interrupt() { 83 r.Lock() 84 if r.cache != nil { 85 r.cache = buf.ReleaseMulti(r.cache) 86 } 87 r.Unlock() 88 r.reader.Interrupt() 89 } 90 91 // DefaultDispatcher is a default implementation of Dispatcher. 92 type DefaultDispatcher struct { 93 ohm outbound.Manager 94 router routing.Router 95 policy policy.Manager 96 stats stats.Manager 97 } 98 99 func init() { 100 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 101 d := new(DefaultDispatcher) 102 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 103 return d.Init(config.(*Config), om, router, pm, sm) 104 }); err != nil { 105 return nil, err 106 } 107 return d, nil 108 })) 109 } 110 111 // Init initializes DefaultDispatcher. 112 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 113 d.ohm = om 114 d.router = router 115 d.policy = pm 116 d.stats = sm 117 return nil 118 } 119 120 // Type implements common.HasType. 121 func (*DefaultDispatcher) Type() interface{} { 122 return routing.DispatcherType() 123 } 124 125 // Start implements common.Runnable. 126 func (*DefaultDispatcher) Start() error { 127 return nil 128 } 129 130 // Close implements common.Closable. 131 func (*DefaultDispatcher) Close() error { return nil } 132 133 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 134 opt := pipe.OptionsFromContext(ctx) 135 uplinkReader, uplinkWriter := pipe.New(opt...) 136 downlinkReader, downlinkWriter := pipe.New(opt...) 137 138 inboundLink := &transport.Link{ 139 Reader: downlinkReader, 140 Writer: uplinkWriter, 141 } 142 143 outboundLink := &transport.Link{ 144 Reader: uplinkReader, 145 Writer: downlinkWriter, 146 } 147 148 sessionInbound := session.InboundFromContext(ctx) 149 var user *protocol.MemoryUser 150 if sessionInbound != nil { 151 user = sessionInbound.User 152 } 153 154 if user != nil && len(user.Email) > 0 { 155 p := d.policy.ForLevel(user.Level) 156 if p.Stats.UserUplink { 157 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 158 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 159 inboundLink.Writer = &SizeStatWriter{ 160 Counter: c, 161 Writer: inboundLink.Writer, 162 } 163 } 164 } 165 if p.Stats.UserDownlink { 166 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 167 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 168 outboundLink.Writer = &SizeStatWriter{ 169 Counter: c, 170 Writer: outboundLink.Writer, 171 } 172 } 173 } 174 } 175 176 return inboundLink, outboundLink 177 } 178 179 func shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool { 180 domain := result.Domain() 181 for _, d := range request.ExcludeForDomain { 182 if strings.ToLower(domain) == d { 183 return false 184 } 185 } 186 var fakeDNSEngine dns.FakeDNSEngine 187 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 188 fakeDNSEngine = fdns 189 }) 190 protocolString := result.Protocol() 191 if resComp, ok := result.(SnifferResultComposite); ok { 192 protocolString = resComp.ProtocolForDomainResult() 193 } 194 for _, p := range request.OverrideDestinationForProtocol { 195 if strings.HasPrefix(protocolString, p) { 196 return true 197 } 198 if fakeDNSEngine != nil && protocolString != "bittorrent" && p == "fakedns" && 199 destination.Address.Family().IsIP() && fakeDNSEngine.GetFakeIPRange().Contains(destination.Address.IP()) { 200 newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx)) 201 return true 202 } 203 } 204 205 return false 206 } 207 208 // Dispatch implements routing.Dispatcher. 209 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 210 if !destination.IsValid() { 211 panic("Dispatcher: Invalid destination.") 212 } 213 ob := &session.Outbound{ 214 Target: destination, 215 } 216 ctx = session.ContextWithOutbound(ctx, ob) 217 218 inbound, outbound := d.getLink(ctx) 219 content := session.ContentFromContext(ctx) 220 if content == nil { 221 content = new(session.Content) 222 ctx = session.ContextWithContent(ctx, content) 223 } 224 sniffingRequest := content.SniffingRequest 225 switch { 226 case !sniffingRequest.Enabled: 227 go d.routedDispatch(ctx, outbound, destination) 228 case destination.Network != net.Network_TCP: 229 // Only metadata sniff will be used for non tcp connection 230 result, err := sniffer(ctx, nil, true) 231 if err == nil { 232 content.Protocol = result.Protocol() 233 if shouldOverride(ctx, result, sniffingRequest, destination) { 234 domain := result.Domain() 235 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 236 destination.Address = net.ParseAddress(domain) 237 ob.Target = destination 238 } 239 } 240 go d.routedDispatch(ctx, outbound, destination) 241 default: 242 go func() { 243 cReader := &cachedReader{ 244 reader: outbound.Reader.(*pipe.Reader), 245 } 246 outbound.Reader = cReader 247 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly) 248 if err == nil { 249 content.Protocol = result.Protocol() 250 } 251 if err == nil && shouldOverride(ctx, result, sniffingRequest, destination) { 252 domain := result.Domain() 253 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 254 destination.Address = net.ParseAddress(domain) 255 ob.Target = destination 256 } 257 d.routedDispatch(ctx, outbound, destination) 258 }() 259 } 260 return inbound, nil 261 } 262 263 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) { 264 payload := buf.New() 265 defer payload.Release() 266 267 sniffer := NewSniffer(ctx) 268 269 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 270 271 if metadataOnly { 272 return metaresult, metadataErr 273 } 274 275 contentResult, contentErr := func() (SniffResult, error) { 276 totalAttempt := 0 277 for { 278 select { 279 case <-ctx.Done(): 280 return nil, ctx.Err() 281 default: 282 totalAttempt++ 283 if totalAttempt > 2 { 284 return nil, errSniffingTimeout 285 } 286 287 cReader.Cache(payload) 288 if !payload.IsEmpty() { 289 result, err := sniffer.Sniff(ctx, payload.Bytes()) 290 if err != common.ErrNoClue { 291 return result, err 292 } 293 } 294 if payload.IsFull() { 295 return nil, errUnknownContent 296 } 297 } 298 } 299 }() 300 if contentErr != nil && metadataErr == nil { 301 return metaresult, nil 302 } 303 if contentErr == nil && metadataErr == nil { 304 return CompositeResult(metaresult, contentResult), nil 305 } 306 return contentResult, contentErr 307 } 308 309 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 310 var handler outbound.Handler 311 312 skipRoutePick := false 313 if content := session.ContentFromContext(ctx); content != nil { 314 skipRoutePick = content.SkipRoutePick 315 } 316 317 routingLink := routing_session.AsRoutingContext(ctx) 318 inTag := routingLink.GetInboundTag() 319 isPickRoute := false 320 if d.router != nil && !skipRoutePick { 321 if route, err := d.router.PickRoute(routingLink); err == nil { 322 outTag := route.GetOutboundTag() 323 isPickRoute = true 324 if h := d.ohm.GetHandler(outTag); h != nil { 325 newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 326 handler = h 327 } else { 328 newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 329 } 330 } else { 331 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 332 } 333 } 334 335 if handler == nil { 336 handler = d.ohm.GetDefaultHandler() 337 } 338 339 if handler == nil { 340 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 341 common.Close(link.Writer) 342 common.Interrupt(link.Reader) 343 return 344 } 345 346 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 347 if tag := handler.Tag(); tag != "" { 348 if isPickRoute { 349 if inTag != "" { 350 accessMessage.Detour = inTag + " -> " + tag 351 } else { 352 accessMessage.Detour = tag 353 } 354 } else { 355 if inTag != "" { 356 accessMessage.Detour = inTag + " >> " + tag 357 } else { 358 accessMessage.Detour = tag 359 } 360 } 361 } 362 log.Record(accessMessage) 363 } 364 365 handler.Dispatch(ctx, link) 366 }