github.com/xraypb/xray-core@v1.6.6/transport/pipe/impl.go (about)

     1  package pipe
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"runtime"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xraypb/xray-core/common"
    11  	"github.com/xraypb/xray-core/common/buf"
    12  	"github.com/xraypb/xray-core/common/signal"
    13  	"github.com/xraypb/xray-core/common/signal/done"
    14  )
    15  
    16  type state byte
    17  
    18  const (
    19  	open state = iota
    20  	closed
    21  	errord
    22  )
    23  
    24  type pipeOption struct {
    25  	limit           int32 // maximum buffer size in bytes
    26  	discardOverflow bool
    27  	onTransmission  func(buffer buf.MultiBuffer) buf.MultiBuffer
    28  }
    29  
    30  func (o *pipeOption) isFull(curSize int32) bool {
    31  	return o.limit >= 0 && curSize > o.limit
    32  }
    33  
    34  type pipe struct {
    35  	sync.Mutex
    36  	data        buf.MultiBuffer
    37  	readSignal  *signal.Notifier
    38  	writeSignal *signal.Notifier
    39  	done        *done.Instance
    40  	option      pipeOption
    41  	state       state
    42  }
    43  
    44  var (
    45  	errBufferFull = errors.New("buffer full")
    46  	errSlowDown   = errors.New("slow down")
    47  )
    48  
    49  func (p *pipe) getState(forRead bool) error {
    50  	switch p.state {
    51  	case open:
    52  		if !forRead && p.option.isFull(p.data.Len()) {
    53  			return errBufferFull
    54  		}
    55  		return nil
    56  	case closed:
    57  		if !forRead {
    58  			return io.ErrClosedPipe
    59  		}
    60  		if !p.data.IsEmpty() {
    61  			return nil
    62  		}
    63  		return io.EOF
    64  	case errord:
    65  		return io.ErrClosedPipe
    66  	default:
    67  		panic("impossible case")
    68  	}
    69  }
    70  
    71  func (p *pipe) readMultiBufferInternal() (buf.MultiBuffer, error) {
    72  	p.Lock()
    73  	defer p.Unlock()
    74  
    75  	if err := p.getState(true); err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	data := p.data
    80  	p.data = nil
    81  	return data, nil
    82  }
    83  
    84  func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) {
    85  	for {
    86  		data, err := p.readMultiBufferInternal()
    87  		if data != nil || err != nil {
    88  			p.writeSignal.Signal()
    89  			return data, err
    90  		}
    91  
    92  		select {
    93  		case <-p.readSignal.Wait():
    94  		case <-p.done.Wait():
    95  		}
    96  	}
    97  }
    98  
    99  func (p *pipe) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) {
   100  	timer := time.NewTimer(d)
   101  	defer timer.Stop()
   102  
   103  	for {
   104  		data, err := p.readMultiBufferInternal()
   105  		if data != nil || err != nil {
   106  			p.writeSignal.Signal()
   107  			return data, err
   108  		}
   109  
   110  		select {
   111  		case <-p.readSignal.Wait():
   112  		case <-p.done.Wait():
   113  		case <-timer.C:
   114  			return nil, buf.ErrReadTimeout
   115  		}
   116  	}
   117  }
   118  
   119  func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error {
   120  	p.Lock()
   121  	defer p.Unlock()
   122  
   123  	if err := p.getState(false); err != nil {
   124  		return err
   125  	}
   126  
   127  	if p.data == nil {
   128  		p.data = mb
   129  		return nil
   130  	}
   131  
   132  	p.data, _ = buf.MergeMulti(p.data, mb)
   133  	return errSlowDown
   134  }
   135  
   136  func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
   137  	if mb.IsEmpty() {
   138  		return nil
   139  	}
   140  
   141  	if p.option.onTransmission != nil {
   142  		mb = p.option.onTransmission(mb)
   143  	}
   144  
   145  	for {
   146  		err := p.writeMultiBufferInternal(mb)
   147  		if err == nil {
   148  			p.readSignal.Signal()
   149  			return nil
   150  		}
   151  
   152  		if err == errSlowDown {
   153  			p.readSignal.Signal()
   154  
   155  			// Yield current goroutine. Hopefully the reading counterpart can pick up the payload.
   156  			runtime.Gosched()
   157  			return nil
   158  		}
   159  
   160  		if err == errBufferFull && p.option.discardOverflow {
   161  			buf.ReleaseMulti(mb)
   162  			return nil
   163  		}
   164  
   165  		if err != errBufferFull {
   166  			buf.ReleaseMulti(mb)
   167  			p.readSignal.Signal()
   168  			return err
   169  		}
   170  
   171  		select {
   172  		case <-p.writeSignal.Wait():
   173  		case <-p.done.Wait():
   174  			return io.ErrClosedPipe
   175  		}
   176  	}
   177  }
   178  
   179  func (p *pipe) Close() error {
   180  	p.Lock()
   181  	defer p.Unlock()
   182  
   183  	if p.state == closed || p.state == errord {
   184  		return nil
   185  	}
   186  
   187  	p.state = closed
   188  	common.Must(p.done.Close())
   189  	return nil
   190  }
   191  
   192  // Interrupt implements common.Interruptible.
   193  func (p *pipe) Interrupt() {
   194  	p.Lock()
   195  	defer p.Unlock()
   196  
   197  	if p.state == closed || p.state == errord {
   198  		return
   199  	}
   200  
   201  	p.state = errord
   202  
   203  	if !p.data.IsEmpty() {
   204  		buf.ReleaseMulti(p.data)
   205  		p.data = nil
   206  	}
   207  
   208  	common.Must(p.done.Close())
   209  }