github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/app/reverse/portal.go (about)

     1  package reverse
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/xtls/xray-core/common"
     9  	"github.com/xtls/xray-core/common/buf"
    10  	"github.com/xtls/xray-core/common/mux"
    11  	"github.com/xtls/xray-core/common/net"
    12  	"github.com/xtls/xray-core/common/session"
    13  	"github.com/xtls/xray-core/common/task"
    14  	"github.com/xtls/xray-core/features/outbound"
    15  	"github.com/xtls/xray-core/transport"
    16  	"github.com/xtls/xray-core/transport/pipe"
    17  	"google.golang.org/protobuf/proto"
    18  )
    19  
    20  type Portal struct {
    21  	ohm    outbound.Manager
    22  	tag    string
    23  	domain string
    24  	picker *StaticMuxPicker
    25  	client *mux.ClientManager
    26  }
    27  
    28  func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
    29  	if config.Tag == "" {
    30  		return nil, newError("portal tag is empty")
    31  	}
    32  
    33  	if config.Domain == "" {
    34  		return nil, newError("portal domain is empty")
    35  	}
    36  
    37  	picker, err := NewStaticMuxPicker()
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	return &Portal{
    43  		ohm:    ohm,
    44  		tag:    config.Tag,
    45  		domain: config.Domain,
    46  		picker: picker,
    47  		client: &mux.ClientManager{
    48  			Picker: picker,
    49  		},
    50  	}, nil
    51  }
    52  
    53  func (p *Portal) Start() error {
    54  	return p.ohm.AddHandler(context.Background(), &Outbound{
    55  		portal: p,
    56  		tag:    p.tag,
    57  	})
    58  }
    59  
    60  func (p *Portal) Close() error {
    61  	return p.ohm.RemoveHandler(context.Background(), p.tag)
    62  }
    63  
    64  func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
    65  	outbounds := session.OutboundsFromContext(ctx)
    66  	ob := outbounds[len(outbounds) - 1]
    67  	if ob == nil {
    68  		return newError("outbound metadata not found").AtError()
    69  	}
    70  
    71  	if isDomain(ob.Target, p.domain) {
    72  		muxClient, err := mux.NewClientWorker(*link, mux.ClientStrategy{})
    73  		if err != nil {
    74  			return newError("failed to create mux client worker").Base(err).AtWarning()
    75  		}
    76  
    77  		worker, err := NewPortalWorker(muxClient)
    78  		if err != nil {
    79  			return newError("failed to create portal worker").Base(err)
    80  		}
    81  
    82  		p.picker.AddWorker(worker)
    83  		return nil
    84  	}
    85  
    86  	return p.client.Dispatch(ctx, link)
    87  }
    88  
    89  type Outbound struct {
    90  	portal *Portal
    91  	tag    string
    92  }
    93  
    94  func (o *Outbound) Tag() string {
    95  	return o.tag
    96  }
    97  
    98  func (o *Outbound) Dispatch(ctx context.Context, link *transport.Link) {
    99  	if err := o.portal.HandleConnection(ctx, link); err != nil {
   100  		newError("failed to process reverse connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
   101  		common.Interrupt(link.Writer)
   102  	}
   103  }
   104  
   105  func (o *Outbound) Start() error {
   106  	return nil
   107  }
   108  
   109  func (o *Outbound) Close() error {
   110  	return nil
   111  }
   112  
   113  type StaticMuxPicker struct {
   114  	access  sync.Mutex
   115  	workers []*PortalWorker
   116  	cTask   *task.Periodic
   117  }
   118  
   119  func NewStaticMuxPicker() (*StaticMuxPicker, error) {
   120  	p := &StaticMuxPicker{}
   121  	p.cTask = &task.Periodic{
   122  		Execute:  p.cleanup,
   123  		Interval: time.Second * 30,
   124  	}
   125  	p.cTask.Start()
   126  	return p, nil
   127  }
   128  
   129  func (p *StaticMuxPicker) cleanup() error {
   130  	p.access.Lock()
   131  	defer p.access.Unlock()
   132  
   133  	var activeWorkers []*PortalWorker
   134  	for _, w := range p.workers {
   135  		if !w.Closed() {
   136  			activeWorkers = append(activeWorkers, w)
   137  		}
   138  	}
   139  
   140  	if len(activeWorkers) != len(p.workers) {
   141  		p.workers = activeWorkers
   142  	}
   143  
   144  	return nil
   145  }
   146  
   147  func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) {
   148  	p.access.Lock()
   149  	defer p.access.Unlock()
   150  
   151  	if len(p.workers) == 0 {
   152  		return nil, newError("empty worker list")
   153  	}
   154  
   155  	var minIdx int = -1
   156  	var minConn uint32 = 9999
   157  	for i, w := range p.workers {
   158  		if w.draining {
   159  			continue
   160  		}
   161  		if w.client.Closed() {
   162  			continue
   163  		}
   164  		if w.client.ActiveConnections() < minConn {
   165  			minConn = w.client.ActiveConnections()
   166  			minIdx = i
   167  		}
   168  	}
   169  
   170  	if minIdx == -1 {
   171  		for i, w := range p.workers {
   172  			if w.IsFull() {
   173  				continue
   174  			}
   175  			if w.client.ActiveConnections() < minConn {
   176  				minConn = w.client.ActiveConnections()
   177  				minIdx = i
   178  			}
   179  		}
   180  	}
   181  
   182  	if minIdx != -1 {
   183  		return p.workers[minIdx].client, nil
   184  	}
   185  
   186  	return nil, newError("no mux client worker available")
   187  }
   188  
   189  func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) {
   190  	p.access.Lock()
   191  	defer p.access.Unlock()
   192  
   193  	p.workers = append(p.workers, worker)
   194  }
   195  
   196  type PortalWorker struct {
   197  	client   *mux.ClientWorker
   198  	control  *task.Periodic
   199  	writer   buf.Writer
   200  	reader   buf.Reader
   201  	draining bool
   202  }
   203  
   204  func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
   205  	opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)}
   206  	uplinkReader, uplinkWriter := pipe.New(opt...)
   207  	downlinkReader, downlinkWriter := pipe.New(opt...)
   208  
   209  	ctx := context.Background()
   210  	outbounds := []*session.Outbound{{
   211  		Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
   212  	}}
   213  	ctx = session.ContextWithOutbounds(ctx, outbounds)
   214  	f := client.Dispatch(ctx, &transport.Link{
   215  		Reader: uplinkReader,
   216  		Writer: downlinkWriter,
   217  	})
   218  	if !f {
   219  		return nil, newError("unable to dispatch control connection")
   220  	}
   221  	w := &PortalWorker{
   222  		client: client,
   223  		reader: downlinkReader,
   224  		writer: uplinkWriter,
   225  	}
   226  	w.control = &task.Periodic{
   227  		Execute:  w.heartbeat,
   228  		Interval: time.Second * 2,
   229  	}
   230  	w.control.Start()
   231  	return w, nil
   232  }
   233  
   234  func (w *PortalWorker) heartbeat() error {
   235  	if w.client.Closed() {
   236  		return newError("client worker stopped")
   237  	}
   238  
   239  	if w.draining || w.writer == nil {
   240  		return newError("already disposed")
   241  	}
   242  
   243  	msg := &Control{}
   244  	msg.FillInRandom()
   245  
   246  	if w.client.TotalConnections() > 256 {
   247  		w.draining = true
   248  		msg.State = Control_DRAIN
   249  
   250  		defer func() {
   251  			common.Close(w.writer)
   252  			common.Interrupt(w.reader)
   253  			w.writer = nil
   254  		}()
   255  	}
   256  
   257  	b, err := proto.Marshal(msg)
   258  	common.Must(err)
   259  	mb := buf.MergeBytes(nil, b)
   260  	return w.writer.WriteMultiBuffer(mb)
   261  }
   262  
   263  func (w *PortalWorker) IsFull() bool {
   264  	return w.client.IsFull()
   265  }
   266  
   267  func (w *PortalWorker) Closed() bool {
   268  	return w.client.Closed()
   269  }