github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/zerocopy/stream.go (about)

     1  package zerocopy
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"sync"
     8  )
     9  
    10  // defaultBufferSize is the default buffer size to use
    11  // when neither the reader nor the writer has buffer size requirements.
    12  // It's the same default as io.Copy.
    13  const defaultBufferSize = 32768
    14  
    15  // ReaderInfo contains information about a reader.
    16  type ReaderInfo struct {
    17  	Headroom Headroom
    18  
    19  	// MinPayloadBufferSizePerRead is the minimum size of payload buffer
    20  	// the ReadZeroCopy method requires for an unbuffered read.
    21  	//
    22  	// This is usually required by chunk-based protocols to be able to read
    23  	// whole chunks without needing internal caching.
    24  	MinPayloadBufferSizePerRead int
    25  }
    26  
    27  // Reader provides a stream interface for reading.
    28  type Reader interface {
    29  	// ReaderInfo returns information about the reader.
    30  	ReaderInfo() ReaderInfo
    31  
    32  	// ReadZeroCopy uses b as buffer space to initiate a read operation.
    33  	//
    34  	// b must have at least [ReaderInfo.Headroom.Front] bytes before payloadBufStart
    35  	// and [ReaderInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen.
    36  	//
    37  	// payloadBufLen must be at least [ReaderInfo.MinPayloadBufferSizePerRead].
    38  	//
    39  	// The read operation may use the whole space of b.
    40  	// The actual payload will be confined in [payloadBufStart, payloadBufLen).
    41  	//
    42  	// If no error occurs, the returned payload is b[payloadBufStart : payloadBufStart+payloadLen].
    43  	ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (payloadLen int, err error)
    44  }
    45  
    46  // WriterInfo contains information about a writer.
    47  type WriterInfo struct {
    48  	Headroom Headroom
    49  
    50  	// MaxPayloadSizePerWrite is the maximum size of payload
    51  	// the WriteZeroCopy method can write at a time.
    52  	//
    53  	// This is usually required by chunk-based protocols to be able to write
    54  	// one chunk at a time without needing to break up the payload.
    55  	//
    56  	// 0 means no size limit.
    57  	MaxPayloadSizePerWrite int
    58  }
    59  
    60  // Writer provides a stream interface for writing.
    61  type Writer interface {
    62  	// WriterInfo returns information about the writer.
    63  	WriterInfo() WriterInfo
    64  
    65  	// WriteZeroCopy uses b as buffer space to initiate a write operation.
    66  	//
    67  	// b must have at least [WriterInfo.Headroom.Front] bytes before payloadBufStart
    68  	// and [WriterInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen.
    69  	//
    70  	// payloadLen must not exceed [WriterInfo.MaxPayloadSizePerWrite].
    71  	//
    72  	// The write operation may use the whole space of b.
    73  	WriteZeroCopy(b []byte, payloadStart, payloadLen int) (payloadWritten int, err error)
    74  }
    75  
    76  // DirectReader provides access to the underlying [io.Reader].
    77  type DirectReader interface {
    78  	// DirectReader returns the underlying reader for direct reads.
    79  	DirectReader() io.Reader
    80  }
    81  
    82  // DirectWriter provides access to the underlying [io.Writer].
    83  type DirectWriter interface {
    84  	// DirectWriter returns the underlying writer for direct writes.
    85  	DirectWriter() io.Writer
    86  }
    87  
    88  // Relay reads from r and writes to w using zero-copy methods.
    89  // It returns the number of bytes transferred, and any error occurred during transfer.
    90  func Relay(w Writer, r Reader) (n int64, err error) {
    91  	// Use direct read/write when possible.
    92  	if dr, ok := r.(DirectReader); ok {
    93  		if dw, ok := w.(DirectWriter); ok {
    94  			r := dr.DirectReader()
    95  			w := dw.DirectWriter()
    96  			return io.Copy(w, r)
    97  		}
    98  	}
    99  
   100  	// Process reader and writer info.
   101  	ri := r.ReaderInfo()
   102  	wi := w.WriterInfo()
   103  	headroom := MaxHeadroom(ri.Headroom, wi.Headroom)
   104  
   105  	// Check payload buffer size requirement compatibility.
   106  	if wi.MaxPayloadSizePerWrite > 0 && ri.MinPayloadBufferSizePerRead > wi.MaxPayloadSizePerWrite {
   107  		return relayFallback(w, r, headroom.Front, headroom.Rear, ri.MinPayloadBufferSizePerRead, wi.MaxPayloadSizePerWrite)
   108  	}
   109  
   110  	payloadBufSize := ri.MinPayloadBufferSizePerRead
   111  	if payloadBufSize == 0 {
   112  		payloadBufSize = wi.MaxPayloadSizePerWrite
   113  		if payloadBufSize == 0 {
   114  			payloadBufSize = defaultBufferSize
   115  		}
   116  	}
   117  
   118  	// Make buffer.
   119  	b := make([]byte, headroom.Front+payloadBufSize+headroom.Rear)
   120  
   121  	// Main relay loop.
   122  	for {
   123  		var payloadLen int
   124  		payloadLen, err = r.ReadZeroCopy(b, headroom.Front, payloadBufSize)
   125  		if payloadLen == 0 {
   126  			if err == io.EOF {
   127  				err = nil
   128  			}
   129  			return
   130  		}
   131  
   132  		payloadWritten, werr := w.WriteZeroCopy(b, headroom.Front, payloadLen)
   133  		n += int64(payloadWritten)
   134  		if werr != nil {
   135  			err = werr
   136  			return
   137  		}
   138  
   139  		if err != nil {
   140  			if err == io.EOF {
   141  				err = nil
   142  			}
   143  			return
   144  		}
   145  	}
   146  }
   147  
   148  // relayFallback uses copying to handle situations where the reader requires more payload buffer space than the writer can handle in one write call.
   149  func relayFallback(w Writer, r Reader, frontHeadroom, rearHeadroom, readMaxPayloadSize, writeMaxPayloadSize int) (n int64, err error) {
   150  	br := make([]byte, frontHeadroom+readMaxPayloadSize+rearHeadroom)
   151  	bw := make([]byte, frontHeadroom+writeMaxPayloadSize+rearHeadroom)
   152  
   153  	for {
   154  		var payloadLen int
   155  		payloadLen, err = r.ReadZeroCopy(br, frontHeadroom, readMaxPayloadSize)
   156  		if payloadLen == 0 {
   157  			if err == io.EOF {
   158  				err = nil
   159  			}
   160  			return
   161  		}
   162  
   163  		// Short-circuit to avoid copying if payload can fit in one write.
   164  		if payloadLen <= writeMaxPayloadSize {
   165  			payloadWritten, werr := w.WriteZeroCopy(br, frontHeadroom, payloadLen)
   166  			n += int64(payloadWritten)
   167  			if werr != nil {
   168  				err = werr
   169  			}
   170  			if err != nil {
   171  				return
   172  			}
   173  			continue
   174  		}
   175  
   176  		// Loop until all of br[frontHeadroom : frontHeadroom+payloadLen] is written.
   177  		for i, j := 0, 0; i < payloadLen; i += j {
   178  			j = copy(bw[frontHeadroom:frontHeadroom+writeMaxPayloadSize], br[frontHeadroom+i:frontHeadroom+payloadLen])
   179  			payloadWritten, werr := w.WriteZeroCopy(bw, frontHeadroom, j)
   180  			n += int64(payloadWritten)
   181  			if werr != nil {
   182  				err = werr
   183  				return
   184  			}
   185  		}
   186  
   187  		if err != nil {
   188  			if err == io.EOF {
   189  				err = nil
   190  			}
   191  			return
   192  		}
   193  	}
   194  }
   195  
   196  // CloseRead provides the CloseRead method.
   197  type CloseRead interface {
   198  	// CloseRead indicates to the underlying reader that no further reads will happen.
   199  	CloseRead() error
   200  }
   201  
   202  // CloseWrite provides the CloseWrite method.
   203  type CloseWrite interface {
   204  	// CloseWrite indicates to the underlying writer that no further writes will happen.
   205  	CloseWrite() error
   206  }
   207  
   208  // ReadWriter provides a stream interface for reading and writing.
   209  type ReadWriter interface {
   210  	Reader
   211  	Writer
   212  	CloseRead
   213  	CloseWrite
   214  	io.Closer
   215  }
   216  
   217  // TwoWayRelay relays data between left and right using zero-copy methods.
   218  // It returns the number of bytes sent from left to right, from right to left,
   219  // and any error occurred during transfer.
   220  func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) {
   221  	var l2rErr error
   222  	l2rDone := make(chan struct{})
   223  
   224  	go func() {
   225  		nl2r, l2rErr = Relay(right, left)
   226  		right.CloseWrite()
   227  		close(l2rDone)
   228  	}()
   229  
   230  	nr2l, err = Relay(left, right)
   231  	left.CloseWrite()
   232  	<-l2rDone
   233  
   234  	if l2rErr != nil {
   235  		err = l2rErr
   236  	}
   237  	return
   238  }
   239  
   240  // DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite.
   241  type DirectReadWriteCloser interface {
   242  	io.ReadWriteCloser
   243  	CloseRead
   244  	CloseWrite
   245  }
   246  
   247  // DirectTwoWayRelay relays data between left and right using [io.Copy].
   248  // It returns the number of bytes sent from left to right, from right to left,
   249  // and any error occurred during transfer.
   250  func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) {
   251  	var l2rErr error
   252  	l2rDone := make(chan struct{})
   253  
   254  	go func() {
   255  		nl2r, l2rErr = io.Copy(right, left)
   256  		right.CloseWrite()
   257  		close(l2rDone)
   258  	}()
   259  
   260  	nr2l, err = io.Copy(left, right)
   261  	left.CloseWrite()
   262  	<-l2rDone
   263  
   264  	if l2rErr != nil {
   265  		err = l2rErr
   266  	}
   267  	return
   268  }
   269  
   270  // DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser].
   271  type DirectReadWriteCloserOpener interface {
   272  	// Open opens a [DirectReadWriteCloser] with the specified initial payload.
   273  	Open(ctx context.Context, b []byte) (DirectReadWriteCloser, error)
   274  }
   275  
   276  // SimpleDirectReadWriteCloserOpener wraps a [DirectReadWriteCloser] for the Open method to return.
   277  type SimpleDirectReadWriteCloserOpener struct {
   278  	DirectReadWriteCloser
   279  }
   280  
   281  // Open implements the DirectReadWriteCloserOpener Open method.
   282  func (o *SimpleDirectReadWriteCloserOpener) Open(ctx context.Context, b []byte) (DirectReadWriteCloser, error) {
   283  	_, err := o.DirectReadWriteCloser.Write(b)
   284  	return o.DirectReadWriteCloser, err
   285  }
   286  
   287  // ReadWriterTestFunc tests the left and right ReadWriters by performing 2 writes
   288  // on each ReadWriter and validating the read results.
   289  //
   290  // The left and right ReadWriters must be connected with a duplex pipe.
   291  func ReadWriterTestFunc(t tester, l, r ReadWriter) {
   292  	defer r.Close()
   293  	defer l.Close()
   294  
   295  	var (
   296  		hello = []byte{'h', 'e', 'l', 'l', 'o'}
   297  		world = []byte{'w', 'o', 'r', 'l', 'd'}
   298  	)
   299  
   300  	lri := l.ReaderInfo()
   301  	lwi := l.WriterInfo()
   302  	lwmax := lwi.MaxPayloadSizePerWrite
   303  	if lwmax == 0 {
   304  		lwmax = 5
   305  	}
   306  	lrmin := lri.MinPayloadBufferSizePerRead
   307  	if lrmin == 0 {
   308  		lrmin = 5
   309  	}
   310  	lwbuf := make([]byte, lwi.Headroom.Front+lwmax+lwi.Headroom.Rear)
   311  	lrbuf := make([]byte, lri.Headroom.Front+lrmin+lri.Headroom.Rear)
   312  
   313  	rri := r.ReaderInfo()
   314  	rwi := r.WriterInfo()
   315  	rwmax := rwi.MaxPayloadSizePerWrite
   316  	if rwmax == 0 {
   317  		rwmax = 5
   318  	}
   319  	rrmin := rri.MinPayloadBufferSizePerRead
   320  	if rrmin == 0 {
   321  		rrmin = 5
   322  	}
   323  	rwbuf := make([]byte, rwi.Headroom.Front+rwmax+rwi.Headroom.Rear)
   324  	rrbuf := make([]byte, rri.Headroom.Front+rrmin+rri.Headroom.Rear)
   325  
   326  	var wg sync.WaitGroup
   327  	wg.Add(2)
   328  
   329  	// Start read goroutines.
   330  	go func() {
   331  		defer wg.Done()
   332  
   333  		pl, err := l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   334  		if err != nil {
   335  			t.Error(err)
   336  		}
   337  		if pl != 5 {
   338  			t.Errorf("Expected payloadLen 5, got %d", pl)
   339  		}
   340  		p := lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl]
   341  		if !bytes.Equal(p, world) {
   342  			t.Errorf("Expected payload %v, got %v", world, p)
   343  		}
   344  
   345  		pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   346  		if err != nil {
   347  			t.Error(err)
   348  		}
   349  		if pl != 5 {
   350  			t.Errorf("Expected payloadLen 5, got %d", pl)
   351  		}
   352  		p = lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl]
   353  		if !bytes.Equal(p, hello) {
   354  			t.Errorf("Expected payload %v, got %v", hello, p)
   355  		}
   356  
   357  		pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin)
   358  		if err != io.EOF {
   359  			t.Errorf("Expected io.EOF, got %v", err)
   360  		}
   361  		if pl != 0 {
   362  			t.Errorf("Expected payloadLen 0, got %v", pl)
   363  		}
   364  	}()
   365  
   366  	go func() {
   367  		defer wg.Done()
   368  
   369  		pl, err := r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   370  		if err != nil {
   371  			t.Error(err)
   372  		}
   373  		if pl != 5 {
   374  			t.Errorf("Expected payloadLen 5, got %d", pl)
   375  		}
   376  		p := rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl]
   377  		if !bytes.Equal(p, hello) {
   378  			t.Errorf("Expected payload %v, got %v", hello, p)
   379  		}
   380  
   381  		pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   382  		if err != nil {
   383  			t.Error(err)
   384  		}
   385  		if pl != 5 {
   386  			t.Errorf("Expected payloadLen 5, got %d", pl)
   387  		}
   388  		p = rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl]
   389  		if !bytes.Equal(p, world) {
   390  			t.Errorf("Expected payload %v, got %v", world, p)
   391  		}
   392  
   393  		pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin)
   394  		if err != io.EOF {
   395  			t.Errorf("Expected io.EOF, got %v", err)
   396  		}
   397  		if pl != 0 {
   398  			t.Errorf("Expected payloadLen 0, got %v", pl)
   399  		}
   400  	}()
   401  
   402  	// Write from left to right.
   403  	n := copy(lwbuf[lwi.Headroom.Front:], hello)
   404  	written, err := l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n)
   405  	if err != nil {
   406  		t.Error(err)
   407  	}
   408  	if written != n {
   409  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   410  	}
   411  
   412  	n = copy(lwbuf[lwi.Headroom.Front:], world)
   413  	written, err = l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n)
   414  	if err != nil {
   415  		t.Error(err)
   416  	}
   417  	if written != n {
   418  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   419  	}
   420  
   421  	err = l.CloseWrite()
   422  	if err != nil {
   423  		t.Error(err)
   424  	}
   425  
   426  	// Write from right to left.
   427  	n = copy(rwbuf[rwi.Headroom.Front:], world)
   428  	written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n)
   429  	if err != nil {
   430  		t.Error(err)
   431  	}
   432  	if written != n {
   433  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   434  	}
   435  
   436  	n = copy(rwbuf[rwi.Headroom.Front:], hello)
   437  	written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n)
   438  	if err != nil {
   439  		t.Error(err)
   440  	}
   441  	if written != n {
   442  		t.Errorf("Expected bytes written: %d, got %d", n, written)
   443  	}
   444  
   445  	err = r.CloseWrite()
   446  	if err != nil {
   447  		t.Error(err)
   448  	}
   449  
   450  	wg.Wait()
   451  }
   452  
   453  // CopyReadWriter wraps a ReadWriter and provides the io.ReadWriter Read and Write methods
   454  // by copying from and to internal buffers and using the zerocopy methods on them.
   455  //
   456  // The io.ReaderFrom ReadFrom method is implemented using the internal write buffer without copying.
   457  type CopyReadWriter struct {
   458  	ReadWriter
   459  
   460  	readHeadroom  Headroom
   461  	writeHeadroom Headroom
   462  
   463  	readBuf       []byte
   464  	readBufStart  int
   465  	readBufLength int
   466  
   467  	writeBuf []byte
   468  }
   469  
   470  func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter {
   471  	ri := rw.ReaderInfo()
   472  	wi := rw.WriterInfo()
   473  
   474  	readBufSize := ri.MinPayloadBufferSizePerRead
   475  	if readBufSize == 0 {
   476  		readBufSize = defaultBufferSize
   477  	}
   478  
   479  	writeBufSize := wi.MaxPayloadSizePerWrite
   480  	if writeBufSize == 0 {
   481  		writeBufSize = defaultBufferSize
   482  	}
   483  
   484  	return &CopyReadWriter{
   485  		ReadWriter:    rw,
   486  		readHeadroom:  ri.Headroom,
   487  		writeHeadroom: wi.Headroom,
   488  		readBuf:       make([]byte, ri.Headroom.Front+readBufSize+ri.Headroom.Front),
   489  		writeBuf:      make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear),
   490  	}
   491  }
   492  
   493  // Read implements the io.Reader Read method.
   494  func (rw *CopyReadWriter) Read(b []byte) (n int, err error) {
   495  	if rw.readBufLength == 0 {
   496  		rw.readBufStart = rw.readHeadroom.Front
   497  		rw.readBufLength = len(rw.readBuf) - rw.readHeadroom.Front - rw.readHeadroom.Rear
   498  		rw.readBufLength, err = rw.ReadWriter.ReadZeroCopy(rw.readBuf, rw.readBufStart, rw.readBufLength)
   499  		if err != nil {
   500  			return
   501  		}
   502  	}
   503  
   504  	n = copy(b, rw.readBuf[rw.readBufStart:rw.readBufStart+rw.readBufLength])
   505  	rw.readBufStart += n
   506  	rw.readBufLength -= n
   507  	return n, nil
   508  }
   509  
   510  // Write implements the io.Writer Write method.
   511  func (rw *CopyReadWriter) Write(b []byte) (n int, err error) {
   512  	payloadBuf := rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear]
   513  
   514  	for n < len(b) {
   515  		payloadLength := copy(payloadBuf, b[n:])
   516  		var payloadWritten int
   517  		payloadWritten, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, payloadLength)
   518  		n += payloadWritten
   519  		if err != nil {
   520  			return
   521  		}
   522  	}
   523  
   524  	return
   525  }
   526  
   527  // ReadFrom implements the io.ReaderFrom ReadFrom method.
   528  func (rw *CopyReadWriter) ReadFrom(r io.Reader) (n int64, err error) {
   529  	for {
   530  		nr, err := r.Read(rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear])
   531  		n += int64(nr)
   532  		switch err {
   533  		case nil:
   534  		case io.EOF:
   535  			return n, nil
   536  		default:
   537  			return n, err
   538  		}
   539  
   540  		_, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, nr)
   541  		if err != nil {
   542  			return n, err
   543  		}
   544  	}
   545  }
   546  
   547  func CopyWriteOnce(w Writer, b []byte) (n int, err error) {
   548  	wi := w.WriterInfo()
   549  	writeBufSize := wi.MaxPayloadSizePerWrite
   550  	if writeBufSize == 0 {
   551  		writeBufSize = defaultBufferSize
   552  	}
   553  	if writeBufSize > len(b) {
   554  		writeBufSize = len(b)
   555  	}
   556  
   557  	writeBuf := make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear)
   558  	payloadBuf := writeBuf[wi.Headroom.Front : wi.Headroom.Front+writeBufSize]
   559  
   560  	for n < len(b) {
   561  		payloadLength := copy(payloadBuf, b[n:])
   562  		var payloadWritten int
   563  		payloadWritten, err = w.WriteZeroCopy(writeBuf, wi.Headroom.Front, payloadLength)
   564  		n += payloadWritten
   565  		if err != nil {
   566  			return
   567  		}
   568  	}
   569  
   570  	return
   571  }