github.com/xmplusdev/xray-core@v1.8.10/app/dispatcher/default.go (about)

     1  package dispatcher
     2  
     3  //go:generate go run github.com/xmplusdev/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/xmplusdev/xray-core/common"
    12  	"github.com/xmplusdev/xray-core/common/buf"
    13  	"github.com/xmplusdev/xray-core/common/log"
    14  	"github.com/xmplusdev/xray-core/common/net"
    15  	"github.com/xmplusdev/xray-core/common/protocol"
    16  	"github.com/xmplusdev/xray-core/common/session"
    17  	"github.com/xmplusdev/xray-core/core"
    18  	"github.com/xmplusdev/xray-core/features/dns"
    19  	"github.com/xmplusdev/xray-core/features/outbound"
    20  	"github.com/xmplusdev/xray-core/features/policy"
    21  	"github.com/xmplusdev/xray-core/features/routing"
    22  	routing_session "github.com/xmplusdev/xray-core/features/routing/session"
    23  	"github.com/xmplusdev/xray-core/features/stats"
    24  	"github.com/xmplusdev/xray-core/transport"
    25  	"github.com/xmplusdev/xray-core/transport/pipe"
    26  )
    27  
    28  var errSniffingTimeout = newError("timeout on sniffing")
    29  
    30  type cachedReader struct {
    31  	sync.Mutex
    32  	reader *pipe.Reader
    33  	cache  buf.MultiBuffer
    34  }
    35  
    36  func (r *cachedReader) Cache(b *buf.Buffer) {
    37  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    38  	r.Lock()
    39  	if !mb.IsEmpty() {
    40  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    41  	}
    42  	b.Clear()
    43  	rawBytes := b.Extend(buf.Size)
    44  	n := r.cache.Copy(rawBytes)
    45  	b.Resize(0, int32(n))
    46  	r.Unlock()
    47  }
    48  
    49  func (r *cachedReader) readInternal() buf.MultiBuffer {
    50  	r.Lock()
    51  	defer r.Unlock()
    52  
    53  	if r.cache != nil && !r.cache.IsEmpty() {
    54  		mb := r.cache
    55  		r.cache = nil
    56  		return mb
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    63  	mb := r.readInternal()
    64  	if mb != nil {
    65  		return mb, nil
    66  	}
    67  
    68  	return r.reader.ReadMultiBuffer()
    69  }
    70  
    71  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    72  	mb := r.readInternal()
    73  	if mb != nil {
    74  		return mb, nil
    75  	}
    76  
    77  	return r.reader.ReadMultiBufferTimeout(timeout)
    78  }
    79  
    80  func (r *cachedReader) Interrupt() {
    81  	r.Lock()
    82  	if r.cache != nil {
    83  		r.cache = buf.ReleaseMulti(r.cache)
    84  	}
    85  	r.Unlock()
    86  	r.reader.Interrupt()
    87  }
    88  
    89  // DefaultDispatcher is a default implementation of Dispatcher.
    90  type DefaultDispatcher struct {
    91  	ohm    outbound.Manager
    92  	router routing.Router
    93  	policy policy.Manager
    94  	stats  stats.Manager
    95  	dns    dns.Client
    96  	fdns   dns.FakeDNSEngine
    97  }
    98  
    99  func init() {
   100  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   101  		d := new(DefaultDispatcher)
   102  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
   103  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
   104  				d.fdns = fdns
   105  			})
   106  			return d.Init(config.(*Config), om, router, pm, sm, dc)
   107  		}); err != nil {
   108  			return nil, err
   109  		}
   110  		return d, nil
   111  	}))
   112  }
   113  
   114  // Init initializes DefaultDispatcher.
   115  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
   116  	d.ohm = om
   117  	d.router = router
   118  	d.policy = pm
   119  	d.stats = sm
   120  	d.dns = dns
   121  	return nil
   122  }
   123  
   124  // Type implements common.HasType.
   125  func (*DefaultDispatcher) Type() interface{} {
   126  	return routing.DispatcherType()
   127  }
   128  
   129  // Start implements common.Runnable.
   130  func (*DefaultDispatcher) Start() error {
   131  	return nil
   132  }
   133  
   134  // Close implements common.Closable.
   135  func (*DefaultDispatcher) Close() error { return nil }
   136  
   137  func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
   138  	opt := pipe.OptionsFromContext(ctx)
   139  	uplinkReader, uplinkWriter := pipe.New(opt...)
   140  	downlinkReader, downlinkWriter := pipe.New(opt...)
   141  
   142  	inboundLink := &transport.Link{
   143  		Reader: downlinkReader,
   144  		Writer: uplinkWriter,
   145  	}
   146  
   147  	outboundLink := &transport.Link{
   148  		Reader: uplinkReader,
   149  		Writer: downlinkWriter,
   150  	}
   151  
   152  	sessionInbound := session.InboundFromContext(ctx)
   153  	var user *protocol.MemoryUser
   154  	if sessionInbound != nil {
   155  		user = sessionInbound.User
   156  	}
   157  
   158  	if user != nil && len(user.Email) > 0 {
   159  		p := d.policy.ForLevel(user.Level)
   160  		if p.Stats.UserUplink {
   161  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   162  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   163  				inboundLink.Writer = &SizeStatWriter{
   164  					Counter: c,
   165  					Writer:  inboundLink.Writer,
   166  				}
   167  			}
   168  		}
   169  		if p.Stats.UserDownlink {
   170  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   171  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   172  				outboundLink.Writer = &SizeStatWriter{
   173  					Counter: c,
   174  					Writer:  outboundLink.Writer,
   175  				}
   176  			}
   177  		}
   178  	}
   179  
   180  	return inboundLink, outboundLink
   181  }
   182  
   183  func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
   184  	domain := result.Domain()
   185  	if domain == "" {
   186  		return false
   187  	}
   188  	for _, d := range request.ExcludeForDomain {
   189  		if strings.ToLower(domain) == d {
   190  			return false
   191  		}
   192  	}
   193  	protocolString := result.Protocol()
   194  	if resComp, ok := result.(SnifferResultComposite); ok {
   195  		protocolString = resComp.ProtocolForDomainResult()
   196  	}
   197  	for _, p := range request.OverrideDestinationForProtocol {
   198  		if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) {
   199  			return true
   200  		}
   201  		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
   202  			fkr0.IsIPInIPPool(destination.Address) {
   203  			newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
   204  			return true
   205  		}
   206  		if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
   207  			if resultSubset.IsProtoSubsetOf(p) {
   208  				return true
   209  			}
   210  		}
   211  	}
   212  
   213  	return false
   214  }
   215  
   216  // Dispatch implements routing.Dispatcher.
   217  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   218  	if !destination.IsValid() {
   219  		panic("Dispatcher: Invalid destination.")
   220  	}
   221  	ob := session.OutboundFromContext(ctx)
   222  	if ob == nil {
   223  		ob = &session.Outbound{}
   224  		ctx = session.ContextWithOutbound(ctx, ob)
   225  	}
   226  	ob.OriginalTarget = destination
   227  	ob.Target = destination
   228  	content := session.ContentFromContext(ctx)
   229  	if content == nil {
   230  		content = new(session.Content)
   231  		ctx = session.ContextWithContent(ctx, content)
   232  	}
   233  
   234  	sniffingRequest := content.SniffingRequest
   235  	inbound, outbound := d.getLink(ctx)
   236  	if !sniffingRequest.Enabled {
   237  		go d.routedDispatch(ctx, outbound, destination)
   238  	} else {
   239  		go func() {
   240  			cReader := &cachedReader{
   241  				reader: outbound.Reader.(*pipe.Reader),
   242  			}
   243  			outbound.Reader = cReader
   244  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   245  			if err == nil {
   246  				content.Protocol = result.Protocol()
   247  			}
   248  			if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   249  				domain := result.Domain()
   250  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   251  				destination.Address = net.ParseAddress(domain)
   252  				protocol := result.Protocol()
   253  				if resComp, ok := result.(SnifferResultComposite); ok {
   254  					protocol = resComp.ProtocolForDomainResult()
   255  				}
   256  				isFakeIP := false
   257  				if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
   258  					isFakeIP = true
   259  				}
   260  				if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
   261  					ob.RouteTarget = destination
   262  				} else {
   263  					ob.Target = destination
   264  				}
   265  			}
   266  			d.routedDispatch(ctx, outbound, destination)
   267  		}()
   268  	}
   269  	return inbound, nil
   270  }
   271  
   272  // DispatchLink implements routing.Dispatcher.
   273  func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
   274  	if !destination.IsValid() {
   275  		return newError("Dispatcher: Invalid destination.")
   276  	}
   277  	ob := session.OutboundFromContext(ctx)
   278  	if ob == nil {
   279  		ob = &session.Outbound{}
   280  		ctx = session.ContextWithOutbound(ctx, ob)
   281  	}
   282  	ob.OriginalTarget = destination
   283  	ob.Target = destination
   284  	content := session.ContentFromContext(ctx)
   285  	if content == nil {
   286  		content = new(session.Content)
   287  		ctx = session.ContextWithContent(ctx, content)
   288  	}
   289  	sniffingRequest := content.SniffingRequest
   290  	if !sniffingRequest.Enabled {
   291  		d.routedDispatch(ctx, outbound, destination)
   292  	} else {
   293  		cReader := &cachedReader{
   294  			reader: outbound.Reader.(*pipe.Reader),
   295  		}
   296  		outbound.Reader = cReader
   297  		result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   298  		if err == nil {
   299  			content.Protocol = result.Protocol()
   300  		}
   301  		if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   302  			domain := result.Domain()
   303  			newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   304  			destination.Address = net.ParseAddress(domain)
   305  			protocol := result.Protocol()
   306  			if resComp, ok := result.(SnifferResultComposite); ok {
   307  				protocol = resComp.ProtocolForDomainResult()
   308  			}
   309  			isFakeIP := false
   310  			if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
   311  				isFakeIP = true
   312  			}
   313  			if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
   314  				ob.RouteTarget = destination
   315  			} else {
   316  				ob.Target = destination
   317  			}
   318  		}
   319  		d.routedDispatch(ctx, outbound, destination)
   320  	}
   321  
   322  	return nil
   323  }
   324  
   325  func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
   326  	payload := buf.New()
   327  	defer payload.Release()
   328  
   329  	sniffer := NewSniffer(ctx)
   330  
   331  	metaresult, metadataErr := sniffer.SniffMetadata(ctx)
   332  
   333  	if metadataOnly {
   334  		return metaresult, metadataErr
   335  	}
   336  
   337  	contentResult, contentErr := func() (SniffResult, error) {
   338  		totalAttempt := 0
   339  		for {
   340  			select {
   341  			case <-ctx.Done():
   342  				return nil, ctx.Err()
   343  			default:
   344  				totalAttempt++
   345  				if totalAttempt > 2 {
   346  					return nil, errSniffingTimeout
   347  				}
   348  
   349  				cReader.Cache(payload)
   350  				if !payload.IsEmpty() {
   351  					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
   352  					if err != common.ErrNoClue {
   353  						return result, err
   354  					}
   355  				}
   356  				if payload.IsFull() {
   357  					return nil, errUnknownContent
   358  				}
   359  			}
   360  		}
   361  	}()
   362  	if contentErr != nil && metadataErr == nil {
   363  		return metaresult, nil
   364  	}
   365  	if contentErr == nil && metadataErr == nil {
   366  		return CompositeResult(metaresult, contentResult), nil
   367  	}
   368  	return contentResult, contentErr
   369  }
   370  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   371  	ob := session.OutboundFromContext(ctx)
   372  	if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
   373  		proxied := hosts.LookupHosts(ob.Target.String())
   374  		if proxied != nil {
   375  			ro := ob.RouteTarget == destination
   376  			destination.Address = *proxied
   377  			if ro {
   378  				ob.RouteTarget = destination
   379  			} else {
   380  				ob.Target = destination
   381  			}
   382  		}
   383  	}
   384  
   385  	var handler outbound.Handler
   386  
   387  	routingLink := routing_session.AsRoutingContext(ctx)
   388  	inTag := routingLink.GetInboundTag()
   389  	isPickRoute := 0
   390  	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
   391  		ctx = session.SetForcedOutboundTagToContext(ctx, "")
   392  		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
   393  			isPickRoute = 1
   394  			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   395  			handler = h
   396  		} else {
   397  			newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
   398  			common.Close(link.Writer)
   399  			common.Interrupt(link.Reader)
   400  			return
   401  		}
   402  	} else if d.router != nil {
   403  		if route, err := d.router.PickRoute(routingLink); err == nil {
   404  			outTag := route.GetOutboundTag()
   405  			if h := d.ohm.GetHandler(outTag); h != nil {
   406  				isPickRoute = 2
   407  				newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   408  				handler = h
   409  			} else {
   410  				newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   411  			}
   412  		} else {
   413  			newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
   414  		}
   415  	}
   416  
   417  	if handler == nil {
   418  		handler = d.ohm.GetDefaultHandler()
   419  	}
   420  
   421  	if handler == nil {
   422  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   423  		common.Close(link.Writer)
   424  		common.Interrupt(link.Reader)
   425  		return
   426  	}
   427  
   428  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   429  		if tag := handler.Tag(); tag != "" {
   430  			if inTag == "" {
   431  				accessMessage.Detour = tag
   432  			} else if isPickRoute == 1 {
   433  				accessMessage.Detour = inTag + " ==> " + tag
   434  			} else if isPickRoute == 2 {
   435  				accessMessage.Detour = inTag + " -> " + tag
   436  			} else {
   437  				accessMessage.Detour = inTag + " >> " + tag
   438  			}
   439  		}
   440  		log.Record(accessMessage)
   441  	}
   442  
   443  	handler.Dispatch(ctx, link)
   444  }