github.com/xtls/xray-core@v1.8.3/app/dispatcher/default.go (about)

     1  package dispatcher
     2  
     3  //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xtls/xray-core/common"
    13  	"github.com/xtls/xray-core/common/buf"
    14  	"github.com/xtls/xray-core/common/log"
    15  	"github.com/xtls/xray-core/common/net"
    16  	"github.com/xtls/xray-core/common/protocol"
    17  	"github.com/xtls/xray-core/common/session"
    18  	"github.com/xtls/xray-core/core"
    19  	"github.com/xtls/xray-core/features/dns"
    20  	"github.com/xtls/xray-core/features/outbound"
    21  	"github.com/xtls/xray-core/features/policy"
    22  	"github.com/xtls/xray-core/features/routing"
    23  	routing_session "github.com/xtls/xray-core/features/routing/session"
    24  	"github.com/xtls/xray-core/features/stats"
    25  	"github.com/xtls/xray-core/transport"
    26  	"github.com/xtls/xray-core/transport/pipe"
    27  )
    28  
    29  var errSniffingTimeout = newError("timeout on sniffing")
    30  
    31  type cachedReader struct {
    32  	sync.Mutex
    33  	reader *pipe.Reader
    34  	cache  buf.MultiBuffer
    35  }
    36  
    37  func (r *cachedReader) Cache(b *buf.Buffer) {
    38  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    39  	r.Lock()
    40  	if !mb.IsEmpty() {
    41  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    42  	}
    43  	b.Clear()
    44  	rawBytes := b.Extend(buf.Size)
    45  	n := r.cache.Copy(rawBytes)
    46  	b.Resize(0, int32(n))
    47  	r.Unlock()
    48  }
    49  
    50  func (r *cachedReader) readInternal() buf.MultiBuffer {
    51  	r.Lock()
    52  	defer r.Unlock()
    53  
    54  	if r.cache != nil && !r.cache.IsEmpty() {
    55  		mb := r.cache
    56  		r.cache = nil
    57  		return mb
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    64  	mb := r.readInternal()
    65  	if mb != nil {
    66  		return mb, nil
    67  	}
    68  
    69  	return r.reader.ReadMultiBuffer()
    70  }
    71  
    72  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    73  	mb := r.readInternal()
    74  	if mb != nil {
    75  		return mb, nil
    76  	}
    77  
    78  	return r.reader.ReadMultiBufferTimeout(timeout)
    79  }
    80  
    81  func (r *cachedReader) Interrupt() {
    82  	r.Lock()
    83  	if r.cache != nil {
    84  		r.cache = buf.ReleaseMulti(r.cache)
    85  	}
    86  	r.Unlock()
    87  	r.reader.Interrupt()
    88  }
    89  
    90  // DefaultDispatcher is a default implementation of Dispatcher.
    91  type DefaultDispatcher struct {
    92  	ohm    outbound.Manager
    93  	router routing.Router
    94  	policy policy.Manager
    95  	stats  stats.Manager
    96  	dns    dns.Client
    97  	fdns   dns.FakeDNSEngine
    98  }
    99  
   100  func init() {
   101  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   102  		d := new(DefaultDispatcher)
   103  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
   104  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
   105  				d.fdns = fdns
   106  			})
   107  			return d.Init(config.(*Config), om, router, pm, sm, dc)
   108  		}); err != nil {
   109  			return nil, err
   110  		}
   111  		return d, nil
   112  	}))
   113  }
   114  
   115  // Init initializes DefaultDispatcher.
   116  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
   117  	d.ohm = om
   118  	d.router = router
   119  	d.policy = pm
   120  	d.stats = sm
   121  	d.dns = dns
   122  	return nil
   123  }
   124  
   125  // Type implements common.HasType.
   126  func (*DefaultDispatcher) Type() interface{} {
   127  	return routing.DispatcherType()
   128  }
   129  
   130  // Start implements common.Runnable.
   131  func (*DefaultDispatcher) Start() error {
   132  	return nil
   133  }
   134  
   135  // Close implements common.Closable.
   136  func (*DefaultDispatcher) Close() error { return nil }
   137  
   138  func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
   139  	downOpt := pipe.OptionsFromContext(ctx)
   140  	upOpt := downOpt
   141  
   142  	if network == net.Network_UDP {
   143  		var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
   144  		// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
   145  		// When target replies, server will restore the domain and send back to client.
   146  		// Note: this map is not global but per connection context
   147  		upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
   148  			for i, buffer := range mb {
   149  				if buffer.UDP == nil {
   150  					continue
   151  				}
   152  				addr := buffer.UDP.Address
   153  				if addr.Family().IsIP() {
   154  					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
   155  						domain := fkr0.GetDomainFromFakeDNS(addr)
   156  						if len(domain) > 0 {
   157  							buffer.UDP.Address = net.DomainAddress(domain)
   158  							newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   159  						} else {
   160  							newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   161  						}
   162  					}
   163  				} else {
   164  					if ip2domain == nil {
   165  						ip2domain = new(sync.Map)
   166  						newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
   167  					}
   168  					domain := addr.Domain()
   169  					ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
   170  					if err == nil {
   171  						for _, ip := range ips {
   172  							ip2domain.Store(ip.String(), domain)
   173  						}
   174  						newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   175  					} else {
   176  						newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
   177  					}
   178  				}
   179  			}
   180  			return mb
   181  		}))
   182  		downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
   183  			for i, buffer := range mb {
   184  				if buffer.UDP == nil {
   185  					continue
   186  				}
   187  				addr := buffer.UDP.Address
   188  				if addr.Family().IsIP() {
   189  					if ip2domain == nil {
   190  						continue
   191  					}
   192  					if domain, found := ip2domain.Load(addr.IP().String()); found {
   193  						buffer.UDP.Address = net.DomainAddress(domain.(string))
   194  						newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   195  					}
   196  				} else {
   197  					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
   198  						fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
   199  						buffer.UDP.Address = fakeIp[0]
   200  						newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   201  					}
   202  				}
   203  			}
   204  			return mb
   205  		}))
   206  	}
   207  	uplinkReader, uplinkWriter := pipe.New(upOpt...)
   208  	downlinkReader, downlinkWriter := pipe.New(downOpt...)
   209  
   210  	inboundLink := &transport.Link{
   211  		Reader: downlinkReader,
   212  		Writer: uplinkWriter,
   213  	}
   214  
   215  	outboundLink := &transport.Link{
   216  		Reader: uplinkReader,
   217  		Writer: downlinkWriter,
   218  	}
   219  
   220  	sessionInbound := session.InboundFromContext(ctx)
   221  	var user *protocol.MemoryUser
   222  	if sessionInbound != nil {
   223  		user = sessionInbound.User
   224  	}
   225  
   226  	if user != nil && len(user.Email) > 0 {
   227  		p := d.policy.ForLevel(user.Level)
   228  		if p.Stats.UserUplink {
   229  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   230  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   231  				inboundLink.Writer = &SizeStatWriter{
   232  					Counter: c,
   233  					Writer:  inboundLink.Writer,
   234  				}
   235  			}
   236  		}
   237  		if p.Stats.UserDownlink {
   238  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   239  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   240  				outboundLink.Writer = &SizeStatWriter{
   241  					Counter: c,
   242  					Writer:  outboundLink.Writer,
   243  				}
   244  			}
   245  		}
   246  	}
   247  
   248  	return inboundLink, outboundLink
   249  }
   250  
   251  func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
   252  	domain := result.Domain()
   253  	if domain == "" {
   254  		return false
   255  	}
   256  	for _, d := range request.ExcludeForDomain {
   257  		if strings.ToLower(domain) == d {
   258  			return false
   259  		}
   260  	}
   261  	protocolString := result.Protocol()
   262  	if resComp, ok := result.(SnifferResultComposite); ok {
   263  		protocolString = resComp.ProtocolForDomainResult()
   264  	}
   265  	for _, p := range request.OverrideDestinationForProtocol {
   266  		if strings.HasPrefix(protocolString, p) {
   267  			return true
   268  		}
   269  		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
   270  			destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
   271  			newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
   272  			return true
   273  		}
   274  		if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
   275  			if resultSubset.IsProtoSubsetOf(p) {
   276  				return true
   277  			}
   278  		}
   279  	}
   280  
   281  	return false
   282  }
   283  
   284  // Dispatch implements routing.Dispatcher.
   285  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   286  	if !destination.IsValid() {
   287  		panic("Dispatcher: Invalid destination.")
   288  	}
   289  	ob := &session.Outbound{
   290  		Target: destination,
   291  	}
   292  	ctx = session.ContextWithOutbound(ctx, ob)
   293  	content := session.ContentFromContext(ctx)
   294  	if content == nil {
   295  		content = new(session.Content)
   296  		ctx = session.ContextWithContent(ctx, content)
   297  	}
   298  
   299  	sniffingRequest := content.SniffingRequest
   300  	inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
   301  	if !sniffingRequest.Enabled {
   302  		go d.routedDispatch(ctx, outbound, destination)
   303  	} else {
   304  		go func() {
   305  			cReader := &cachedReader{
   306  				reader: outbound.Reader.(*pipe.Reader),
   307  			}
   308  			outbound.Reader = cReader
   309  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   310  			if err == nil {
   311  				content.Protocol = result.Protocol()
   312  			}
   313  			if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   314  				domain := result.Domain()
   315  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   316  				destination.Address = net.ParseAddress(domain)
   317  				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
   318  					ob.RouteTarget = destination
   319  				} else {
   320  					ob.Target = destination
   321  				}
   322  			}
   323  			d.routedDispatch(ctx, outbound, destination)
   324  		}()
   325  	}
   326  	return inbound, nil
   327  }
   328  
   329  // DispatchLink implements routing.Dispatcher.
   330  func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
   331  	if !destination.IsValid() {
   332  		return newError("Dispatcher: Invalid destination.")
   333  	}
   334  	ob := &session.Outbound{
   335  		Target: destination,
   336  	}
   337  	ctx = session.ContextWithOutbound(ctx, ob)
   338  	content := session.ContentFromContext(ctx)
   339  	if content == nil {
   340  		content = new(session.Content)
   341  		ctx = session.ContextWithContent(ctx, content)
   342  	}
   343  	sniffingRequest := content.SniffingRequest
   344  	if !sniffingRequest.Enabled {
   345  		d.routedDispatch(ctx, outbound, destination)
   346  	} else {
   347  		cReader := &cachedReader{
   348  			reader: outbound.Reader.(*pipe.Reader),
   349  		}
   350  		outbound.Reader = cReader
   351  		result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   352  		if err == nil {
   353  			content.Protocol = result.Protocol()
   354  		}
   355  		if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   356  			domain := result.Domain()
   357  			newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   358  			destination.Address = net.ParseAddress(domain)
   359  			if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
   360  				ob.RouteTarget = destination
   361  			} else {
   362  				ob.Target = destination
   363  			}
   364  		}
   365  		d.routedDispatch(ctx, outbound, destination)
   366  	}
   367  
   368  	return nil
   369  }
   370  
   371  func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
   372  	payload := buf.New()
   373  	defer payload.Release()
   374  
   375  	sniffer := NewSniffer(ctx)
   376  
   377  	metaresult, metadataErr := sniffer.SniffMetadata(ctx)
   378  
   379  	if metadataOnly {
   380  		return metaresult, metadataErr
   381  	}
   382  
   383  	contentResult, contentErr := func() (SniffResult, error) {
   384  		totalAttempt := 0
   385  		for {
   386  			select {
   387  			case <-ctx.Done():
   388  				return nil, ctx.Err()
   389  			default:
   390  				totalAttempt++
   391  				if totalAttempt > 2 {
   392  					return nil, errSniffingTimeout
   393  				}
   394  
   395  				cReader.Cache(payload)
   396  				if !payload.IsEmpty() {
   397  					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
   398  					if err != common.ErrNoClue {
   399  						return result, err
   400  					}
   401  				}
   402  				if payload.IsFull() {
   403  					return nil, errUnknownContent
   404  				}
   405  			}
   406  		}
   407  	}()
   408  	if contentErr != nil && metadataErr == nil {
   409  		return metaresult, nil
   410  	}
   411  	if contentErr == nil && metadataErr == nil {
   412  		return CompositeResult(metaresult, contentResult), nil
   413  	}
   414  	return contentResult, contentErr
   415  }
   416  
   417  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   418  	ob := session.OutboundFromContext(ctx)
   419  	if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
   420  		proxied := hosts.LookupHosts(ob.Target.String())
   421  		if proxied != nil {
   422  			ro := ob.RouteTarget == destination
   423  			destination.Address = *proxied
   424  			if ro {
   425  				ob.RouteTarget = destination
   426  			} else {
   427  				ob.Target = destination
   428  			}
   429  		}
   430  	}
   431  
   432  	var handler outbound.Handler
   433  
   434  	routingLink := routing_session.AsRoutingContext(ctx)
   435  	inTag := routingLink.GetInboundTag()
   436  	isPickRoute := 0
   437  	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
   438  		ctx = session.SetForcedOutboundTagToContext(ctx, "")
   439  		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
   440  			isPickRoute = 1
   441  			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   442  			handler = h
   443  		} else {
   444  			newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
   445  			common.Close(link.Writer)
   446  			common.Interrupt(link.Reader)
   447  			return
   448  		}
   449  	} else if d.router != nil {
   450  		if route, err := d.router.PickRoute(routingLink); err == nil {
   451  			outTag := route.GetOutboundTag()
   452  			if h := d.ohm.GetHandler(outTag); h != nil {
   453  				isPickRoute = 2
   454  				newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   455  				handler = h
   456  			} else {
   457  				newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   458  			}
   459  		} else {
   460  			newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
   461  		}
   462  	}
   463  
   464  	if handler == nil {
   465  		handler = d.ohm.GetDefaultHandler()
   466  	}
   467  
   468  	if handler == nil {
   469  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   470  		common.Close(link.Writer)
   471  		common.Interrupt(link.Reader)
   472  		return
   473  	}
   474  
   475  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   476  		if tag := handler.Tag(); tag != "" {
   477  			if inTag == "" {
   478  				accessMessage.Detour = tag
   479  			} else if isPickRoute == 1 {
   480  				accessMessage.Detour = inTag + " ==> " + tag
   481  			} else if isPickRoute == 2 {
   482  				accessMessage.Detour = inTag + " -> " + tag
   483  			} else {
   484  				accessMessage.Detour = inTag + " >> " + tag
   485  			}
   486  		}
   487  		log.Record(accessMessage)
   488  	}
   489  
   490  	handler.Dispatch(ctx, link)
   491  }