github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/usermem/bytes_io.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package usermem
    16  
    17  import (
    18  	"github.com/sagernet/gvisor/pkg/context"
    19  	"github.com/sagernet/gvisor/pkg/errors/linuxerr"
    20  	"github.com/sagernet/gvisor/pkg/hostarch"
    21  	"github.com/sagernet/gvisor/pkg/safemem"
    22  )
    23  
    24  const maxInt = int(^uint(0) >> 1)
    25  
    26  // BytesIO implements IO using a byte slice. Addresses are interpreted as
    27  // offsets into the slice. Reads and writes beyond the end of the slice return
    28  // EFAULT.
    29  type BytesIO struct {
    30  	Bytes []byte
    31  }
    32  
    33  // CopyOut implements IO.CopyOut.
    34  func (b *BytesIO) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) {
    35  	rngN, rngErr := b.rangeCheck(addr, len(src))
    36  	if rngN == 0 {
    37  		return 0, rngErr
    38  	}
    39  	return copy(b.Bytes[int(addr):], src[:rngN]), rngErr
    40  }
    41  
    42  // CopyIn implements IO.CopyIn.
    43  func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) {
    44  	rngN, rngErr := b.rangeCheck(addr, len(dst))
    45  	if rngN == 0 {
    46  		return 0, rngErr
    47  	}
    48  	return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr
    49  }
    50  
    51  // ZeroOut implements IO.ZeroOut.
    52  func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) {
    53  	if toZero > int64(maxInt) {
    54  		return 0, linuxerr.EINVAL
    55  	}
    56  	rngN, rngErr := b.rangeCheck(addr, int(toZero))
    57  	if rngN == 0 {
    58  		return 0, rngErr
    59  	}
    60  	zeroSlice := b.Bytes[int(addr) : int(addr)+rngN]
    61  	clear(zeroSlice)
    62  	return int64(rngN), rngErr
    63  }
    64  
    65  // CopyOutFrom implements IO.CopyOutFrom.
    66  func (b *BytesIO) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
    67  	dsts, rngErr := b.blocksFromAddrRanges(ars)
    68  	n, err := src.ReadToBlocks(dsts)
    69  	if err != nil {
    70  		return int64(n), err
    71  	}
    72  	return int64(n), rngErr
    73  }
    74  
    75  // CopyInTo implements IO.CopyInTo.
    76  func (b *BytesIO) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
    77  	srcs, rngErr := b.blocksFromAddrRanges(ars)
    78  	n, err := dst.WriteFromBlocks(srcs)
    79  	if err != nil {
    80  		return int64(n), err
    81  	}
    82  	return int64(n), rngErr
    83  }
    84  
    85  func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) {
    86  	if length == 0 {
    87  		return 0, nil
    88  	}
    89  	if length < 0 {
    90  		return 0, linuxerr.EINVAL
    91  	}
    92  	max := hostarch.Addr(len(b.Bytes))
    93  	if addr >= max {
    94  		return 0, linuxerr.EFAULT
    95  	}
    96  	end, ok := addr.AddLength(uint64(length))
    97  	if !ok || end > max {
    98  		return int(max - addr), linuxerr.EFAULT
    99  	}
   100  	return length, nil
   101  }
   102  
   103  func (b *BytesIO) blocksFromAddrRanges(ars hostarch.AddrRangeSeq) (safemem.BlockSeq, error) {
   104  	switch ars.NumRanges() {
   105  	case 0:
   106  		return safemem.BlockSeq{}, nil
   107  	case 1:
   108  		block, err := b.blockFromAddrRange(ars.Head())
   109  		return safemem.BlockSeqOf(block), err
   110  	default:
   111  		blocks := make([]safemem.Block, 0, ars.NumRanges())
   112  		for !ars.IsEmpty() {
   113  			block, err := b.blockFromAddrRange(ars.Head())
   114  			if block.Len() != 0 {
   115  				blocks = append(blocks, block)
   116  			}
   117  			if err != nil {
   118  				return safemem.BlockSeqFromSlice(blocks), err
   119  			}
   120  			ars = ars.Tail()
   121  		}
   122  		return safemem.BlockSeqFromSlice(blocks), nil
   123  	}
   124  }
   125  
   126  func (b *BytesIO) blockFromAddrRange(ar hostarch.AddrRange) (safemem.Block, error) {
   127  	n, err := b.rangeCheck(ar.Start, int(ar.Length()))
   128  	if n == 0 {
   129  		return safemem.Block{}, err
   130  	}
   131  	return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
   132  }
   133  
   134  // BytesIOSequence returns an IOSequence representing the given byte slice.
   135  func BytesIOSequence(buf []byte) IOSequence {
   136  	return IOSequence{
   137  		IO:    &BytesIO{buf},
   138  		Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(len(buf))}),
   139  	}
   140  }