github.com/moqsien/xraycore@v1.8.5/app/dispatcher/default.go (about)

     1  package dispatcher
     2  
     3  //go:generate go run github.com/moqsien/xraycore/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/moqsien/xraycore/common"
    12  	"github.com/moqsien/xraycore/common/buf"
    13  	"github.com/moqsien/xraycore/common/log"
    14  	"github.com/moqsien/xraycore/common/net"
    15  	"github.com/moqsien/xraycore/common/protocol"
    16  	"github.com/moqsien/xraycore/common/session"
    17  	"github.com/moqsien/xraycore/core"
    18  	"github.com/moqsien/xraycore/features/dns"
    19  	"github.com/moqsien/xraycore/features/outbound"
    20  	"github.com/moqsien/xraycore/features/policy"
    21  	"github.com/moqsien/xraycore/features/routing"
    22  	routing_session "github.com/moqsien/xraycore/features/routing/session"
    23  	"github.com/moqsien/xraycore/features/stats"
    24  	"github.com/moqsien/xraycore/transport"
    25  	"github.com/moqsien/xraycore/transport/pipe"
    26  )
    27  
    28  var errSniffingTimeout = newError("timeout on sniffing")
    29  
    30  type cachedReader struct {
    31  	sync.Mutex
    32  	reader *pipe.Reader
    33  	cache  buf.MultiBuffer
    34  }
    35  
    36  func (r *cachedReader) Cache(b *buf.Buffer) {
    37  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    38  	r.Lock()
    39  	if !mb.IsEmpty() {
    40  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    41  	}
    42  	b.Clear()
    43  	rawBytes := b.Extend(buf.Size)
    44  	n := r.cache.Copy(rawBytes)
    45  	b.Resize(0, int32(n))
    46  	r.Unlock()
    47  }
    48  
    49  func (r *cachedReader) readInternal() buf.MultiBuffer {
    50  	r.Lock()
    51  	defer r.Unlock()
    52  
    53  	if r.cache != nil && !r.cache.IsEmpty() {
    54  		mb := r.cache
    55  		r.cache = nil
    56  		return mb
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    63  	mb := r.readInternal()
    64  	if mb != nil {
    65  		return mb, nil
    66  	}
    67  
    68  	return r.reader.ReadMultiBuffer()
    69  }
    70  
    71  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    72  	mb := r.readInternal()
    73  	if mb != nil {
    74  		return mb, nil
    75  	}
    76  
    77  	return r.reader.ReadMultiBufferTimeout(timeout)
    78  }
    79  
    80  func (r *cachedReader) Interrupt() {
    81  	r.Lock()
    82  	if r.cache != nil {
    83  		r.cache = buf.ReleaseMulti(r.cache)
    84  	}
    85  	r.Unlock()
    86  	r.reader.Interrupt()
    87  }
    88  
    89  // DefaultDispatcher is a default implementation of Dispatcher.
    90  type DefaultDispatcher struct {
    91  	ohm    outbound.Manager
    92  	router routing.Router
    93  	policy policy.Manager
    94  	stats  stats.Manager
    95  	dns    dns.Client
    96  	fdns   dns.FakeDNSEngine
    97  }
    98  
    99  func init() {
   100  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   101  		d := new(DefaultDispatcher)
   102  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
   103  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
   104  				d.fdns = fdns
   105  			})
   106  			return d.Init(config.(*Config), om, router, pm, sm, dc)
   107  		}); err != nil {
   108  			return nil, err
   109  		}
   110  		return d, nil
   111  	}))
   112  }
   113  
   114  // Init initializes DefaultDispatcher.
   115  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
   116  	d.ohm = om
   117  	d.router = router
   118  	d.policy = pm
   119  	d.stats = sm
   120  	d.dns = dns
   121  	return nil
   122  }
   123  
   124  // Type implements common.HasType.
   125  func (*DefaultDispatcher) Type() interface{} {
   126  	return routing.DispatcherType()
   127  }
   128  
   129  // Start implements common.Runnable.
   130  func (*DefaultDispatcher) Start() error {
   131  	return nil
   132  }
   133  
   134  // Close implements common.Closable.
   135  func (*DefaultDispatcher) Close() error { return nil }
   136  
   137  func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
   138  	opt := pipe.OptionsFromContext(ctx)
   139  	uplinkReader, uplinkWriter := pipe.New(opt...)
   140  	downlinkReader, downlinkWriter := pipe.New(opt...)
   141  
   142  	inboundLink := &transport.Link{
   143  		Reader: downlinkReader,
   144  		Writer: uplinkWriter,
   145  	}
   146  
   147  	outboundLink := &transport.Link{
   148  		Reader: uplinkReader,
   149  		Writer: downlinkWriter,
   150  	}
   151  
   152  	sessionInbound := session.InboundFromContext(ctx)
   153  	var user *protocol.MemoryUser
   154  	if sessionInbound != nil {
   155  		user = sessionInbound.User
   156  	}
   157  
   158  	if user != nil && len(user.Email) > 0 {
   159  		p := d.policy.ForLevel(user.Level)
   160  		if p.Stats.UserUplink {
   161  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   162  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   163  				inboundLink.Writer = &SizeStatWriter{
   164  					Counter: c,
   165  					Writer:  inboundLink.Writer,
   166  				}
   167  			}
   168  		}
   169  		if p.Stats.UserDownlink {
   170  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   171  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   172  				outboundLink.Writer = &SizeStatWriter{
   173  					Counter: c,
   174  					Writer:  outboundLink.Writer,
   175  				}
   176  			}
   177  		}
   178  	}
   179  
   180  	return inboundLink, outboundLink
   181  }
   182  
   183  func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
   184  	domain := result.Domain()
   185  	if domain == "" {
   186  		return false
   187  	}
   188  	for _, d := range request.ExcludeForDomain {
   189  		if strings.ToLower(domain) == d {
   190  			return false
   191  		}
   192  	}
   193  	protocolString := result.Protocol()
   194  	if resComp, ok := result.(SnifferResultComposite); ok {
   195  		protocolString = resComp.ProtocolForDomainResult()
   196  	}
   197  	for _, p := range request.OverrideDestinationForProtocol {
   198  		if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
   199  			return true
   200  		}
   201  		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
   202  			destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
   203  			newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
   204  			return true
   205  		}
   206  		if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
   207  			if resultSubset.IsProtoSubsetOf(p) {
   208  				return true
   209  			}
   210  		}
   211  	}
   212  
   213  	return false
   214  }
   215  
   216  // Dispatch implements routing.Dispatcher.
   217  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   218  	if !destination.IsValid() {
   219  		panic("Dispatcher: Invalid destination.")
   220  	}
   221  	ob := &session.Outbound{
   222  		OriginalTarget: destination,
   223  		Target:         destination,
   224  	}
   225  	ctx = session.ContextWithOutbound(ctx, ob)
   226  	content := session.ContentFromContext(ctx)
   227  	if content == nil {
   228  		content = new(session.Content)
   229  		ctx = session.ContextWithContent(ctx, content)
   230  	}
   231  	sniffingRequest := content.SniffingRequest
   232  	inbound, outbound := d.getLink(ctx)
   233  	if !sniffingRequest.Enabled {
   234  		go d.routedDispatch(ctx, outbound, destination)
   235  	} else {
   236  		go func() {
   237  			cReader := &cachedReader{
   238  				reader: outbound.Reader.(*pipe.Reader),
   239  			}
   240  			outbound.Reader = cReader
   241  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   242  			if err == nil {
   243  				content.Protocol = result.Protocol()
   244  			}
   245  			if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   246  				domain := result.Domain()
   247  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   248  				destination.Address = net.ParseAddress(domain)
   249  				protocol := result.Protocol()
   250  				if resComp, ok := result.(SnifferResultComposite); ok {
   251  					protocol = resComp.ProtocolForDomainResult()
   252  				}
   253  				isFakeIP := false
   254  				if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
   255  					isFakeIP = true
   256  				}
   257  				if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
   258  					ob.RouteTarget = destination
   259  				} else {
   260  					ob.Target = destination
   261  				}
   262  			}
   263  			d.routedDispatch(ctx, outbound, destination)
   264  		}()
   265  	}
   266  	return inbound, nil
   267  }
   268  
   269  // DispatchLink implements routing.Dispatcher.
   270  func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
   271  	if !destination.IsValid() {
   272  		return newError("Dispatcher: Invalid destination.")
   273  	}
   274  	ob := &session.Outbound{
   275  		OriginalTarget: destination,
   276  		Target:         destination,
   277  	}
   278  	ctx = session.ContextWithOutbound(ctx, ob)
   279  	content := session.ContentFromContext(ctx)
   280  	if content == nil {
   281  		content = new(session.Content)
   282  		ctx = session.ContextWithContent(ctx, content)
   283  	}
   284  	sniffingRequest := content.SniffingRequest
   285  	if !sniffingRequest.Enabled {
   286  		d.routedDispatch(ctx, outbound, destination)
   287  	} else {
   288  		cReader := &cachedReader{
   289  			reader: outbound.Reader.(*pipe.Reader),
   290  		}
   291  		outbound.Reader = cReader
   292  		result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   293  		if err == nil {
   294  			content.Protocol = result.Protocol()
   295  		}
   296  		if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   297  			domain := result.Domain()
   298  			newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   299  			destination.Address = net.ParseAddress(domain)
   300  			protocol := result.Protocol()
   301  			if resComp, ok := result.(SnifferResultComposite); ok {
   302  				protocol = resComp.ProtocolForDomainResult()
   303  			}
   304  			isFakeIP := false
   305  			if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
   306  				isFakeIP = true
   307  			}
   308  			if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
   309  				ob.RouteTarget = destination
   310  			} else {
   311  				ob.Target = destination
   312  			}
   313  		}
   314  		d.routedDispatch(ctx, outbound, destination)
   315  	}
   316  
   317  	return nil
   318  }
   319  
   320  func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
   321  	payload := buf.New()
   322  	defer payload.Release()
   323  
   324  	sniffer := NewSniffer(ctx)
   325  
   326  	metaresult, metadataErr := sniffer.SniffMetadata(ctx)
   327  
   328  	if metadataOnly {
   329  		return metaresult, metadataErr
   330  	}
   331  
   332  	contentResult, contentErr := func() (SniffResult, error) {
   333  		totalAttempt := 0
   334  		for {
   335  			select {
   336  			case <-ctx.Done():
   337  				return nil, ctx.Err()
   338  			default:
   339  				totalAttempt++
   340  				if totalAttempt > 2 {
   341  					return nil, errSniffingTimeout
   342  				}
   343  
   344  				cReader.Cache(payload)
   345  				if !payload.IsEmpty() {
   346  					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
   347  					if err != common.ErrNoClue {
   348  						return result, err
   349  					}
   350  				}
   351  				if payload.IsFull() {
   352  					return nil, errUnknownContent
   353  				}
   354  			}
   355  		}
   356  	}()
   357  	if contentErr != nil && metadataErr == nil {
   358  		return metaresult, nil
   359  	}
   360  	if contentErr == nil && metadataErr == nil {
   361  		return CompositeResult(metaresult, contentResult), nil
   362  	}
   363  	return contentResult, contentErr
   364  }
   365  
   366  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   367  	ob := session.OutboundFromContext(ctx)
   368  	if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
   369  		proxied := hosts.LookupHosts(ob.Target.String())
   370  		if proxied != nil {
   371  			ro := ob.RouteTarget == destination
   372  			destination.Address = *proxied
   373  			if ro {
   374  				ob.RouteTarget = destination
   375  			} else {
   376  				ob.Target = destination
   377  			}
   378  		}
   379  	}
   380  
   381  	var handler outbound.Handler
   382  
   383  	routingLink := routing_session.AsRoutingContext(ctx)
   384  	inTag := routingLink.GetInboundTag()
   385  	isPickRoute := 0
   386  	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
   387  		ctx = session.SetForcedOutboundTagToContext(ctx, "")
   388  		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
   389  			isPickRoute = 1
   390  			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   391  			handler = h
   392  		} else {
   393  			newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
   394  			common.Close(link.Writer)
   395  			common.Interrupt(link.Reader)
   396  			return
   397  		}
   398  	} else if d.router != nil {
   399  		if route, err := d.router.PickRoute(routingLink); err == nil {
   400  			outTag := route.GetOutboundTag()
   401  			if h := d.ohm.GetHandler(outTag); h != nil {
   402  				isPickRoute = 2
   403  				newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   404  				handler = h
   405  			} else {
   406  				newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   407  			}
   408  		} else {
   409  			newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
   410  		}
   411  	}
   412  
   413  	if handler == nil {
   414  		handler = d.ohm.GetDefaultHandler()
   415  	}
   416  
   417  	if handler == nil {
   418  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   419  		common.Close(link.Writer)
   420  		common.Interrupt(link.Reader)
   421  		return
   422  	}
   423  
   424  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   425  		if tag := handler.Tag(); tag != "" {
   426  			if inTag == "" {
   427  				accessMessage.Detour = tag
   428  			} else if isPickRoute == 1 {
   429  				accessMessage.Detour = inTag + " ==> " + tag
   430  			} else if isPickRoute == 2 {
   431  				accessMessage.Detour = inTag + " -> " + tag
   432  			} else {
   433  				accessMessage.Detour = inTag + " >> " + tag
   434  			}
   435  		}
   436  		log.Record(accessMessage)
   437  	}
   438  
   439  	handler.Dispatch(ctx, link)
   440  }