github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/common/mux/client.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/v2fly/v2ray-core/v5/common"
    10  	"github.com/v2fly/v2ray-core/v5/common/buf"
    11  	"github.com/v2fly/v2ray-core/v5/common/errors"
    12  	"github.com/v2fly/v2ray-core/v5/common/net"
    13  	"github.com/v2fly/v2ray-core/v5/common/protocol"
    14  	"github.com/v2fly/v2ray-core/v5/common/session"
    15  	"github.com/v2fly/v2ray-core/v5/common/signal/done"
    16  	"github.com/v2fly/v2ray-core/v5/common/task"
    17  	"github.com/v2fly/v2ray-core/v5/proxy"
    18  	"github.com/v2fly/v2ray-core/v5/transport"
    19  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    20  	"github.com/v2fly/v2ray-core/v5/transport/pipe"
    21  )
    22  
    23  type ClientManager struct {
    24  	Enabled bool // wheather mux is enabled from user config
    25  	Picker  WorkerPicker
    26  }
    27  
    28  func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
    29  	for i := 0; i < 16; i++ {
    30  		worker, err := m.Picker.PickAvailable()
    31  		if err != nil {
    32  			return err
    33  		}
    34  		if worker.Dispatch(ctx, link) {
    35  			return nil
    36  		}
    37  	}
    38  
    39  	return newError("unable to find an available mux client").AtWarning()
    40  }
    41  
    42  type WorkerPicker interface {
    43  	PickAvailable() (*ClientWorker, error)
    44  }
    45  
    46  type IncrementalWorkerPicker struct {
    47  	Factory ClientWorkerFactory
    48  
    49  	access      sync.Mutex
    50  	workers     []*ClientWorker
    51  	cleanupTask *task.Periodic
    52  }
    53  
    54  func (p *IncrementalWorkerPicker) cleanupFunc() error {
    55  	p.access.Lock()
    56  	defer p.access.Unlock()
    57  
    58  	if len(p.workers) == 0 {
    59  		return newError("no worker")
    60  	}
    61  
    62  	p.cleanup()
    63  	return nil
    64  }
    65  
    66  func (p *IncrementalWorkerPicker) cleanup() {
    67  	var activeWorkers []*ClientWorker
    68  	for _, w := range p.workers {
    69  		if !w.Closed() {
    70  			activeWorkers = append(activeWorkers, w)
    71  		}
    72  	}
    73  	p.workers = activeWorkers
    74  }
    75  
    76  func (p *IncrementalWorkerPicker) findAvailable() int {
    77  	for idx, w := range p.workers {
    78  		if !w.IsFull() {
    79  			return idx
    80  		}
    81  	}
    82  
    83  	return -1
    84  }
    85  
    86  func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
    87  	p.access.Lock()
    88  	defer p.access.Unlock()
    89  
    90  	idx := p.findAvailable()
    91  	if idx >= 0 {
    92  		n := len(p.workers)
    93  		if n > 1 && idx != n-1 {
    94  			p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
    95  		}
    96  		return p.workers[idx], false, nil
    97  	}
    98  
    99  	p.cleanup()
   100  
   101  	worker, err := p.Factory.Create()
   102  	if err != nil {
   103  		return nil, false, err
   104  	}
   105  	p.workers = append(p.workers, worker)
   106  
   107  	if p.cleanupTask == nil {
   108  		p.cleanupTask = &task.Periodic{
   109  			Interval: time.Second * 30,
   110  			Execute:  p.cleanupFunc,
   111  		}
   112  	}
   113  
   114  	return worker, true, nil
   115  }
   116  
   117  func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
   118  	worker, start, err := p.pickInternal()
   119  	if start {
   120  		common.Must(p.cleanupTask.Start())
   121  	}
   122  
   123  	return worker, err
   124  }
   125  
   126  type ClientWorkerFactory interface {
   127  	Create() (*ClientWorker, error)
   128  }
   129  
   130  type DialingWorkerFactory struct {
   131  	Proxy    proxy.Outbound
   132  	Dialer   internet.Dialer
   133  	Strategy ClientStrategy
   134  
   135  	ctx context.Context
   136  }
   137  
   138  func NewDialingWorkerFactory(ctx context.Context, proxy proxy.Outbound, dialer internet.Dialer, strategy ClientStrategy) *DialingWorkerFactory {
   139  	return &DialingWorkerFactory{
   140  		Proxy:    proxy,
   141  		Dialer:   dialer,
   142  		Strategy: strategy,
   143  		ctx:      ctx,
   144  	}
   145  }
   146  
   147  func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
   148  	opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
   149  	uplinkReader, upLinkWriter := pipe.New(opts...)
   150  	downlinkReader, downlinkWriter := pipe.New(opts...)
   151  
   152  	c, err := NewClientWorker(transport.Link{
   153  		Reader: downlinkReader,
   154  		Writer: upLinkWriter,
   155  	}, f.Strategy)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
   161  		ctx := session.ContextWithOutbound(f.ctx, &session.Outbound{
   162  			Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
   163  		})
   164  		ctx, cancel := context.WithCancel(ctx)
   165  
   166  		if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
   167  			errors.New("failed to handler mux client connection").Base(err).WriteToLog()
   168  		}
   169  		common.Must(c.Close())
   170  		cancel()
   171  	}(f.Proxy, f.Dialer, c.done)
   172  
   173  	return c, nil
   174  }
   175  
   176  type ClientStrategy struct {
   177  	MaxConcurrency uint32
   178  	MaxConnection  uint32
   179  }
   180  
   181  type ClientWorker struct {
   182  	sessionManager *SessionManager
   183  	link           transport.Link
   184  	done           *done.Instance
   185  	strategy       ClientStrategy
   186  }
   187  
   188  var (
   189  	muxCoolAddress = net.DomainAddress("v1.mux.cool")
   190  	muxCoolPort    = net.Port(9527)
   191  )
   192  
   193  // NewClientWorker creates a new mux.Client.
   194  func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
   195  	c := &ClientWorker{
   196  		sessionManager: NewSessionManager(),
   197  		link:           stream,
   198  		done:           done.New(),
   199  		strategy:       s,
   200  	}
   201  
   202  	go c.fetchOutput()
   203  	go c.monitor()
   204  
   205  	return c, nil
   206  }
   207  
   208  func (m *ClientWorker) TotalConnections() uint32 {
   209  	return uint32(m.sessionManager.Count())
   210  }
   211  
   212  func (m *ClientWorker) ActiveConnections() uint32 {
   213  	return uint32(m.sessionManager.Size())
   214  }
   215  
   216  // Closed returns true if this Client is closed.
   217  func (m *ClientWorker) Closed() bool {
   218  	return m.done.Done()
   219  }
   220  
   221  func (m *ClientWorker) monitor() {
   222  	timer := time.NewTicker(time.Second * 16)
   223  	defer timer.Stop()
   224  
   225  	for {
   226  		select {
   227  		case <-m.done.Wait():
   228  			m.sessionManager.Close()
   229  			common.Close(m.link.Writer)
   230  			common.Interrupt(m.link.Reader)
   231  			return
   232  		case <-timer.C:
   233  			size := m.sessionManager.Size()
   234  			if size == 0 && m.sessionManager.CloseIfNoSession() {
   235  				common.Must(m.done.Close())
   236  			}
   237  		}
   238  	}
   239  }
   240  
   241  func writeFirstPayload(reader buf.Reader, writer *Writer) error {
   242  	err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
   243  	if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
   244  		return writer.WriteMultiBuffer(buf.MultiBuffer{})
   245  	}
   246  
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
   255  	dest := session.OutboundFromContext(ctx).Target
   256  	transferType := protocol.TransferTypeStream
   257  	if dest.Network == net.Network_UDP {
   258  		transferType = protocol.TransferTypePacket
   259  	}
   260  	s.transferType = transferType
   261  	writer := NewWriter(s.ID, dest, output, transferType)
   262  	defer s.Close()
   263  	defer writer.Close()
   264  
   265  	newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   266  	if err := writeFirstPayload(s.input, writer); err != nil {
   267  		newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
   268  		writer.hasError = true
   269  		common.Interrupt(s.input)
   270  		return
   271  	}
   272  
   273  	if err := buf.Copy(s.input, writer); err != nil {
   274  		newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   275  		writer.hasError = true
   276  		common.Interrupt(s.input)
   277  		return
   278  	}
   279  }
   280  
   281  func (m *ClientWorker) IsClosing() bool {
   282  	sm := m.sessionManager
   283  	if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
   284  		return true
   285  	}
   286  	return false
   287  }
   288  
   289  func (m *ClientWorker) IsFull() bool {
   290  	if m.IsClosing() || m.Closed() {
   291  		return true
   292  	}
   293  
   294  	sm := m.sessionManager
   295  	if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
   296  		return true
   297  	}
   298  	return false
   299  }
   300  
   301  func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
   302  	if m.IsFull() || m.Closed() {
   303  		return false
   304  	}
   305  
   306  	sm := m.sessionManager
   307  	s := sm.Allocate()
   308  	if s == nil {
   309  		return false
   310  	}
   311  	s.input = link.Reader
   312  	s.output = link.Writer
   313  	go fetchInput(ctx, s, m.link.Writer)
   314  	return true
   315  }
   316  
   317  func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
   318  	if meta.Option.Has(OptionData) {
   319  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   320  	}
   321  	return nil
   322  }
   323  
   324  func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
   325  	if meta.Option.Has(OptionData) {
   326  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   327  	}
   328  	return nil
   329  }
   330  
   331  func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
   332  	if !meta.Option.Has(OptionData) {
   333  		return nil
   334  	}
   335  
   336  	s, found := m.sessionManager.Get(meta.SessionID)
   337  	if !found {
   338  		// Notify remote peer to close this session.
   339  		closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
   340  		closingWriter.Close()
   341  
   342  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   343  	}
   344  
   345  	rr := s.NewReader(reader)
   346  	err := buf.Copy(rr, s.output)
   347  	if err != nil && buf.IsWriteError(err) {
   348  		newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
   349  
   350  		// Notify remote peer to close this session.
   351  		closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
   352  		closingWriter.Close()
   353  
   354  		drainErr := buf.Copy(rr, buf.Discard)
   355  		common.Interrupt(s.input)
   356  		s.Close()
   357  		return drainErr
   358  	}
   359  
   360  	return err
   361  }
   362  
   363  func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
   364  	if s, found := m.sessionManager.Get(meta.SessionID); found {
   365  		if meta.Option.Has(OptionError) {
   366  			common.Interrupt(s.input)
   367  			common.Interrupt(s.output)
   368  		}
   369  		s.Close()
   370  	}
   371  	if meta.Option.Has(OptionData) {
   372  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   373  	}
   374  	return nil
   375  }
   376  
   377  func (m *ClientWorker) fetchOutput() {
   378  	defer func() {
   379  		common.Must(m.done.Close())
   380  	}()
   381  
   382  	reader := &buf.BufferedReader{Reader: m.link.Reader}
   383  
   384  	var meta FrameMetadata
   385  	for {
   386  		err := meta.Unmarshal(reader)
   387  		if err != nil {
   388  			if errors.Cause(err) != io.EOF {
   389  				newError("failed to read metadata").Base(err).WriteToLog()
   390  			}
   391  			break
   392  		}
   393  
   394  		switch meta.SessionStatus {
   395  		case SessionStatusKeepAlive:
   396  			err = m.handleStatueKeepAlive(&meta, reader)
   397  		case SessionStatusEnd:
   398  			err = m.handleStatusEnd(&meta, reader)
   399  		case SessionStatusNew:
   400  			err = m.handleStatusNew(&meta, reader)
   401  		case SessionStatusKeep:
   402  			err = m.handleStatusKeep(&meta, reader)
   403  		default:
   404  			status := meta.SessionStatus
   405  			newError("unknown status: ", status).AtError().WriteToLog()
   406  			return
   407  		}
   408  
   409  		if err != nil {
   410  			newError("failed to process data").Base(err).WriteToLog()
   411  			return
   412  		}
   413  	}
   414  }