github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/mux/client.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/xtls/xray-core/common"
    10  	"github.com/xtls/xray-core/common/buf"
    11  	"github.com/xtls/xray-core/common/errors"
    12  	"github.com/xtls/xray-core/common/net"
    13  	"github.com/xtls/xray-core/common/protocol"
    14  	"github.com/xtls/xray-core/common/session"
    15  	"github.com/xtls/xray-core/common/signal/done"
    16  	"github.com/xtls/xray-core/common/task"
    17  	"github.com/xtls/xray-core/common/xudp"
    18  	"github.com/xtls/xray-core/proxy"
    19  	"github.com/xtls/xray-core/transport"
    20  	"github.com/xtls/xray-core/transport/internet"
    21  	"github.com/xtls/xray-core/transport/pipe"
    22  )
    23  
    24  type ClientManager struct {
    25  	Enabled bool // wheather mux is enabled from user config
    26  	Picker  WorkerPicker
    27  }
    28  
    29  func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
    30  	for i := 0; i < 16; i++ {
    31  		worker, err := m.Picker.PickAvailable()
    32  		if err != nil {
    33  			return err
    34  		}
    35  		if worker.Dispatch(ctx, link) {
    36  			return nil
    37  		}
    38  	}
    39  
    40  	return newError("unable to find an available mux client").AtWarning()
    41  }
    42  
    43  type WorkerPicker interface {
    44  	PickAvailable() (*ClientWorker, error)
    45  }
    46  
    47  type IncrementalWorkerPicker struct {
    48  	Factory ClientWorkerFactory
    49  
    50  	access      sync.Mutex
    51  	workers     []*ClientWorker
    52  	cleanupTask *task.Periodic
    53  }
    54  
    55  func (p *IncrementalWorkerPicker) cleanupFunc() error {
    56  	p.access.Lock()
    57  	defer p.access.Unlock()
    58  
    59  	if len(p.workers) == 0 {
    60  		return newError("no worker")
    61  	}
    62  
    63  	p.cleanup()
    64  	return nil
    65  }
    66  
    67  func (p *IncrementalWorkerPicker) cleanup() {
    68  	var activeWorkers []*ClientWorker
    69  	for _, w := range p.workers {
    70  		if !w.Closed() {
    71  			activeWorkers = append(activeWorkers, w)
    72  		}
    73  	}
    74  	p.workers = activeWorkers
    75  }
    76  
    77  func (p *IncrementalWorkerPicker) findAvailable() int {
    78  	for idx, w := range p.workers {
    79  		if !w.IsFull() {
    80  			return idx
    81  		}
    82  	}
    83  
    84  	return -1
    85  }
    86  
    87  func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
    88  	p.access.Lock()
    89  	defer p.access.Unlock()
    90  
    91  	idx := p.findAvailable()
    92  	if idx >= 0 {
    93  		n := len(p.workers)
    94  		if n > 1 && idx != n-1 {
    95  			p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
    96  		}
    97  		return p.workers[idx], false, nil
    98  	}
    99  
   100  	p.cleanup()
   101  
   102  	worker, err := p.Factory.Create()
   103  	if err != nil {
   104  		return nil, false, err
   105  	}
   106  	p.workers = append(p.workers, worker)
   107  
   108  	if p.cleanupTask == nil {
   109  		p.cleanupTask = &task.Periodic{
   110  			Interval: time.Second * 30,
   111  			Execute:  p.cleanupFunc,
   112  		}
   113  	}
   114  
   115  	return worker, true, nil
   116  }
   117  
   118  func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
   119  	worker, start, err := p.pickInternal()
   120  	if start {
   121  		common.Must(p.cleanupTask.Start())
   122  	}
   123  
   124  	return worker, err
   125  }
   126  
   127  type ClientWorkerFactory interface {
   128  	Create() (*ClientWorker, error)
   129  }
   130  
   131  type DialingWorkerFactory struct {
   132  	Proxy    proxy.Outbound
   133  	Dialer   internet.Dialer
   134  	Strategy ClientStrategy
   135  }
   136  
   137  func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
   138  	opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
   139  	uplinkReader, upLinkWriter := pipe.New(opts...)
   140  	downlinkReader, downlinkWriter := pipe.New(opts...)
   141  
   142  	c, err := NewClientWorker(transport.Link{
   143  		Reader: downlinkReader,
   144  		Writer: upLinkWriter,
   145  	}, f.Strategy)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
   151  		outbounds := []*session.Outbound{{
   152  			Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
   153  		}}
   154  		ctx := session.ContextWithOutbounds(context.Background(), outbounds)
   155  		ctx, cancel := context.WithCancel(ctx)
   156  
   157  		if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
   158  			errors.New("failed to handler mux client connection").Base(err).WriteToLog()
   159  		}
   160  		common.Must(c.Close())
   161  		cancel()
   162  	}(f.Proxy, f.Dialer, c.done)
   163  
   164  	return c, nil
   165  }
   166  
   167  type ClientStrategy struct {
   168  	MaxConcurrency uint32
   169  	MaxConnection  uint32
   170  }
   171  
   172  type ClientWorker struct {
   173  	sessionManager *SessionManager
   174  	link           transport.Link
   175  	done           *done.Instance
   176  	strategy       ClientStrategy
   177  }
   178  
   179  var (
   180  	muxCoolAddress = net.DomainAddress("v1.mux.cool")
   181  	muxCoolPort    = net.Port(9527)
   182  )
   183  
   184  // NewClientWorker creates a new mux.Client.
   185  func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
   186  	c := &ClientWorker{
   187  		sessionManager: NewSessionManager(),
   188  		link:           stream,
   189  		done:           done.New(),
   190  		strategy:       s,
   191  	}
   192  
   193  	go c.fetchOutput()
   194  	go c.monitor()
   195  
   196  	return c, nil
   197  }
   198  
   199  func (m *ClientWorker) TotalConnections() uint32 {
   200  	return uint32(m.sessionManager.Count())
   201  }
   202  
   203  func (m *ClientWorker) ActiveConnections() uint32 {
   204  	return uint32(m.sessionManager.Size())
   205  }
   206  
   207  // Closed returns true if this Client is closed.
   208  func (m *ClientWorker) Closed() bool {
   209  	return m.done.Done()
   210  }
   211  
   212  func (m *ClientWorker) monitor() {
   213  	timer := time.NewTicker(time.Second * 16)
   214  	defer timer.Stop()
   215  
   216  	for {
   217  		select {
   218  		case <-m.done.Wait():
   219  			m.sessionManager.Close()
   220  			common.Close(m.link.Writer)
   221  			common.Interrupt(m.link.Reader)
   222  			return
   223  		case <-timer.C:
   224  			size := m.sessionManager.Size()
   225  			if size == 0 && m.sessionManager.CloseIfNoSession() {
   226  				common.Must(m.done.Close())
   227  			}
   228  		}
   229  	}
   230  }
   231  
   232  func writeFirstPayload(reader buf.Reader, writer *Writer) error {
   233  	err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
   234  	if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
   235  		return writer.WriteMultiBuffer(buf.MultiBuffer{})
   236  	}
   237  
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
   246  	outbounds := session.OutboundsFromContext(ctx)
   247  	ob := outbounds[len(outbounds) - 1]
   248  	transferType := protocol.TransferTypeStream
   249  	if ob.Target.Network == net.Network_UDP {
   250  		transferType = protocol.TransferTypePacket
   251  	}
   252  	s.transferType = transferType
   253  	writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
   254  	defer s.Close(false)
   255  	defer writer.Close()
   256  
   257  	newError("dispatching request to ", ob.Target).WriteToLog(session.ExportIDToError(ctx))
   258  	if err := writeFirstPayload(s.input, writer); err != nil {
   259  		newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
   260  		writer.hasError = true
   261  		return
   262  	}
   263  
   264  	if err := buf.Copy(s.input, writer); err != nil {
   265  		newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   266  		writer.hasError = true
   267  		return
   268  	}
   269  }
   270  
   271  func (m *ClientWorker) IsClosing() bool {
   272  	sm := m.sessionManager
   273  	if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
   274  		return true
   275  	}
   276  	return false
   277  }
   278  
   279  func (m *ClientWorker) IsFull() bool {
   280  	if m.IsClosing() || m.Closed() {
   281  		return true
   282  	}
   283  
   284  	sm := m.sessionManager
   285  	if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
   286  		return true
   287  	}
   288  	return false
   289  }
   290  
   291  func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
   292  	if m.IsFull() || m.Closed() {
   293  		return false
   294  	}
   295  
   296  	sm := m.sessionManager
   297  	s := sm.Allocate()
   298  	if s == nil {
   299  		return false
   300  	}
   301  	s.input = link.Reader
   302  	s.output = link.Writer
   303  	go fetchInput(ctx, s, m.link.Writer)
   304  	return true
   305  }
   306  
   307  func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
   308  	if meta.Option.Has(OptionData) {
   309  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   310  	}
   311  	return nil
   312  }
   313  
   314  func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
   315  	if meta.Option.Has(OptionData) {
   316  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   317  	}
   318  	return nil
   319  }
   320  
   321  func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
   322  	if !meta.Option.Has(OptionData) {
   323  		return nil
   324  	}
   325  
   326  	s, found := m.sessionManager.Get(meta.SessionID)
   327  	if !found {
   328  		// Notify remote peer to close this session.
   329  		closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
   330  		closingWriter.Close()
   331  
   332  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   333  	}
   334  
   335  	rr := s.NewReader(reader, &meta.Target)
   336  	err := buf.Copy(rr, s.output)
   337  	if err != nil && buf.IsWriteError(err) {
   338  		newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
   339  		s.Close(false)
   340  		return buf.Copy(rr, buf.Discard)
   341  	}
   342  
   343  	return err
   344  }
   345  
   346  func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
   347  	if s, found := m.sessionManager.Get(meta.SessionID); found {
   348  		s.Close(false)
   349  	}
   350  	if meta.Option.Has(OptionData) {
   351  		return buf.Copy(NewStreamReader(reader), buf.Discard)
   352  	}
   353  	return nil
   354  }
   355  
   356  func (m *ClientWorker) fetchOutput() {
   357  	defer func() {
   358  		common.Must(m.done.Close())
   359  	}()
   360  
   361  	reader := &buf.BufferedReader{Reader: m.link.Reader}
   362  
   363  	var meta FrameMetadata
   364  	for {
   365  		err := meta.Unmarshal(reader)
   366  		if err != nil {
   367  			if errors.Cause(err) != io.EOF {
   368  				newError("failed to read metadata").Base(err).WriteToLog()
   369  			}
   370  			break
   371  		}
   372  
   373  		switch meta.SessionStatus {
   374  		case SessionStatusKeepAlive:
   375  			err = m.handleStatueKeepAlive(&meta, reader)
   376  		case SessionStatusEnd:
   377  			err = m.handleStatusEnd(&meta, reader)
   378  		case SessionStatusNew:
   379  			err = m.handleStatusNew(&meta, reader)
   380  		case SessionStatusKeep:
   381  			err = m.handleStatusKeep(&meta, reader)
   382  		default:
   383  			status := meta.SessionStatus
   384  			newError("unknown status: ", status).AtError().WriteToLog()
   385  			return
   386  		}
   387  
   388  		if err != nil {
   389  			newError("failed to process data").Base(err).WriteToLog()
   390  			return
   391  		}
   392  	}
   393  }