github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/safemem/io.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safemem
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  	"math"
    21  )
    22  
    23  // ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write
    24  // beyond the end of the BlockSeq.
    25  var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq")
    26  
    27  // Reader represents a streaming byte source like io.Reader.
    28  type Reader interface {
    29  	// ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the
    30  	// number of bytes read. It may return a partial read without an error
    31  	// (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a
    32  	// full read with an error (i.e. (dsts.NumBytes(), err) where err != nil);
    33  	// note that this differs from io.Reader.Read (in particular, io.EOF should
    34  	// not be returned if ReadToBlocks successfully reads dsts.NumBytes()
    35  	// bytes.)
    36  	ReadToBlocks(dsts BlockSeq) (uint64, error)
    37  }
    38  
    39  // Writer represents a streaming byte sink like io.Writer.
    40  type Writer interface {
    41  	// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
    42  	// the number of bytes written. It may return a partial write without an
    43  	// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
    44  	// return a full write with an error (i.e. srcs.NumBytes(), err) where err
    45  	// != nil).
    46  	WriteFromBlocks(srcs BlockSeq) (uint64, error)
    47  }
    48  
    49  // ReadFullToBlocks repeatedly invokes r until dsts.NumBytes() bytes have been
    50  // read or r returns an error. Note that we avoid a Reader interface receiver
    51  // to avoid heap allocation.
    52  func ReadFullToBlocks(r ReaderFunc, dsts BlockSeq) (uint64, error) {
    53  	var done uint64
    54  	for !dsts.IsEmpty() {
    55  		n, err := r(dsts)
    56  		done += n
    57  		if err != nil {
    58  			return done, err
    59  		}
    60  		dsts = dsts.DropFirst64(n)
    61  	}
    62  	return done, nil
    63  }
    64  
    65  // WriteFullFromBlocks repeatedly invokes w until srcs.NumBytes() bytes have
    66  // been written or w returns an error. Note that we avoid a Writer interface
    67  // receiver to avoid heap allocation.
    68  func WriteFullFromBlocks(w WriterFunc, srcs BlockSeq) (uint64, error) {
    69  	var done uint64
    70  	for !srcs.IsEmpty() {
    71  		n, err := w(srcs)
    72  		done += n
    73  		if err != nil {
    74  			return done, err
    75  		}
    76  		srcs = srcs.DropFirst64(n)
    77  	}
    78  	return done, nil
    79  }
    80  
    81  // BlockSeqReader implements Reader by reading from a BlockSeq.
    82  type BlockSeqReader struct {
    83  	Blocks BlockSeq
    84  }
    85  
    86  // ReadToBlocks implements Reader.ReadToBlocks.
    87  func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
    88  	n, err := CopySeq(dsts, r.Blocks)
    89  	r.Blocks = r.Blocks.DropFirst64(n)
    90  	if err != nil {
    91  		return n, err
    92  	}
    93  	if n < dsts.NumBytes() {
    94  		return n, io.EOF
    95  	}
    96  	return n, nil
    97  }
    98  
    99  // BlockSeqWriter implements Writer by writing to a BlockSeq.
   100  type BlockSeqWriter struct {
   101  	Blocks BlockSeq
   102  }
   103  
   104  // WriteFromBlocks implements Writer.WriteFromBlocks.
   105  func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   106  	n, err := CopySeq(w.Blocks, srcs)
   107  	w.Blocks = w.Blocks.DropFirst64(n)
   108  	if err != nil {
   109  		return n, err
   110  	}
   111  	if n < srcs.NumBytes() {
   112  		return n, ErrEndOfBlockSeq
   113  	}
   114  	return n, nil
   115  }
   116  
   117  // ReaderFunc implements Reader for a function with the semantics of
   118  // Reader.ReadToBlocks.
   119  type ReaderFunc func(dsts BlockSeq) (uint64, error)
   120  
   121  // ReadToBlocks implements Reader.ReadToBlocks.
   122  func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   123  	return f(dsts)
   124  }
   125  
   126  // WriterFunc implements Writer for a function with the semantics of
   127  // Writer.WriteFromBlocks.
   128  type WriterFunc func(srcs BlockSeq) (uint64, error)
   129  
   130  // WriteFromBlocks implements Writer.WriteFromBlocks.
   131  func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   132  	return f(srcs)
   133  }
   134  
   135  // ToIOReader implements io.Reader for a (safemem.)Reader.
   136  //
   137  // ToIOReader will return a successful partial read iff Reader.ReadToBlocks does
   138  // so.
   139  type ToIOReader struct {
   140  	Reader Reader
   141  }
   142  
   143  // Read implements io.Reader.Read.
   144  func (r ToIOReader) Read(dst []byte) (int, error) {
   145  	n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst)))
   146  	return int(n), err
   147  }
   148  
   149  // FromIOReader implements Reader for an io.Reader by repeatedly invoking
   150  // io.Reader.Read until it returns an error or partial read. This is not
   151  // thread-safe.
   152  //
   153  // FromIOReader will return a successful partial read iff Reader.Read does so.
   154  type FromIOReader struct {
   155  	Reader io.Reader
   156  }
   157  
   158  // ReadToBlocks implements Reader.ReadToBlocks.
   159  func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   160  	var buf []byte
   161  	var done uint64
   162  	for !dsts.IsEmpty() {
   163  		dst := dsts.Head()
   164  		var n int
   165  		var err error
   166  		n, buf, err = r.readToBlock(dst, buf)
   167  		done += uint64(n)
   168  		if n != dst.Len() {
   169  			return done, err
   170  		}
   171  		dsts = dsts.Tail()
   172  		if err != nil {
   173  			if dsts.IsEmpty() && err == io.EOF {
   174  				return done, nil
   175  			}
   176  			return done, err
   177  		}
   178  	}
   179  	return done, nil
   180  }
   181  
   182  func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
   183  	// io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
   184  	// safecopy.
   185  	if !dst.NeedSafecopy() {
   186  		n, err := r.Reader.Read(dst.ToSlice())
   187  		return n, buf, err
   188  	}
   189  	if len(buf) < dst.Len() {
   190  		buf = make([]byte, dst.Len())
   191  	}
   192  	rn, rerr := r.Reader.Read(buf[:dst.Len()])
   193  	wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
   194  	if wberr != nil {
   195  		return wbn, buf, wberr
   196  	}
   197  	return wbn, buf, rerr
   198  }
   199  
   200  // FromIOWriter implements Writer for an io.Writer by repeatedly invoking
   201  // io.Writer.Write until it returns an error or partial write.
   202  //
   203  // FromIOWriter will tolerate implementations of io.Writer.Write that return
   204  // partial writes with a nil error in contravention of io.Writer's
   205  // requirements, since Writer is permitted to do so. FromIOWriter will return a
   206  // successful partial write iff Writer.Write does so.
   207  type FromIOWriter struct {
   208  	Writer io.Writer
   209  }
   210  
   211  // WriteFromBlocks implements Writer.WriteFromBlocks.
   212  func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   213  	var buf []byte
   214  	var done uint64
   215  	for !srcs.IsEmpty() {
   216  		src := srcs.Head()
   217  		var n int
   218  		var err error
   219  		n, buf, err = w.writeFromBlock(src, buf)
   220  		done += uint64(n)
   221  		if n != src.Len() || err != nil {
   222  			return done, err
   223  		}
   224  		srcs = srcs.Tail()
   225  	}
   226  	return done, nil
   227  }
   228  
   229  func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) {
   230  	// io.Writer isn't safecopy-aware, so we have to buffer Blocks that require
   231  	// safecopy.
   232  	if !src.NeedSafecopy() {
   233  		n, err := w.Writer.Write(src.ToSlice())
   234  		return n, buf, err
   235  	}
   236  	if len(buf) < src.Len() {
   237  		buf = make([]byte, src.Len())
   238  	}
   239  	bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src)
   240  	wn, werr := w.Writer.Write(buf[:bufn])
   241  	if werr != nil {
   242  		return wn, buf, werr
   243  	}
   244  	return wn, buf, buferr
   245  }
   246  
   247  // FromVecReaderFunc implements Reader for a function that reads data into a
   248  // [][]byte and returns the number of bytes read as an int64.
   249  type FromVecReaderFunc struct {
   250  	ReadVec func(dsts [][]byte) (int64, error)
   251  }
   252  
   253  // ReadToBlocks implements Reader.ReadToBlocks.
   254  //
   255  // ReadToBlocks calls r.ReadVec at most once.
   256  func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   257  	if dsts.IsEmpty() {
   258  		return 0, nil
   259  	}
   260  	// Ensure that we don't pass a [][]byte with a total length > MaxInt64.
   261  	dsts = dsts.TakeFirst64(uint64(math.MaxInt64))
   262  	dstSlices := make([][]byte, 0, dsts.NumBlocks())
   263  	// Buffer Blocks that require safecopy.
   264  	for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() {
   265  		dst := tmp.Head()
   266  		if dst.NeedSafecopy() {
   267  			dstSlices = append(dstSlices, make([]byte, dst.Len()))
   268  		} else {
   269  			dstSlices = append(dstSlices, dst.ToSlice())
   270  		}
   271  	}
   272  	rn, rerr := r.ReadVec(dstSlices)
   273  	dsts = dsts.TakeFirst64(uint64(rn))
   274  	var done uint64
   275  	var i int
   276  	for !dsts.IsEmpty() {
   277  		dst := dsts.Head()
   278  		if dst.NeedSafecopy() {
   279  			n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i]))
   280  			done += uint64(n)
   281  			if err != nil {
   282  				return done, err
   283  			}
   284  		} else {
   285  			done += uint64(dst.Len())
   286  		}
   287  		dsts = dsts.Tail()
   288  		i++
   289  	}
   290  	return done, rerr
   291  }
   292  
   293  // FromVecWriterFunc implements Writer for a function that writes data from a
   294  // [][]byte and returns the number of bytes written.
   295  type FromVecWriterFunc struct {
   296  	WriteVec func(srcs [][]byte) (int64, error)
   297  }
   298  
   299  // WriteFromBlocks implements Writer.WriteFromBlocks.
   300  //
   301  // WriteFromBlocks calls w.WriteVec at most once.
   302  func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   303  	if srcs.IsEmpty() {
   304  		return 0, nil
   305  	}
   306  	// Ensure that we don't pass a [][]byte with a total length > MaxInt64.
   307  	srcs = srcs.TakeFirst64(uint64(math.MaxInt64))
   308  	srcSlices := make([][]byte, 0, srcs.NumBlocks())
   309  	// Buffer Blocks that require safecopy.
   310  	var buferr error
   311  	for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() {
   312  		src := tmp.Head()
   313  		if src.NeedSafecopy() {
   314  			slice := make([]byte, src.Len())
   315  			n, err := Copy(BlockFromSafeSlice(slice), src)
   316  			srcSlices = append(srcSlices, slice[:n])
   317  			if err != nil {
   318  				buferr = err
   319  				break
   320  			}
   321  		} else {
   322  			srcSlices = append(srcSlices, src.ToSlice())
   323  		}
   324  	}
   325  	n, err := w.WriteVec(srcSlices)
   326  	if err != nil {
   327  		return uint64(n), err
   328  	}
   329  	return uint64(n), buferr
   330  }