github.com/imannamdari/v2ray-core/v5@v5.0.5/app/dispatcher/default.go (about)

     1  package dispatcher
     2  
     3  //go:generate go run github.com/imannamdari/v2ray-core/v5/common/errors/errorgen
     4  
     5  import (
     6  	"context"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	core "github.com/imannamdari/v2ray-core/v5"
    12  	"github.com/imannamdari/v2ray-core/v5/common"
    13  	"github.com/imannamdari/v2ray-core/v5/common/buf"
    14  	"github.com/imannamdari/v2ray-core/v5/common/log"
    15  	"github.com/imannamdari/v2ray-core/v5/common/net"
    16  	"github.com/imannamdari/v2ray-core/v5/common/protocol"
    17  	"github.com/imannamdari/v2ray-core/v5/common/session"
    18  	"github.com/imannamdari/v2ray-core/v5/common/strmatcher"
    19  	"github.com/imannamdari/v2ray-core/v5/features/outbound"
    20  	"github.com/imannamdari/v2ray-core/v5/features/policy"
    21  	"github.com/imannamdari/v2ray-core/v5/features/routing"
    22  	routing_session "github.com/imannamdari/v2ray-core/v5/features/routing/session"
    23  	"github.com/imannamdari/v2ray-core/v5/features/stats"
    24  	"github.com/imannamdari/v2ray-core/v5/transport"
    25  	"github.com/imannamdari/v2ray-core/v5/transport/pipe"
    26  )
    27  
    28  var errSniffingTimeout = newError("timeout on sniffing")
    29  
    30  type cachedReader struct {
    31  	sync.Mutex
    32  	reader *pipe.Reader
    33  	cache  buf.MultiBuffer
    34  }
    35  
    36  func (r *cachedReader) Cache(b *buf.Buffer) {
    37  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    38  	r.Lock()
    39  	if !mb.IsEmpty() {
    40  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    41  	}
    42  	b.Clear()
    43  	rawBytes := b.Extend(buf.Size)
    44  	n := r.cache.Copy(rawBytes)
    45  	b.Resize(0, int32(n))
    46  	r.Unlock()
    47  }
    48  
    49  func (r *cachedReader) readInternal() buf.MultiBuffer {
    50  	r.Lock()
    51  	defer r.Unlock()
    52  
    53  	if r.cache != nil && !r.cache.IsEmpty() {
    54  		mb := r.cache
    55  		r.cache = nil
    56  		return mb
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    63  	mb := r.readInternal()
    64  	if mb != nil {
    65  		return mb, nil
    66  	}
    67  
    68  	return r.reader.ReadMultiBuffer()
    69  }
    70  
    71  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    72  	mb := r.readInternal()
    73  	if mb != nil {
    74  		return mb, nil
    75  	}
    76  
    77  	return r.reader.ReadMultiBufferTimeout(timeout)
    78  }
    79  
    80  func (r *cachedReader) Interrupt() {
    81  	r.Lock()
    82  	if r.cache != nil {
    83  		r.cache = buf.ReleaseMulti(r.cache)
    84  	}
    85  	r.Unlock()
    86  	r.reader.Interrupt()
    87  }
    88  
    89  // DefaultDispatcher is a default implementation of Dispatcher.
    90  type DefaultDispatcher struct {
    91  	ohm    outbound.Manager
    92  	router routing.Router
    93  	policy policy.Manager
    94  	stats  stats.Manager
    95  }
    96  
    97  func init() {
    98  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
    99  		d := new(DefaultDispatcher)
   100  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
   101  			return d.Init(config.(*Config), om, router, pm, sm)
   102  		}); err != nil {
   103  			return nil, err
   104  		}
   105  		return d, nil
   106  	}))
   107  }
   108  
   109  // Init initializes DefaultDispatcher.
   110  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
   111  	d.ohm = om
   112  	d.router = router
   113  	d.policy = pm
   114  	d.stats = sm
   115  	return nil
   116  }
   117  
   118  // Type implements common.HasType.
   119  func (*DefaultDispatcher) Type() interface{} {
   120  	return routing.DispatcherType()
   121  }
   122  
   123  // Start implements common.Runnable.
   124  func (*DefaultDispatcher) Start() error {
   125  	return nil
   126  }
   127  
   128  // Close implements common.Closable.
   129  func (*DefaultDispatcher) Close() error { return nil }
   130  
   131  func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
   132  	opt := pipe.OptionsFromContext(ctx)
   133  	uplinkReader, uplinkWriter := pipe.New(opt...)
   134  	downlinkReader, downlinkWriter := pipe.New(opt...)
   135  
   136  	inboundLink := &transport.Link{
   137  		Reader: downlinkReader,
   138  		Writer: uplinkWriter,
   139  	}
   140  
   141  	outboundLink := &transport.Link{
   142  		Reader: uplinkReader,
   143  		Writer: downlinkWriter,
   144  	}
   145  
   146  	sessionInbound := session.InboundFromContext(ctx)
   147  	var user *protocol.MemoryUser
   148  	if sessionInbound != nil {
   149  		user = sessionInbound.User
   150  	}
   151  
   152  	if user != nil && len(user.Email) > 0 {
   153  		p := d.policy.ForLevel(user.Level)
   154  		if p.Stats.UserUplink {
   155  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   156  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   157  				inboundLink.Writer = &SizeStatWriter{
   158  					Counter: c,
   159  					Writer:  inboundLink.Writer,
   160  				}
   161  			}
   162  		}
   163  		if p.Stats.UserDownlink {
   164  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   165  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   166  				outboundLink.Writer = &SizeStatWriter{
   167  					Counter: c,
   168  					Writer:  outboundLink.Writer,
   169  				}
   170  			}
   171  		}
   172  	}
   173  
   174  	return inboundLink, outboundLink
   175  }
   176  
   177  func shouldOverride(result SniffResult, domainOverride []string) bool {
   178  	if result.Domain() == "" {
   179  		return false
   180  	}
   181  	protocolString := result.Protocol()
   182  	if resComp, ok := result.(SnifferResultComposite); ok {
   183  		protocolString = resComp.ProtocolForDomainResult()
   184  	}
   185  	for _, p := range domainOverride {
   186  		if strings.HasPrefix(protocolString, p) || strings.HasSuffix(protocolString, p) {
   187  			return true
   188  		}
   189  		if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
   190  			if resultSubset.IsProtoSubsetOf(p) {
   191  				return true
   192  			}
   193  		}
   194  	}
   195  	return false
   196  }
   197  
   198  // Dispatch implements routing.Dispatcher.
   199  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   200  	if !destination.IsValid() {
   201  		panic("Dispatcher: Invalid destination.")
   202  	}
   203  	ob := &session.Outbound{
   204  		Target: destination,
   205  	}
   206  	ctx = session.ContextWithOutbound(ctx, ob)
   207  
   208  	inbound, outbound := d.getLink(ctx)
   209  	content := session.ContentFromContext(ctx)
   210  	if content == nil {
   211  		content = new(session.Content)
   212  		ctx = session.ContextWithContent(ctx, content)
   213  	}
   214  	sniffingRequest := content.SniffingRequest
   215  	if !sniffingRequest.Enabled {
   216  		go d.routedDispatch(ctx, outbound, destination)
   217  	} else {
   218  		go func() {
   219  			cReader := &cachedReader{
   220  				reader: outbound.Reader.(*pipe.Reader),
   221  			}
   222  			outbound.Reader = cReader
   223  			result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
   224  			if err == nil {
   225  				content.Protocol = result.Protocol()
   226  			}
   227  			if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
   228  				if domain, err := strmatcher.ToDomain(result.Domain()); err == nil {
   229  					newError("sniffed domain: ", domain, " for ", destination).WriteToLog(session.ExportIDToError(ctx))
   230  					destination.Address = net.ParseAddress(domain)
   231  					ob.Target = destination
   232  				}
   233  			}
   234  			d.routedDispatch(ctx, outbound, destination)
   235  		}()
   236  	}
   237  
   238  	return inbound, nil
   239  }
   240  
   241  func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
   242  	payload := buf.New()
   243  	defer payload.Release()
   244  
   245  	sniffer := NewSniffer(ctx)
   246  
   247  	metaresult, metadataErr := sniffer.SniffMetadata(ctx)
   248  
   249  	if metadataOnly {
   250  		return metaresult, metadataErr
   251  	}
   252  
   253  	contentResult, contentErr := func() (SniffResult, error) {
   254  		totalAttempt := 0
   255  		for {
   256  			select {
   257  			case <-ctx.Done():
   258  				return nil, ctx.Err()
   259  			default:
   260  				totalAttempt++
   261  				if totalAttempt > 2 {
   262  					return nil, errSniffingTimeout
   263  				}
   264  
   265  				cReader.Cache(payload)
   266  				if !payload.IsEmpty() {
   267  					result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
   268  					if err != common.ErrNoClue {
   269  						return result, err
   270  					}
   271  				}
   272  				if payload.IsFull() {
   273  					return nil, errUnknownContent
   274  				}
   275  			}
   276  		}
   277  	}()
   278  	if contentErr != nil && metadataErr == nil {
   279  		return metaresult, nil
   280  	}
   281  	if contentErr == nil && metadataErr == nil {
   282  		return CompositeResult(metaresult, contentResult), nil
   283  	}
   284  	return contentResult, contentErr
   285  }
   286  
   287  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   288  	var handler outbound.Handler
   289  
   290  	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
   291  		ctx = session.SetForcedOutboundTagToContext(ctx, "")
   292  		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
   293  			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   294  			handler = h
   295  		} else {
   296  			newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx))
   297  			common.Close(link.Writer)
   298  			common.Interrupt(link.Reader)
   299  			return
   300  		}
   301  	} else if d.router != nil {
   302  		if route, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil {
   303  			tag := route.GetOutboundTag()
   304  			if h := d.ohm.GetHandler(tag); h != nil {
   305  				newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   306  				handler = h
   307  			} else {
   308  				newError("non existing tag: ", tag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   309  			}
   310  		} else {
   311  			newError("default route for ", destination).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   312  		}
   313  	}
   314  
   315  	if handler == nil {
   316  		handler = d.ohm.GetDefaultHandler()
   317  	}
   318  
   319  	if handler == nil {
   320  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   321  		common.Close(link.Writer)
   322  		common.Interrupt(link.Reader)
   323  		return
   324  	}
   325  
   326  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   327  		if tag := handler.Tag(); tag != "" {
   328  			accessMessage.Detour = tag
   329  			if d.policy.ForSystem().OverrideAccessLogDest {
   330  				accessMessage.To = destination
   331  			}
   332  		}
   333  		log.Record(accessMessage)
   334  	}
   335  
   336  	handler.Dispatch(ctx, link)
   337  }