github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/safemem/io.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safemem
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  	"math"
    21  )
    22  
    23  // ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write
    24  // beyond the end of the BlockSeq.
    25  var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq")
    26  
    27  // Reader represents a streaming byte source like io.Reader.
    28  type Reader interface {
    29  	// ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the
    30  	// number of bytes read. It may return a partial read without an error
    31  	// (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a
    32  	// full read with an error (i.e. (dsts.NumBytes(), err) where err != nil);
    33  	// note that this differs from io.Reader.Read (in particular, io.EOF should
    34  	// not be returned if ReadToBlocks successfully reads dsts.NumBytes()
    35  	// bytes.)
    36  	ReadToBlocks(dsts BlockSeq) (uint64, error)
    37  }
    38  
    39  // Writer represents a streaming byte sink like io.Writer.
    40  type Writer interface {
    41  	// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
    42  	// the number of bytes written. It may return a partial write without an
    43  	// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
    44  	// return a full write with an error (i.e. srcs.NumBytes(), err) where err
    45  	// != nil).
    46  	WriteFromBlocks(srcs BlockSeq) (uint64, error)
    47  }
    48  
    49  // ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes()
    50  // bytes have been read or ReadToBlocks returns an error.
    51  func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) {
    52  	var done uint64
    53  	for !dsts.IsEmpty() {
    54  		n, err := r.ReadToBlocks(dsts)
    55  		done += n
    56  		if err != nil {
    57  			return done, err
    58  		}
    59  		dsts = dsts.DropFirst64(n)
    60  	}
    61  	return done, nil
    62  }
    63  
    64  // WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until
    65  // srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error.
    66  func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) {
    67  	var done uint64
    68  	for !srcs.IsEmpty() {
    69  		n, err := w.WriteFromBlocks(srcs)
    70  		done += n
    71  		if err != nil {
    72  			return done, err
    73  		}
    74  		srcs = srcs.DropFirst64(n)
    75  	}
    76  	return done, nil
    77  }
    78  
    79  // BlockSeqReader implements Reader by reading from a BlockSeq.
    80  type BlockSeqReader struct {
    81  	Blocks BlockSeq
    82  }
    83  
    84  // ReadToBlocks implements Reader.ReadToBlocks.
    85  func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
    86  	n, err := CopySeq(dsts, r.Blocks)
    87  	r.Blocks = r.Blocks.DropFirst64(n)
    88  	if err != nil {
    89  		return n, err
    90  	}
    91  	if n < dsts.NumBytes() {
    92  		return n, io.EOF
    93  	}
    94  	return n, nil
    95  }
    96  
    97  // BlockSeqWriter implements Writer by writing to a BlockSeq.
    98  type BlockSeqWriter struct {
    99  	Blocks BlockSeq
   100  }
   101  
   102  // WriteFromBlocks implements Writer.WriteFromBlocks.
   103  func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   104  	n, err := CopySeq(w.Blocks, srcs)
   105  	w.Blocks = w.Blocks.DropFirst64(n)
   106  	if err != nil {
   107  		return n, err
   108  	}
   109  	if n < srcs.NumBytes() {
   110  		return n, ErrEndOfBlockSeq
   111  	}
   112  	return n, nil
   113  }
   114  
   115  // ReaderFunc implements Reader for a function with the semantics of
   116  // Reader.ReadToBlocks.
   117  type ReaderFunc func(dsts BlockSeq) (uint64, error)
   118  
   119  // ReadToBlocks implements Reader.ReadToBlocks.
   120  func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   121  	return f(dsts)
   122  }
   123  
   124  // WriterFunc implements Writer for a function with the semantics of
   125  // Writer.WriteFromBlocks.
   126  type WriterFunc func(srcs BlockSeq) (uint64, error)
   127  
   128  // WriteFromBlocks implements Writer.WriteFromBlocks.
   129  func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   130  	return f(srcs)
   131  }
   132  
   133  // ToIOReader implements io.Reader for a (safemem.)Reader.
   134  //
   135  // ToIOReader will return a successful partial read iff Reader.ReadToBlocks does
   136  // so.
   137  type ToIOReader struct {
   138  	Reader Reader
   139  }
   140  
   141  // Read implements io.Reader.Read.
   142  func (r ToIOReader) Read(dst []byte) (int, error) {
   143  	n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst)))
   144  	return int(n), err
   145  }
   146  
   147  // ToIOWriter implements io.Writer for a (safemem.)Writer.
   148  type ToIOWriter struct {
   149  	Writer Writer
   150  }
   151  
   152  // Write implements io.Writer.Write.
   153  func (w ToIOWriter) Write(src []byte) (int, error) {
   154  	// io.Writer does not permit partial writes.
   155  	n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src)))
   156  	return int(n), err
   157  }
   158  
   159  // FromIOReader implements Reader for an io.Reader by repeatedly invoking
   160  // io.Reader.Read until it returns an error or partial read. This is not
   161  // thread-safe.
   162  //
   163  // FromIOReader will return a successful partial read iff Reader.Read does so.
   164  type FromIOReader struct {
   165  	Reader io.Reader
   166  }
   167  
   168  // ReadToBlocks implements Reader.ReadToBlocks.
   169  func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   170  	var buf []byte
   171  	var done uint64
   172  	for !dsts.IsEmpty() {
   173  		dst := dsts.Head()
   174  		var n int
   175  		var err error
   176  		n, buf, err = r.readToBlock(dst, buf)
   177  		done += uint64(n)
   178  		if n != dst.Len() {
   179  			return done, err
   180  		}
   181  		dsts = dsts.Tail()
   182  		if err != nil {
   183  			if dsts.IsEmpty() && err == io.EOF {
   184  				return done, nil
   185  			}
   186  			return done, err
   187  		}
   188  	}
   189  	return done, nil
   190  }
   191  
   192  func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) {
   193  	// io.Reader isn't safecopy-aware, so we have to buffer Blocks that require
   194  	// safecopy.
   195  	if !dst.NeedSafecopy() {
   196  		n, err := r.Reader.Read(dst.ToSlice())
   197  		return n, buf, err
   198  	}
   199  	if len(buf) < dst.Len() {
   200  		buf = make([]byte, dst.Len())
   201  	}
   202  	rn, rerr := r.Reader.Read(buf[:dst.Len()])
   203  	wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn]))
   204  	if wberr != nil {
   205  		return wbn, buf, wberr
   206  	}
   207  	return wbn, buf, rerr
   208  }
   209  
   210  // FromIOWriter implements Writer for an io.Writer by repeatedly invoking
   211  // io.Writer.Write until it returns an error or partial write.
   212  //
   213  // FromIOWriter will tolerate implementations of io.Writer.Write that return
   214  // partial writes with a nil error in contravention of io.Writer's
   215  // requirements, since Writer is permitted to do so. FromIOWriter will return a
   216  // successful partial write iff Writer.Write does so.
   217  type FromIOWriter struct {
   218  	Writer io.Writer
   219  }
   220  
   221  // WriteFromBlocks implements Writer.WriteFromBlocks.
   222  func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   223  	var buf []byte
   224  	var done uint64
   225  	for !srcs.IsEmpty() {
   226  		src := srcs.Head()
   227  		var n int
   228  		var err error
   229  		n, buf, err = w.writeFromBlock(src, buf)
   230  		done += uint64(n)
   231  		if n != src.Len() || err != nil {
   232  			return done, err
   233  		}
   234  		srcs = srcs.Tail()
   235  	}
   236  	return done, nil
   237  }
   238  
   239  func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) {
   240  	// io.Writer isn't safecopy-aware, so we have to buffer Blocks that require
   241  	// safecopy.
   242  	if !src.NeedSafecopy() {
   243  		n, err := w.Writer.Write(src.ToSlice())
   244  		return n, buf, err
   245  	}
   246  	if len(buf) < src.Len() {
   247  		buf = make([]byte, src.Len())
   248  	}
   249  	bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src)
   250  	wn, werr := w.Writer.Write(buf[:bufn])
   251  	if werr != nil {
   252  		return wn, buf, werr
   253  	}
   254  	return wn, buf, buferr
   255  }
   256  
   257  // FromVecReaderFunc implements Reader for a function that reads data into a
   258  // [][]byte and returns the number of bytes read as an int64.
   259  type FromVecReaderFunc struct {
   260  	ReadVec func(dsts [][]byte) (int64, error)
   261  }
   262  
   263  // ReadToBlocks implements Reader.ReadToBlocks.
   264  //
   265  // ReadToBlocks calls r.ReadVec at most once.
   266  func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) {
   267  	if dsts.IsEmpty() {
   268  		return 0, nil
   269  	}
   270  	// Ensure that we don't pass a [][]byte with a total length > MaxInt64.
   271  	dsts = dsts.TakeFirst64(uint64(math.MaxInt64))
   272  	dstSlices := make([][]byte, 0, dsts.NumBlocks())
   273  	// Buffer Blocks that require safecopy.
   274  	for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() {
   275  		dst := tmp.Head()
   276  		if dst.NeedSafecopy() {
   277  			dstSlices = append(dstSlices, make([]byte, dst.Len()))
   278  		} else {
   279  			dstSlices = append(dstSlices, dst.ToSlice())
   280  		}
   281  	}
   282  	rn, rerr := r.ReadVec(dstSlices)
   283  	dsts = dsts.TakeFirst64(uint64(rn))
   284  	var done uint64
   285  	var i int
   286  	for !dsts.IsEmpty() {
   287  		dst := dsts.Head()
   288  		if dst.NeedSafecopy() {
   289  			n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i]))
   290  			done += uint64(n)
   291  			if err != nil {
   292  				return done, err
   293  			}
   294  		} else {
   295  			done += uint64(dst.Len())
   296  		}
   297  		dsts = dsts.Tail()
   298  		i++
   299  	}
   300  	return done, rerr
   301  }
   302  
   303  // FromVecWriterFunc implements Writer for a function that writes data from a
   304  // [][]byte and returns the number of bytes written.
   305  type FromVecWriterFunc struct {
   306  	WriteVec func(srcs [][]byte) (int64, error)
   307  }
   308  
   309  // WriteFromBlocks implements Writer.WriteFromBlocks.
   310  //
   311  // WriteFromBlocks calls w.WriteVec at most once.
   312  func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) {
   313  	if srcs.IsEmpty() {
   314  		return 0, nil
   315  	}
   316  	// Ensure that we don't pass a [][]byte with a total length > MaxInt64.
   317  	srcs = srcs.TakeFirst64(uint64(math.MaxInt64))
   318  	srcSlices := make([][]byte, 0, srcs.NumBlocks())
   319  	// Buffer Blocks that require safecopy.
   320  	var buferr error
   321  	for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() {
   322  		src := tmp.Head()
   323  		if src.NeedSafecopy() {
   324  			slice := make([]byte, src.Len())
   325  			n, err := Copy(BlockFromSafeSlice(slice), src)
   326  			srcSlices = append(srcSlices, slice[:n])
   327  			if err != nil {
   328  				buferr = err
   329  				break
   330  			}
   331  		} else {
   332  			srcSlices = append(srcSlices, src.ToSlice())
   333  		}
   334  	}
   335  	n, err := w.WriteVec(srcSlices)
   336  	if err != nil {
   337  		return uint64(n), err
   338  	}
   339  	return uint64(n), buferr
   340  }