github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/app/dispatcher/default.go (about)

     1  // +build !confonly
     2  
     3  package dispatcher
     4  
     5  //go:generate go run v2ray.com/core/common/errors/errorgen
     6  
     7  import (
     8  	"context"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"v2ray.com/core"
    14  	"v2ray.com/core/common"
    15  	"v2ray.com/core/common/buf"
    16  	"v2ray.com/core/common/log"
    17  	"v2ray.com/core/common/net"
    18  	"v2ray.com/core/common/protocol"
    19  	"v2ray.com/core/common/session"
    20  	"v2ray.com/core/features/outbound"
    21  	"v2ray.com/core/features/policy"
    22  	"v2ray.com/core/features/routing"
    23  	routing_session "v2ray.com/core/features/routing/session"
    24  	"v2ray.com/core/features/stats"
    25  	"v2ray.com/core/transport"
    26  	"v2ray.com/core/transport/pipe"
    27  )
    28  
    29  var (
    30  	errSniffingTimeout = newError("timeout on sniffing")
    31  )
    32  
    33  type cachedReader struct {
    34  	sync.Mutex
    35  	reader *pipe.Reader
    36  	cache  buf.MultiBuffer
    37  }
    38  
    39  func (r *cachedReader) Cache(b *buf.Buffer) {
    40  	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
    41  	r.Lock()
    42  	if !mb.IsEmpty() {
    43  		r.cache, _ = buf.MergeMulti(r.cache, mb)
    44  	}
    45  	b.Clear()
    46  	rawBytes := b.Extend(buf.Size)
    47  	n := r.cache.Copy(rawBytes)
    48  	b.Resize(0, int32(n))
    49  	r.Unlock()
    50  }
    51  
    52  func (r *cachedReader) readInternal() buf.MultiBuffer {
    53  	r.Lock()
    54  	defer r.Unlock()
    55  
    56  	if r.cache != nil && !r.cache.IsEmpty() {
    57  		mb := r.cache
    58  		r.cache = nil
    59  		return mb
    60  	}
    61  
    62  	return nil
    63  }
    64  
    65  func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
    66  	mb := r.readInternal()
    67  	if mb != nil {
    68  		return mb, nil
    69  	}
    70  
    71  	return r.reader.ReadMultiBuffer()
    72  }
    73  
    74  func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
    75  	mb := r.readInternal()
    76  	if mb != nil {
    77  		return mb, nil
    78  	}
    79  
    80  	return r.reader.ReadMultiBufferTimeout(timeout)
    81  }
    82  
    83  func (r *cachedReader) Interrupt() {
    84  	r.Lock()
    85  	if r.cache != nil {
    86  		r.cache = buf.ReleaseMulti(r.cache)
    87  	}
    88  	r.Unlock()
    89  	r.reader.Interrupt()
    90  }
    91  
    92  // DefaultDispatcher is a default implementation of Dispatcher.
    93  type DefaultDispatcher struct {
    94  	ohm    outbound.Manager
    95  	router routing.Router
    96  	policy policy.Manager
    97  	stats  stats.Manager
    98  }
    99  
   100  func init() {
   101  	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   102  		d := new(DefaultDispatcher)
   103  		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
   104  			return d.Init(config.(*Config), om, router, pm, sm)
   105  		}); err != nil {
   106  			return nil, err
   107  		}
   108  		return d, nil
   109  	}))
   110  }
   111  
   112  // Init initializes DefaultDispatcher.
   113  func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
   114  	d.ohm = om
   115  	d.router = router
   116  	d.policy = pm
   117  	d.stats = sm
   118  	return nil
   119  }
   120  
   121  // Type implements common.HasType.
   122  func (*DefaultDispatcher) Type() interface{} {
   123  	return routing.DispatcherType()
   124  }
   125  
   126  // Start implements common.Runnable.
   127  func (*DefaultDispatcher) Start() error {
   128  	return nil
   129  }
   130  
   131  // Close implements common.Closable.
   132  func (*DefaultDispatcher) Close() error { return nil }
   133  
   134  func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
   135  	opt := pipe.OptionsFromContext(ctx)
   136  	uplinkReader, uplinkWriter := pipe.New(opt...)
   137  	downlinkReader, downlinkWriter := pipe.New(opt...)
   138  
   139  	inboundLink := &transport.Link{
   140  		Reader: downlinkReader,
   141  		Writer: uplinkWriter,
   142  	}
   143  
   144  	outboundLink := &transport.Link{
   145  		Reader: uplinkReader,
   146  		Writer: downlinkWriter,
   147  	}
   148  
   149  	sessionInbound := session.InboundFromContext(ctx)
   150  	var user *protocol.MemoryUser
   151  	var clientIP net.IP
   152  	if sessionInbound != nil {
   153  		clientIP = sessionInbound.Source.Address.IP()
   154  		user = sessionInbound.User
   155  	}
   156  
   157  	if user != nil && len(user.Email) > 0 {
   158          
   159  		if len(clientIP) > 0 {
   160  			if ipStorager, _ := stats.GetOrRegisterIPStorager(d.stats, "user>>>" +user.Email+">>>ip"); ipStorager != nil {
   161  				ipStorager.Add(clientIP)
   162  			}
   163  		}
   164          var bucket *RateLimiter
   165  
   166  		p := d.policy.ForLevel(user.Level)
   167  
   168  		if p.Buffer.Rate != 0 {
   169  			bucket = NewRateLimiter(int64(p.Buffer.Rate) * 1024)
   170  			inboundLink.Writer = RateWriter(inboundLink.Writer, bucket)
   171  			outboundLink.Writer = RateWriter(outboundLink.Writer, bucket)
   172  		}
   173  
   174  		if p.Stats.UserUplink {
   175  			name := "user>>>" + user.Email + ">>>traffic>>>uplink"
   176  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   177  				inboundLink.Writer = &SizeStatWriter{
   178  					Counter: c,
   179  					Writer:  inboundLink.Writer,
   180  				}
   181  			}
   182  		}
   183  		if p.Stats.UserDownlink {
   184  			name := "user>>>" + user.Email + ">>>traffic>>>downlink"
   185  			if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
   186  				outboundLink.Writer = &SizeStatWriter{
   187  					Counter: c,
   188  					Writer:  outboundLink.Writer,
   189  				}
   190  			}
   191  		}
   192  	}
   193  
   194  	return inboundLink, outboundLink
   195  }
   196  
   197  func shouldOverride(result SniffResult, domainOverride []string) bool {
   198  	for _, p := range domainOverride {
   199  		if strings.HasPrefix(result.Protocol(), p) {
   200  			return true
   201  		}
   202  	}
   203  	return false
   204  }
   205  
   206  // Dispatch implements routing.Dispatcher.
   207  func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
   208  	if !destination.IsValid() {
   209  		panic("Dispatcher: Invalid destination.")
   210  	}
   211  	ob := &session.Outbound{
   212  		Target: destination,
   213  	}
   214  	ctx = session.ContextWithOutbound(ctx, ob)
   215  
   216  	inbound, outbound := d.getLink(ctx)
   217  	content := session.ContentFromContext(ctx)
   218  	if content == nil {
   219  		content = new(session.Content)
   220  		ctx = session.ContextWithContent(ctx, content)
   221  	}
   222  	sniffingRequest := content.SniffingRequest
   223  	if destination.Network != net.Network_TCP || !sniffingRequest.Enabled {
   224  		go d.routedDispatch(ctx, outbound, destination)
   225  	} else {
   226  		go func() {
   227  			cReader := &cachedReader{
   228  				reader: outbound.Reader.(*pipe.Reader),
   229  			}
   230  			outbound.Reader = cReader
   231  			result, err := sniffer(ctx, cReader)
   232  			if err == nil {
   233  				content.Protocol = result.Protocol()
   234  			}
   235  			if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
   236  				domain := result.Domain()
   237  				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
   238  				destination.Address = net.ParseAddress(domain)
   239  				ob.Target = destination
   240  			}
   241  			d.routedDispatch(ctx, outbound, destination)
   242  		}()
   243  	}
   244  	return inbound, nil
   245  }
   246  
   247  func sniffer(ctx context.Context, cReader *cachedReader) (SniffResult, error) {
   248  	payload := buf.New()
   249  	defer payload.Release()
   250  
   251  	sniffer := NewSniffer()
   252  	totalAttempt := 0
   253  	for {
   254  		select {
   255  		case <-ctx.Done():
   256  			return nil, ctx.Err()
   257  		default:
   258  			totalAttempt++
   259  			if totalAttempt > 2 {
   260  				return nil, errSniffingTimeout
   261  			}
   262  
   263  			cReader.Cache(payload)
   264  			if !payload.IsEmpty() {
   265  				result, err := sniffer.Sniff(payload.Bytes())
   266  				if err != common.ErrNoClue {
   267  					return result, err
   268  				}
   269  			}
   270  			if payload.IsFull() {
   271  				return nil, errUnknownContent
   272  			}
   273  		}
   274  	}
   275  }
   276  
   277  func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
   278  	var handler outbound.Handler
   279  
   280  	skipRoutePick := false
   281  	if content := session.ContentFromContext(ctx); content != nil {
   282  		skipRoutePick = content.SkipRoutePick
   283  	}
   284  
   285  	if d.router != nil && !skipRoutePick {
   286  		if route, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil {
   287  			tag := route.GetOutboundTag()
   288  			if h := d.ohm.GetHandler(tag); h != nil {
   289  				newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
   290  				handler = h
   291  			} else {
   292  				newError("non existing tag: ", tag).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   293  			}
   294  		} else {
   295  			newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx))
   296  		}
   297  	}
   298  
   299  	if handler == nil {
   300  		handler = d.ohm.GetDefaultHandler()
   301  	}
   302  
   303  	if handler == nil {
   304  		newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx))
   305  		common.Close(link.Writer)
   306  		common.Interrupt(link.Reader)
   307  		return
   308  	}
   309  
   310  	if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
   311  		if tag := handler.Tag(); tag != "" {
   312  			accessMessage.Detour = tag
   313  		}
   314  		log.Record(accessMessage)
   315  	}
   316  
   317  	handler.Dispatch(ctx, link)
   318  }