github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/usermem/bytes_io.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package usermem
    16  
    17  import (
    18  	"github.com/MerlinKodo/gvisor/pkg/context"
    19  	"github.com/MerlinKodo/gvisor/pkg/errors/linuxerr"
    20  	"github.com/MerlinKodo/gvisor/pkg/hostarch"
    21  	"github.com/MerlinKodo/gvisor/pkg/safemem"
    22  )
    23  
    24  const maxInt = int(^uint(0) >> 1)
    25  
    26  // BytesIO implements IO using a byte slice. Addresses are interpreted as
    27  // offsets into the slice. Reads and writes beyond the end of the slice return
    28  // EFAULT.
    29  type BytesIO struct {
    30  	Bytes []byte
    31  }
    32  
    33  // CopyOut implements IO.CopyOut.
    34  func (b *BytesIO) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) {
    35  	rngN, rngErr := b.rangeCheck(addr, len(src))
    36  	if rngN == 0 {
    37  		return 0, rngErr
    38  	}
    39  	return copy(b.Bytes[int(addr):], src[:rngN]), rngErr
    40  }
    41  
    42  // CopyIn implements IO.CopyIn.
    43  func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) {
    44  	rngN, rngErr := b.rangeCheck(addr, len(dst))
    45  	if rngN == 0 {
    46  		return 0, rngErr
    47  	}
    48  	return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr
    49  }
    50  
    51  // ZeroOut implements IO.ZeroOut.
    52  func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) {
    53  	if toZero > int64(maxInt) {
    54  		return 0, linuxerr.EINVAL
    55  	}
    56  	rngN, rngErr := b.rangeCheck(addr, int(toZero))
    57  	if rngN == 0 {
    58  		return 0, rngErr
    59  	}
    60  	zeroSlice := b.Bytes[int(addr) : int(addr)+rngN]
    61  	for i := range zeroSlice {
    62  		zeroSlice[i] = 0
    63  	}
    64  	return int64(rngN), rngErr
    65  }
    66  
    67  // CopyOutFrom implements IO.CopyOutFrom.
    68  func (b *BytesIO) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
    69  	dsts, rngErr := b.blocksFromAddrRanges(ars)
    70  	n, err := src.ReadToBlocks(dsts)
    71  	if err != nil {
    72  		return int64(n), err
    73  	}
    74  	return int64(n), rngErr
    75  }
    76  
    77  // CopyInTo implements IO.CopyInTo.
    78  func (b *BytesIO) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
    79  	srcs, rngErr := b.blocksFromAddrRanges(ars)
    80  	n, err := dst.WriteFromBlocks(srcs)
    81  	if err != nil {
    82  		return int64(n), err
    83  	}
    84  	return int64(n), rngErr
    85  }
    86  
    87  func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) {
    88  	if length == 0 {
    89  		return 0, nil
    90  	}
    91  	if length < 0 {
    92  		return 0, linuxerr.EINVAL
    93  	}
    94  	max := hostarch.Addr(len(b.Bytes))
    95  	if addr >= max {
    96  		return 0, linuxerr.EFAULT
    97  	}
    98  	end, ok := addr.AddLength(uint64(length))
    99  	if !ok || end > max {
   100  		return int(max - addr), linuxerr.EFAULT
   101  	}
   102  	return length, nil
   103  }
   104  
   105  func (b *BytesIO) blocksFromAddrRanges(ars hostarch.AddrRangeSeq) (safemem.BlockSeq, error) {
   106  	switch ars.NumRanges() {
   107  	case 0:
   108  		return safemem.BlockSeq{}, nil
   109  	case 1:
   110  		block, err := b.blockFromAddrRange(ars.Head())
   111  		return safemem.BlockSeqOf(block), err
   112  	default:
   113  		blocks := make([]safemem.Block, 0, ars.NumRanges())
   114  		for !ars.IsEmpty() {
   115  			block, err := b.blockFromAddrRange(ars.Head())
   116  			if block.Len() != 0 {
   117  				blocks = append(blocks, block)
   118  			}
   119  			if err != nil {
   120  				return safemem.BlockSeqFromSlice(blocks), err
   121  			}
   122  			ars = ars.Tail()
   123  		}
   124  		return safemem.BlockSeqFromSlice(blocks), nil
   125  	}
   126  }
   127  
   128  func (b *BytesIO) blockFromAddrRange(ar hostarch.AddrRange) (safemem.Block, error) {
   129  	n, err := b.rangeCheck(ar.Start, int(ar.Length()))
   130  	if n == 0 {
   131  		return safemem.Block{}, err
   132  	}
   133  	return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
   134  }
   135  
   136  // BytesIOSequence returns an IOSequence representing the given byte slice.
   137  func BytesIOSequence(buf []byte) IOSequence {
   138  	return IOSequence{
   139  		IO:    &BytesIO{buf},
   140  		Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(len(buf))}),
   141  	}
   142  }