github.com/database64128/shadowsocks-go@v1.7.0/zerocopy/stream.go (about)

     1  package zerocopy
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  )
     7  
     8  // defaultBufferSize is the default buffer size to use
     9  // when neither the reader nor the writer has buffer size requirements.
    10  // It's the same default as io.Copy.
    11  const defaultBufferSize = 32768
    12  
    13  // ReaderInfo contains information about a reader.
    14  type ReaderInfo struct {
    15  	Headroom Headroom
    16  
    17  	// MinPayloadBufferSizePerRead is the minimum size of payload buffer
    18  	// the ReadZeroCopy method requires for an unbuffered read.
    19  	//
    20  	// This is usually required by chunk-based protocols to be able to read
    21  	// whole chunks without needing internal caching.
    22  	MinPayloadBufferSizePerRead int
    23  }
    24  
    25  // Reader provides a stream interface for reading.
    26  type Reader interface {
    27  	// ReaderInfo returns information about the reader.
    28  	ReaderInfo() ReaderInfo
    29  
    30  	// ReadZeroCopy uses b as buffer space to initiate a read operation.
    31  	//
    32  	// b must have at least [ReaderInfo.Headroom.Front] bytes before payloadBufStart
    33  	// and [ReaderInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen.
    34  	//
    35  	// payloadBufLen must be at least [ReaderInfo.MinPayloadBufferSizePerRead].
    36  	//
    37  	// The read operation may use the whole space of b.
    38  	// The actual payload will be confined in [payloadBufStart, payloadBufLen).
    39  	//
    40  	// If no error occurs, the returned payload is b[payloadBufStart : payloadBufStart+payloadLen].
    41  	ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (payloadLen int, err error)
    42  }
    43  
    44  // WriterInfo contains information about a writer.
    45  type WriterInfo struct {
    46  	Headroom Headroom
    47  
    48  	// MaxPayloadSizePerWrite is the maximum size of payload
    49  	// the WriteZeroCopy method can write at a time.
    50  	//
    51  	// This is usually required by chunk-based protocols to be able to write
    52  	// one chunk at a time without needing to break up the payload.
    53  	//
    54  	// 0 means no size limit.
    55  	MaxPayloadSizePerWrite int
    56  }
    57  
    58  // Writer provides a stream interface for writing.
    59  type Writer interface {
    60  	// WriterInfo returns information about the writer.
    61  	WriterInfo() WriterInfo
    62  
    63  	// WriteZeroCopy uses b as buffer space to initiate a write operation.
    64  	//
    65  	// b must have at least [WriterInfo.Headroom.Front] bytes before payloadBufStart
    66  	// and [WriterInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen.
    67  	//
    68  	// payloadLen must not exceed [WriterInfo.MaxPayloadSizePerWrite].
    69  	//
    70  	// The write operation may use the whole space of b.
    71  	WriteZeroCopy(b []byte, payloadStart, payloadLen int) (payloadWritten int, err error)
    72  }
    73  
    74  // DirectReader provides access to the underlying [io.Reader].
    75  type DirectReader interface {
    76  	// DirectReader returns the underlying reader for direct reads.
    77  	DirectReader() io.Reader
    78  }
    79  
    80  // DirectWriter provides access to the underlying [io.Writer].
    81  type DirectWriter interface {
    82  	// DirectWriter returns the underlying writer for direct writes.
    83  	DirectWriter() io.Writer
    84  }
    85  
    86  // Relay reads from r and writes to w using zero-copy methods.
    87  // It returns the number of bytes transferred, and any error occurred during transfer.
    88  func Relay(w Writer, r Reader) (n int64, err error) {
    89  	// Use direct read/write when possible.
    90  	if dr, ok := r.(DirectReader); ok {
    91  		if dw, ok := w.(DirectWriter); ok {
    92  			r := dr.DirectReader()
    93  			w := dw.DirectWriter()
    94  			return io.Copy(w, r)
    95  		}
    96  	}
    97  
    98  	// Process reader and writer info.
    99  	ri := r.ReaderInfo()
   100  	wi := w.WriterInfo()
   101  	headroom := MaxHeadroom(ri.Headroom, wi.Headroom)
   102  
   103  	// Check payload buffer size requirement compatibility.
   104  	if wi.MaxPayloadSizePerWrite > 0 && ri.MinPayloadBufferSizePerRead > wi.MaxPayloadSizePerWrite {
   105  		return relayFallback(w, r, headroom.Front, headroom.Rear, ri.MinPayloadBufferSizePerRead, wi.MaxPayloadSizePerWrite)
   106  	}
   107  
   108  	payloadBufSize := ri.MinPayloadBufferSizePerRead
   109  	if payloadBufSize == 0 {
   110  		payloadBufSize = wi.MaxPayloadSizePerWrite
   111  		if payloadBufSize == 0 {
   112  			payloadBufSize = defaultBufferSize
   113  		}
   114  	}
   115  
   116  	// Make buffer.
   117  	b := make([]byte, headroom.Front+payloadBufSize+headroom.Rear)
   118  
   119  	// Main relay loop.
   120  	for {
   121  		var payloadLen int
   122  		payloadLen, err = r.ReadZeroCopy(b, headroom.Front, payloadBufSize)
   123  		if payloadLen == 0 {
   124  			if err == io.EOF {
   125  				err = nil
   126  			}
   127  			return
   128  		}
   129  
   130  		payloadWritten, werr := w.WriteZeroCopy(b, headroom.Front, payloadLen)
   131  		n += int64(payloadWritten)
   132  		if werr != nil {
   133  			err = werr
   134  			return
   135  		}
   136  
   137  		if err != nil {
   138  			if err == io.EOF {
   139  				err = nil
   140  			}
   141  			return
   142  		}
   143  	}
   144  }
   145  
   146  // relayFallback uses copying to handle situations where the reader requires more payload buffer space than the writer can handle in one write call.
   147  func relayFallback(w Writer, r Reader, frontHeadroom, rearHeadroom, readMaxPayloadSize, writeMaxPayloadSize int) (n int64, err error) {
   148  	br := make([]byte, frontHeadroom+readMaxPayloadSize+rearHeadroom)
   149  	bw := make([]byte, frontHeadroom+writeMaxPayloadSize+rearHeadroom)
   150  
   151  	for {
   152  		var payloadLen int
   153  		payloadLen, err = r.ReadZeroCopy(br, frontHeadroom, readMaxPayloadSize)
   154  		if payloadLen == 0 {
   155  			if err == io.EOF {
   156  				err = nil
   157  			}
   158  			return
   159  		}
   160  
   161  		// Short-circuit to avoid copying if payload can fit in one write.
   162  		if payloadLen <= writeMaxPayloadSize {
   163  			payloadWritten, werr := w.WriteZeroCopy(br, frontHeadroom, payloadLen)
   164  			n += int64(payloadWritten)
   165  			if werr != nil {
   166  				err = werr
   167  			}
   168  			if err != nil {
   169  				return
   170  			}
   171  			continue
   172  		}
   173  
   174  		// Loop until all of br[frontHeadroom : frontHeadroom+payloadLen] is written.
   175  		for i, j := 0, 0; i < payloadLen; i += j {
   176  			j = copy(bw[frontHeadroom:frontHeadroom+writeMaxPayloadSize], br[frontHeadroom+i:frontHeadroom+payloadLen])
   177  			payloadWritten, werr := w.WriteZeroCopy(bw, frontHeadroom, j)
   178  			n += int64(payloadWritten)
   179  			if werr != nil {
   180  				err = werr
   181  				return
   182  			}
   183  		}
   184  
   185  		if err != nil {
   186  			if err == io.EOF {
   187  				err = nil
   188  			}
   189  			return
   190  		}
   191  	}
   192  }
   193  
   194  // CloseRead provides the CloseRead method.
   195  type CloseRead interface {
   196  	// CloseRead indicates to the underlying reader that no further reads will happen.
   197  	CloseRead() error
   198  }
   199  
   200  // CloseWrite provides the CloseWrite method.
   201  type CloseWrite interface {
   202  	// CloseWrite indicates to the underlying writer that no further writes will happen.
   203  	CloseWrite() error
   204  }
   205  
   206  // ReadWriter provides a stream interface for reading and writing.
   207  type ReadWriter interface {
   208  	Reader
   209  	Writer
   210  	CloseRead
   211  	CloseWrite
   212  	io.Closer
   213  }
   214  
   215  // TwoWayRelay relays data between left and right using zero-copy methods.
   216  // It returns the number of bytes sent from left to right, from right to left,
   217  // and any error occurred during transfer.
   218  func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) {
   219  	var l2rErr error
   220  	ctrlCh := make(chan struct{})
   221  
   222  	go func() {
   223  		nl2r, l2rErr = Relay(right, left)
   224  		right.CloseWrite()
   225  		ctrlCh <- struct{}{}
   226  	}()
   227  
   228  	nr2l, err = Relay(left, right)
   229  	left.CloseWrite()
   230  	<-ctrlCh
   231  
   232  	if l2rErr != nil {
   233  		err = l2rErr
   234  	}
   235  	return
   236  }
   237  
   238  // DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite.
   239  type DirectReadWriteCloser interface {
   240  	io.ReadWriteCloser
   241  	CloseRead
   242  	CloseWrite
   243  }
   244  
   245  // DirectTwoWayRelay relays data between left and right using [io.Copy].
   246  // It returns the number of bytes sent from left to right, from right to left,
   247  // and any error occurred during transfer.
   248  func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) {
   249  	var l2rErr error
   250  	ctrlCh := make(chan struct{})
   251  
   252  	go func() {
   253  		nl2r, l2rErr = io.Copy(right, left)
   254  		right.CloseWrite()
   255  		ctrlCh <- struct{}{}
   256  	}()
   257  
   258  	nr2l, err = io.Copy(left, right)
   259  	left.CloseWrite()
   260  	<-ctrlCh
   261  
   262  	if l2rErr != nil {
   263  		err = l2rErr
   264  	}
   265  	return
   266  }
   267  
   268  // DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser].
   269  type DirectReadWriteCloserOpener interface {
   270  	// Open opens a [DirectReadWriteCloser] with the specified initial payload.
   271  	Open(b []byte) (DirectReadWriteCloser, error)
   272  }
   273  
   274  // SimpleDirectReadWriteCloserOpener wraps a [DirectReadWriteCloser] for the Open method to return.
   275  type SimpleDirectReadWriteCloserOpener struct {
   276  	DirectReadWriteCloser
   277  }
   278  
   279  // Open implements the DirectReadWriteCloserOpener Open method.
   280  func (o *SimpleDirectReadWriteCloserOpener) Open(b []byte) (DirectReadWriteCloser, error) {
   281  	_, err := o.DirectReadWriteCloser.Write(b)
   282  	return o.DirectReadWriteCloser, err
   283  }
   284  
   285  // ReadWriterTestFunc tests the left and right ReadWriters by performing 2 writes
   286  // on each ReadWriter and validating the read results.
   287  //
   288  // The left and right ReadWriters must be connected with a duplex pipe.
   289  func ReadWriterTestFunc(t tester, l, r ReadWriter) {
   290  	defer r.Close()
   291  	defer l.Close()
   292  
   293  	var (
   294  		hello = []byte{'h', 'e', 'l', 'l', 'o'}
   295  		world = []byte{'w', 'o', 'r', 'l', 'd'}
   296  	)
   297  
   298  	lri := l.ReaderInfo()
   299  	lwi := l.WriterInfo()
   300  	lwmax := lwi.MaxPayloadSizePerWrite
   301  	if lwmax == 0 {
   302  		lwmax = 5
   303  	}
   304  	lrmin := lri.MinPayloadBufferSizePerRead
   305  	if lrmin == 0 {
   306  		lrmin = 5
   307  	}
   308  	lwbuf := make([]byte, lwi.Headroom.Front+lwmax+lwi.Headroom.Rear)
   309  	lrbuf := make([]byte, lri.Headroom.Front+lrmin+lri.Headroom.Rear)
   310  
   311  	rri := r.ReaderInfo()
   312  	rwi := r.WriterInfo()
   313  	rwmax := rwi.MaxPayloadSizePerWrite
   314  	if rwmax == 0 {
   315  		rwmax = 5
   316  	}
   317  	rrmin := rri.MinPayloadBufferSizePerRead
   318  	if rrmin == 0 {
   319  		rrmin = 5
   320  	}
   321  	rwbuf := make([]byte, rwi.Headroom.Front+rwmax+rwi.Headroom.Rear)
   322  	rrbuf := make([]byte, rri.Headroom.Front+rrmin+rri.Headroom.Rear)
   323  
   324  	ctrlCh := make(chan struct{})
   325  
   326  	// Start read goroutines.
   327  	go func() {
   328  		pl, err := l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   329  		if err != nil {
   330  			t.Error(err)
   331  		}
   332  		if pl != 5 {
   333  			t.Errorf("Expected payloadLen 5, got %d", pl)
   334  		}
   335  		p := lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl]
   336  		if !bytes.Equal(p, world) {
   337  			t.Errorf("Expected payload %v, got %v", world, p)
   338  		}
   339  
   340  		pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   341  		if err != nil {
   342  			t.Error(err)
   343  		}
   344  		if pl != 5 {
   345  			t.Errorf("Expected payloadLen 5, got %d", pl)
   346  		}
   347  		p = lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl]
   348  		if !bytes.Equal(p, hello) {
   349  			t.Errorf("Expected payload %v, got %v", hello, p)
   350  		}
   351  
   352  		pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   353  		if err != io.EOF {
   354  			t.Errorf("Expected io.EOF, got %v", err)
   355  		}
   356  		if pl != 0 {
   357  			t.Errorf("Expected payloadLen 0, got %v", pl)
   358  		}
   359  
   360  		ctrlCh <- struct{}{}
   361  	}()
   362  
   363  	go func() {
   364  		pl, err := r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   365  		if err != nil {
   366  			t.Error(err)
   367  		}
   368  		if pl != 5 {
   369  			t.Errorf("Expected payloadLen 5, got %d", pl)
   370  		}
   371  		p := rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl]
   372  		if !bytes.Equal(p, hello) {
   373  			t.Errorf("Expected payload %v, got %v", hello, p)
   374  		}
   375  
   376  		pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   377  		if err != nil {
   378  			t.Error(err)
   379  		}
   380  		if pl != 5 {
   381  			t.Errorf("Expected payloadLen 5, got %d", pl)
   382  		}
   383  		p = rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl]
   384  		if !bytes.Equal(p, world) {
   385  			t.Errorf("Expected payload %v, got %v", world, p)
   386  		}
   387  
   388  		pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   389  		if err != io.EOF {
   390  			t.Errorf("Expected io.EOF, got %v", err)
   391  		}
   392  		if pl != 0 {
   393  			t.Errorf("Expected payloadLen 0, got %v", pl)
   394  		}
   395  
   396  		ctrlCh <- struct{}{}
   397  	}()
   398  
   399  	// Write from left to right.
   400  	n := copy(lwbuf[lwi.Headroom.Front:], hello)
   401  	written, err := l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n)
   402  	if err != nil {
   403  		t.Error(err)
   404  	}
   405  	if written != n {
   406  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   407  	}
   408  
   409  	n = copy(lwbuf[lwi.Headroom.Front:], world)
   410  	written, err = l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n)
   411  	if err != nil {
   412  		t.Error(err)
   413  	}
   414  	if written != n {
   415  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   416  	}
   417  
   418  	err = l.CloseWrite()
   419  	if err != nil {
   420  		t.Error(err)
   421  	}
   422  
   423  	// Write from right to left.
   424  	n = copy(rwbuf[rwi.Headroom.Front:], world)
   425  	written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n)
   426  	if err != nil {
   427  		t.Error(err)
   428  	}
   429  	if written != n {
   430  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   431  	}
   432  
   433  	n = copy(rwbuf[rwi.Headroom.Front:], hello)
   434  	written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n)
   435  	if err != nil {
   436  		t.Error(err)
   437  	}
   438  	if written != n {
   439  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   440  	}
   441  
   442  	err = r.CloseWrite()
   443  	if err != nil {
   444  		t.Error(err)
   445  	}
   446  
   447  	<-ctrlCh
   448  	<-ctrlCh
   449  }
   450  
   451  // CopyReadWriter wraps a ReadWriter and provides the io.ReadWriter Read and Write methods
   452  // by copying from and to internal buffers and using the zerocopy methods on them.
   453  //
   454  // The io.ReaderFrom ReadFrom method is implemented using the internal write buffer without copying.
   455  type CopyReadWriter struct {
   456  	ReadWriter
   457  
   458  	readHeadroom  Headroom
   459  	writeHeadroom Headroom
   460  
   461  	readBuf       []byte
   462  	readBufStart  int
   463  	readBufLength int
   464  
   465  	writeBuf []byte
   466  }
   467  
   468  func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter {
   469  	ri := rw.ReaderInfo()
   470  	wi := rw.WriterInfo()
   471  
   472  	readBufSize := ri.MinPayloadBufferSizePerRead
   473  	if readBufSize == 0 {
   474  		readBufSize = defaultBufferSize
   475  	}
   476  
   477  	writeBufSize := wi.MaxPayloadSizePerWrite
   478  	if writeBufSize == 0 {
   479  		writeBufSize = defaultBufferSize
   480  	}
   481  
   482  	return &CopyReadWriter{
   483  		ReadWriter:    rw,
   484  		readHeadroom:  ri.Headroom,
   485  		writeHeadroom: wi.Headroom,
   486  		readBuf:       make([]byte, ri.Headroom.Front+readBufSize+ri.Headroom.Front),
   487  		writeBuf:      make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear),
   488  	}
   489  }
   490  
   491  // Read implements the io.Reader Read method.
   492  func (rw *CopyReadWriter) Read(b []byte) (n int, err error) {
   493  	if rw.readBufLength == 0 {
   494  		rw.readBufStart = rw.readHeadroom.Front
   495  		rw.readBufLength = len(rw.readBuf) - rw.readHeadroom.Front - rw.readHeadroom.Rear
   496  		rw.readBufLength, err = rw.ReadWriter.ReadZeroCopy(rw.readBuf, rw.readBufStart, rw.readBufLength)
   497  		if err != nil {
   498  			return
   499  		}
   500  	}
   501  
   502  	n = copy(b, rw.readBuf[rw.readBufStart:rw.readBufStart+rw.readBufLength])
   503  	rw.readBufStart += n
   504  	rw.readBufLength -= n
   505  	return n, nil
   506  }
   507  
   508  // Write implements the io.Writer Write method.
   509  func (rw *CopyReadWriter) Write(b []byte) (n int, err error) {
   510  	payloadBuf := rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear]
   511  
   512  	for n < len(b) {
   513  		payloadLength := copy(payloadBuf, b[n:])
   514  		var payloadWritten int
   515  		payloadWritten, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, payloadLength)
   516  		n += payloadWritten
   517  		if err != nil {
   518  			return
   519  		}
   520  	}
   521  
   522  	return
   523  }
   524  
   525  // ReadFrom implements the io.ReaderFrom ReadFrom method.
   526  func (rw *CopyReadWriter) ReadFrom(r io.Reader) (n int64, err error) {
   527  	for {
   528  		nr, err := r.Read(rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear])
   529  		n += int64(nr)
   530  		switch err {
   531  		case nil:
   532  		case io.EOF:
   533  			return n, nil
   534  		default:
   535  			return n, err
   536  		}
   537  
   538  		_, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, nr)
   539  		if err != nil {
   540  			return n, err
   541  		}
   542  	}
   543  }
   544  
   545  func CopyWriteOnce(w Writer, b []byte) (n int, err error) {
   546  	wi := w.WriterInfo()
   547  	writeBufSize := wi.MaxPayloadSizePerWrite
   548  	if writeBufSize == 0 {
   549  		writeBufSize = defaultBufferSize
   550  	}
   551  	if writeBufSize > len(b) {
   552  		writeBufSize = len(b)
   553  	}
   554  
   555  	writeBuf := make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear)
   556  	payloadBuf := writeBuf[wi.Headroom.Front : wi.Headroom.Front+writeBufSize]
   557  
   558  	for n < len(b) {
   559  		payloadLength := copy(payloadBuf, b[n:])
   560  		var payloadWritten int
   561  		payloadWritten, err = w.WriteZeroCopy(writeBuf, wi.Headroom.Front, payloadLength)
   562  		n += payloadWritten
   563  		if err != nil {
   564  			return
   565  		}
   566  	}
   567  
   568  	return
   569  }