github.com/xraypb/xray-core@v1.6.6/app/dispatcher/default.go (about)

     1  package dispatcher
     2  
     3  //go:generate go run github.com/xraypb/xray-core/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xraypb/xray-core/common"
    13  	"github.com/xraypb/xray-core/common/buf"
    14  	"github.com/xraypb/xray-core/common/log"
    15  	"github.com/xraypb/xray-core/common/net"
    16  	"github.com/xraypb/xray-core/common/protocol"
    17  	"github.com/xraypb/xray-core/common/session"
    18  	"github.com/xraypb/xray-core/core"
    19  	"github.com/xraypb/xray-core/features/dns"
    20  	"github.com/xraypb/xray-core/features/outbound"
    21  	"github.com/xraypb/xray-core/features/policy"
    22  	"github.com/xraypb/xray-core/features/routing"
    23  	routing_session "github.com/xraypb/xray-core/features/routing/session"
    24  	"github.com/xraypb/xray-core/features/stats"
    25  	"github.com/xraypb/xray-core/transport"
    26  	"github.com/xraypb/xray-core/transport/pipe"
    27  )
    28  
    29  var errSniffingTimeout = newError("timeout on sniffing")
    30  
    31  type cachedReader struct {
    32  	sync.Mutex
    33  	reader *pipe.Reader
    34  	cache  buf.MultiBuffer
    35  }
    36  
    37  func (r *cachedReader) Cache(b *buf.Buffer) {
    38  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    39  	r.Lock()
    40  	if !mb.IsEmpty() {
    41  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    42  	}
    43  	b.Clear()
    44  	rawBytes := b.Extend(buf.Size)
    45  	n := r.cache.Copy(rawBytes)
    46  	b.Resize(0, int32(n))
    47  	r.Unlock()
    48  }
    49  
    50  func (r *cachedReader) readInternal() buf.MultiBuffer {
    51  	r.Lock()
    52  	defer r.Unlock()
    53  
    54  	if r.cache != nil && !r.cache.IsEmpty() {
    55  		mb := r.cache
    56  		r.cache = nil
    57  		return mb
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    64  	mb := r.readInternal()
    65  	if mb != nil {
    66  		return mb, nil
    67  	}
    68  
    69  	return r.reader.ReadMultiBuffer()
    70  }
    71  
    72  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    73  	mb := r.readInternal()
    74  	if mb != nil {
    75  		return mb, nil
    76  	}
    77  
    78  	return r.reader.ReadMultiBufferTimeout(timeout)
    79  }
    80  
    81  func (r *cachedReader) Interrupt() {
    82  	r.Lock()
    83  	if r.cache != nil {
    84  		r.cache = buf.ReleaseMulti(r.cache)
    85  	}
    86  	r.Unlock()
    87  	r.reader.Interrupt()
    88  }
    89  
    90  // DefaultDispatcher is a default implementation of Dispatcher.
    91  type DefaultDispatcher struct {
    92  	ohm    outbound.Manager
    93  	router routing.Router
    94  	policy policy.Manager
    95  	stats  stats.Manager
    96  	dns    dns.Client
    97  	fdns   dns.FakeDNSEngine
    98  }
    99  
   100  func init() {
   101  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   102  		d := new(DefaultDispatcher)
   103  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
   104  			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
   105  				d.fdns = fdns
   106  			})
   107  			return d.Init(config.(*Config), om, router, pm, sm, dc)
   108  		}); err != nil {
   109  			return nil, err
   110  		}
   111  		return d, nil
   112  	}))
   113  }
   114  
   115  // Init initializes DefaultDispatcher.
   116  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dns dns.Client) error {
   117  	d.ohm = om
   118  	d.router = router
   119  	d.policy = pm
   120  	d.stats = sm
   121  	d.dns = dns
   122  	return nil
   123  }
   124  
   125  // Type implements common.HasType.
   126  func (*DefaultDispatcher) Type() interface{} {
   127  	return routing.DispatcherType()
   128  }
   129  
   130  // Start implements common.Runnable.
   131  func (*DefaultDispatcher) Start() error {
   132  	return nil
   133  }
   134  
   135  // Close implements common.Closable.
   136  func (*DefaultDispatcher) Close() error { return nil }
   137  
   138  func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
   139  	downOpt := pipe.OptionsFromContext(ctx)
   140  	upOpt := downOpt
   141  
   142  	if network == net.Network_UDP {
   143  		var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
   144  		// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
   145  		// When target replies, server will restore the domain and send back to client.
   146  		// Note: this map is not global but per connection context
   147  		upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
   148  			for i, buffer := range mb {
   149  				if buffer.UDP == nil {
   150  					continue
   151  				}
   152  				addr := buffer.UDP.Address
   153  				if addr.Family().IsIP() {
   154  					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
   155  						domain := fkr0.GetDomainFromFakeDNS(addr)
   156  						if len(domain) > 0 {
   157  							buffer.UDP.Address = net.DomainAddress(domain)
   158  							newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   159  						} else {
   160  							newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   161  						}
   162  					}
   163  				} else {
   164  					if ip2domain == nil {
   165  						ip2domain = new(sync.Map)
   166  						newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
   167  					}
   168  					domain := addr.Domain()
   169  					ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
   170  					if err == nil {
   171  						for _, ip := range ips {
   172  							ip2domain.Store(ip.String(), domain)
   173  						}
   174  						newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   175  					} else {
   176  						newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
   177  					}
   178  				}
   179  			}
   180  			return mb
   181  		}))
   182  		downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
   183  			for i, buffer := range mb {
   184  				if buffer.UDP == nil {
   185  					continue
   186  				}
   187  				addr := buffer.UDP.Address
   188  				if addr.Family().IsIP() {
   189  					if ip2domain == nil {
   190  						continue
   191  					}
   192  					if domain, found := ip2domain.Load(addr.IP().String()); found {
   193  						buffer.UDP.Address = net.DomainAddress(domain.(string))
   194  						newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   195  					}
   196  				} else {
   197  					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
   198  						fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
   199  						buffer.UDP.Address = fakeIp[0]
   200  						newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
   201  					}
   202  				}
   203  			}
   204  			return mb
   205  		}))
   206  	}
   207  	uplinkReader, uplinkWriter := pipe.New(upOpt...)
   208  	downlinkReader, downlinkWriter := pipe.New(downOpt...)
   209  
   210  	inboundLink := &transport.Link{
   211  		Reader: downlinkReader,
   212  		Writer: uplinkWriter,
   213  	}
   214  
   215  	outboundLink := &transport.Link{
   216  		Reader: uplinkReader,
   217  		Writer: downlinkWriter,
   218  	}
   219  
   220  	sessionInbound := session.InboundFromContext(ctx)
   221  	var user *protocol.MemoryUser
   222  	if sessionInbound != nil {
   223  		user = sessionInbound.User
   224  	}
   225  
   226  	if user != nil && len(user.Email) > 0 {
   227  		p := d.policy.ForLevel(user.Level)
   228  		if p.Stats.UserUplink {
   229  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   230  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   231  				inboundLink.Writer = &SizeStatWriter{
   232  					Counter: c,
   233  					Writer:  inboundLink.Writer,
   234  				}
   235  			}
   236  		}
   237  		if p.Stats.UserDownlink {
   238  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   239  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   240  				outboundLink.Writer = &SizeStatWriter{
   241  					Counter: c,
   242  					Writer:  outboundLink.Writer,
   243  				}
   244  			}
   245  		}
   246  	}
   247  
   248  	return inboundLink, outboundLink
   249  }
   250  
   251  func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
   252  	domain := result.Domain()
   253  	if domain == "" {
   254  		return false
   255  	}
   256  	for _, d := range request.ExcludeForDomain {
   257  		if strings.ToLower(domain) == d {
   258  			return false
   259  		}
   260  	}
   261  	protocolString := result.Protocol()
   262  	if resComp, ok := result.(SnifferResultComposite); ok {
   263  		protocolString = resComp.ProtocolForDomainResult()
   264  	}
   265  	for _, p := range request.OverrideDestinationForProtocol {
   266  		if strings.HasPrefix(protocolString, p) {
   267  			return true
   268  		}
   269  		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
   270  			destination.Address.Family().IsIP() && fkr0.IsIPInIPPool(destination.Address) {
   271  			newError("Using sniffer ", protocolString, " since the fake DNS missed").WriteToLog(session.ExportIDToError(ctx))
   272  			return true
   273  		}
   274  		if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
   275  			if resultSubset.IsProtoSubsetOf(p) {
   276  				return true
   277  			}
   278  		}
   279  	}
   280  
   281  	return false
   282  }
   283  
   284  // Dispatch implements routing.Dispatcher.
   285  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   286  	if !destination.IsValid() {
   287  		panic("Dispatcher: Invalid destination.")
   288  	}
   289  	ob := &session.Outbound{
   290  		Target: destination,
   291  	}
   292  	ctx = session.ContextWithOutbound(ctx, ob)
   293  	content := session.ContentFromContext(ctx)
   294  	if content == nil {
   295  		content = new(session.Content)
   296  		ctx = session.ContextWithContent(ctx, content)
   297  	}
   298  
   299  	sniffingRequest := content.SniffingRequest
   300  	inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
   301  	if !sniffingRequest.Enabled {
   302  		go d.routedDispatch(ctx, outbound, destination)
   303  	} else {
   304  		go func() {
   305  			cReader := &cachedReader{
   306  				reader: outbound.Reader.(*pipe.Reader),
   307  			}
   308  			outbound.Reader = cReader
   309  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   310  			if err == nil {
   311  				content.Protocol = result.Protocol()
   312  			}
   313  			if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   314  				domain := result.Domain()
   315  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   316  				destination.Address = net.ParseAddress(domain)
   317  				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
   318  					ob.RouteTarget = destination
   319  				} else {
   320  					ob.Target = destination
   321  				}
   322  			}
   323  			d.routedDispatch(ctx, outbound, destination)
   324  		}()
   325  	}
   326  	return inbound, nil
   327  }
   328  
   329  // DispatchLink implements routing.Dispatcher.
   330  func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
   331  	if !destination.IsValid() {
   332  		return newError("Dispatcher: Invalid destination.")
   333  	}
   334  	ob := &session.Outbound{
   335  		Target: destination,
   336  	}
   337  	ctx = session.ContextWithOutbound(ctx, ob)
   338  	content := session.ContentFromContext(ctx)
   339  	if content == nil {
   340  		content = new(session.Content)
   341  		ctx = session.ContextWithContent(ctx, content)
   342  	}
   343  	sniffingRequest := content.SniffingRequest
   344  	if !sniffingRequest.Enabled {
   345  		go d.routedDispatch(ctx, outbound, destination)
   346  	} else {
   347  		go func() {
   348  			cReader := &cachedReader{
   349  				reader: outbound.Reader.(*pipe.Reader),
   350  			}
   351  			outbound.Reader = cReader
   352  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   353  			if err == nil {
   354  				content.Protocol = result.Protocol()
   355  			}
   356  			if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
   357  				domain := result.Domain()
   358  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   359  				destination.Address = net.ParseAddress(domain)
   360  				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
   361  					ob.RouteTarget = destination
   362  				} else {
   363  					ob.Target = destination
   364  				}
   365  			}
   366  			d.routedDispatch(ctx, outbound, destination)
   367  		}()
   368  	}
   369  
   370  	return nil
   371  }
   372  
   373  func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
   374  	payload := buf.New()
   375  	defer payload.Release()
   376  
   377  	sniffer := NewSniffer(ctx)
   378  
   379  	metaresult, metadataErr := sniffer.SniffMetadata(ctx)
   380  
   381  	if metadataOnly {
   382  		return metaresult, metadataErr
   383  	}
   384  
   385  	contentResult, contentErr := func() (SniffResult, error) {
   386  		totalAttempt := 0
   387  		for {
   388  			select {
   389  			case <-ctx.Done():
   390  				return nil, ctx.Err()
   391  			default:
   392  				totalAttempt++
   393  				if totalAttempt > 2 {
   394  					return nil, errSniffingTimeout
   395  				}
   396  
   397  				cReader.Cache(payload)
   398  				if !payload.IsEmpty() {
   399  					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
   400  					if err != common.ErrNoClue {
   401  						return result, err
   402  					}
   403  				}
   404  				if payload.IsFull() {
   405  					return nil, errUnknownContent
   406  				}
   407  			}
   408  		}
   409  	}()
   410  	if contentErr != nil && metadataErr == nil {
   411  		return metaresult, nil
   412  	}
   413  	if contentErr == nil && metadataErr == nil {
   414  		return CompositeResult(metaresult, contentResult), nil
   415  	}
   416  	return contentResult, contentErr
   417  }
   418  
   419  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   420  	ob := session.OutboundFromContext(ctx)
   421  	if hosts, ok := d.dns.(dns.HostsLookup); ok && destination.Address.Family().IsDomain() {
   422  		proxied := hosts.LookupHosts(ob.Target.String())
   423  		if proxied != nil {
   424  			ro := ob.RouteTarget == destination
   425  			destination.Address = *proxied
   426  			if ro {
   427  				ob.RouteTarget = destination
   428  			} else {
   429  				ob.Target = destination
   430  			}
   431  		}
   432  	}
   433  
   434  	var handler outbound.Handler
   435  
   436  	routingLink := routing_session.AsRoutingContext(ctx)
   437  	inTag := routingLink.GetInboundTag()
   438  	isPickRoute := 0
   439  	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
   440  		ctx = session.SetForcedOutboundTagToContext(ctx, "")
   441  		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
   442  			isPickRoute = 1
   443  			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   444  			handler = h
   445  		} else {
   446  			newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
   447  			common.Close(link.Writer)
   448  			common.Interrupt(link.Reader)
   449  			return
   450  		}
   451  	} else if d.router != nil {
   452  		if route, err := d.router.PickRoute(routingLink); err == nil {
   453  			outTag := route.GetOutboundTag()
   454  			if h := d.ohm.GetHandler(outTag); h != nil {
   455  				isPickRoute = 2
   456  				newError("taking detour [", outTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   457  				handler = h
   458  			} else {
   459  				newError("non existing outTag: ", outTag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   460  			}
   461  		} else {
   462  			newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
   463  		}
   464  	}
   465  
   466  	if handler == nil {
   467  		handler = d.ohm.GetDefaultHandler()
   468  	}
   469  
   470  	if handler == nil {
   471  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   472  		common.Close(link.Writer)
   473  		common.Interrupt(link.Reader)
   474  		return
   475  	}
   476  
   477  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   478  		if tag := handler.Tag(); tag != "" {
   479  			if inTag == "" {
   480  				accessMessage.Detour = tag
   481  			} else if isPickRoute == 1 {
   482  				accessMessage.Detour = inTag + " ==> " + tag
   483  			} else if isPickRoute == 2 {
   484  				accessMessage.Detour = inTag + " -> " + tag
   485  			} else {
   486  				accessMessage.Detour = inTag + " >> " + tag
   487  			}
   488  		}
   489  		log.Record(accessMessage)
   490  	}
   491  
   492  	handler.Dispatch(ctx, link)
   493  }