github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/pipe/impl.go (about)

     1  package pipe
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"runtime"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xtls/xray-core/common"
    11  	"github.com/xtls/xray-core/common/buf"
    12  	"github.com/xtls/xray-core/common/signal"
    13  	"github.com/xtls/xray-core/common/signal/done"
    14  )
    15  
    16  type state byte
    17  
    18  const (
    19  	open state = iota
    20  	closed
    21  	errord
    22  )
    23  
    24  type pipeOption struct {
    25  	limit           int32 // maximum buffer size in bytes
    26  	discardOverflow bool
    27  }
    28  
    29  func (o *pipeOption) isFull(curSize int32) bool {
    30  	return o.limit >= 0 && curSize > o.limit
    31  }
    32  
    33  type pipe struct {
    34  	sync.Mutex
    35  	data        buf.MultiBuffer
    36  	readSignal  *signal.Notifier
    37  	writeSignal *signal.Notifier
    38  	done        *done.Instance
    39  	errChan     chan error
    40  	option      pipeOption
    41  	state       state
    42  }
    43  
    44  var (
    45  	errBufferFull = errors.New("buffer full")
    46  	errSlowDown   = errors.New("slow down")
    47  )
    48  
    49  func (p *pipe) getState(forRead bool) error {
    50  	switch p.state {
    51  	case open:
    52  		if !forRead && p.option.isFull(p.data.Len()) {
    53  			return errBufferFull
    54  		}
    55  		return nil
    56  	case closed:
    57  		if !forRead {
    58  			return io.ErrClosedPipe
    59  		}
    60  		if !p.data.IsEmpty() {
    61  			return nil
    62  		}
    63  		return io.EOF
    64  	case errord:
    65  		return io.ErrClosedPipe
    66  	default:
    67  		panic("impossible case")
    68  	}
    69  }
    70  
    71  func (p *pipe) readMultiBufferInternal() (buf.MultiBuffer, error) {
    72  	p.Lock()
    73  	defer p.Unlock()
    74  
    75  	if err := p.getState(true); err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	data := p.data
    80  	p.data = nil
    81  	return data, nil
    82  }
    83  
    84  func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) {
    85  	for {
    86  		data, err := p.readMultiBufferInternal()
    87  		if data != nil || err != nil {
    88  			p.writeSignal.Signal()
    89  			return data, err
    90  		}
    91  
    92  		select {
    93  		case <-p.readSignal.Wait():
    94  		case <-p.done.Wait():
    95  		case err = <-p.errChan:
    96  			return nil, err
    97  		}
    98  	}
    99  }
   100  
   101  func (p *pipe) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) {
   102  	timer := time.NewTimer(d)
   103  	defer timer.Stop()
   104  
   105  	for {
   106  		data, err := p.readMultiBufferInternal()
   107  		if data != nil || err != nil {
   108  			p.writeSignal.Signal()
   109  			return data, err
   110  		}
   111  
   112  		select {
   113  		case <-p.readSignal.Wait():
   114  		case <-p.done.Wait():
   115  		case <-timer.C:
   116  			return nil, buf.ErrReadTimeout
   117  		}
   118  	}
   119  }
   120  
   121  func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error {
   122  	p.Lock()
   123  	defer p.Unlock()
   124  
   125  	if err := p.getState(false); err != nil {
   126  		return err
   127  	}
   128  
   129  	if p.data == nil {
   130  		p.data = mb
   131  		return nil
   132  	}
   133  
   134  	p.data, _ = buf.MergeMulti(p.data, mb)
   135  	return errSlowDown
   136  }
   137  
   138  func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
   139  	if mb.IsEmpty() {
   140  		return nil
   141  	}
   142  
   143  	for {
   144  		err := p.writeMultiBufferInternal(mb)
   145  		if err == nil {
   146  			p.readSignal.Signal()
   147  			return nil
   148  		}
   149  
   150  		if err == errSlowDown {
   151  			p.readSignal.Signal()
   152  
   153  			// Yield current goroutine. Hopefully the reading counterpart can pick up the payload.
   154  			runtime.Gosched()
   155  			return nil
   156  		}
   157  
   158  		if err == errBufferFull && p.option.discardOverflow {
   159  			buf.ReleaseMulti(mb)
   160  			return nil
   161  		}
   162  
   163  		if err != errBufferFull {
   164  			buf.ReleaseMulti(mb)
   165  			p.readSignal.Signal()
   166  			return err
   167  		}
   168  
   169  		select {
   170  		case <-p.writeSignal.Wait():
   171  		case <-p.done.Wait():
   172  			return io.ErrClosedPipe
   173  		}
   174  	}
   175  }
   176  
   177  func (p *pipe) Close() error {
   178  	p.Lock()
   179  	defer p.Unlock()
   180  
   181  	if p.state == closed || p.state == errord {
   182  		return nil
   183  	}
   184  
   185  	p.state = closed
   186  	common.Must(p.done.Close())
   187  	return nil
   188  }
   189  
   190  // Interrupt implements common.Interruptible.
   191  func (p *pipe) Interrupt() {
   192  	p.Lock()
   193  	defer p.Unlock()
   194  
   195  	if p.state == closed || p.state == errord {
   196  		return
   197  	}
   198  
   199  	p.state = errord
   200  
   201  	if !p.data.IsEmpty() {
   202  		buf.ReleaseMulti(p.data)
   203  		p.data = nil
   204  	}
   205  
   206  	common.Must(p.done.Close())
   207  }