github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/usermem/bytes_io.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package usermem
    16  
    17  import (
    18  	"github.com/SagerNet/gvisor/pkg/context"
    19  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    20  	"github.com/SagerNet/gvisor/pkg/hostarch"
    21  	"github.com/SagerNet/gvisor/pkg/safemem"
    22  	"github.com/SagerNet/gvisor/pkg/syserror"
    23  )
    24  
    25  const maxInt = int(^uint(0) >> 1)
    26  
    27  // BytesIO implements IO using a byte slice. Addresses are interpreted as
    28  // offsets into the slice. Reads and writes beyond the end of the slice return
    29  // EFAULT.
    30  type BytesIO struct {
    31  	Bytes []byte
    32  }
    33  
    34  // CopyOut implements IO.CopyOut.
    35  func (b *BytesIO) CopyOut(ctx context.Context, addr hostarch.Addr, src []byte, opts IOOpts) (int, error) {
    36  	rngN, rngErr := b.rangeCheck(addr, len(src))
    37  	if rngN == 0 {
    38  		return 0, rngErr
    39  	}
    40  	return copy(b.Bytes[int(addr):], src[:rngN]), rngErr
    41  }
    42  
    43  // CopyIn implements IO.CopyIn.
    44  func (b *BytesIO) CopyIn(ctx context.Context, addr hostarch.Addr, dst []byte, opts IOOpts) (int, error) {
    45  	rngN, rngErr := b.rangeCheck(addr, len(dst))
    46  	if rngN == 0 {
    47  		return 0, rngErr
    48  	}
    49  	return copy(dst[:rngN], b.Bytes[int(addr):]), rngErr
    50  }
    51  
    52  // ZeroOut implements IO.ZeroOut.
    53  func (b *BytesIO) ZeroOut(ctx context.Context, addr hostarch.Addr, toZero int64, opts IOOpts) (int64, error) {
    54  	if toZero > int64(maxInt) {
    55  		return 0, linuxerr.EINVAL
    56  	}
    57  	rngN, rngErr := b.rangeCheck(addr, int(toZero))
    58  	if rngN == 0 {
    59  		return 0, rngErr
    60  	}
    61  	zeroSlice := b.Bytes[int(addr) : int(addr)+rngN]
    62  	for i := range zeroSlice {
    63  		zeroSlice[i] = 0
    64  	}
    65  	return int64(rngN), rngErr
    66  }
    67  
    68  // CopyOutFrom implements IO.CopyOutFrom.
    69  func (b *BytesIO) CopyOutFrom(ctx context.Context, ars hostarch.AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) {
    70  	dsts, rngErr := b.blocksFromAddrRanges(ars)
    71  	n, err := src.ReadToBlocks(dsts)
    72  	if err != nil {
    73  		return int64(n), err
    74  	}
    75  	return int64(n), rngErr
    76  }
    77  
    78  // CopyInTo implements IO.CopyInTo.
    79  func (b *BytesIO) CopyInTo(ctx context.Context, ars hostarch.AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) {
    80  	srcs, rngErr := b.blocksFromAddrRanges(ars)
    81  	n, err := dst.WriteFromBlocks(srcs)
    82  	if err != nil {
    83  		return int64(n), err
    84  	}
    85  	return int64(n), rngErr
    86  }
    87  
    88  func (b *BytesIO) rangeCheck(addr hostarch.Addr, length int) (int, error) {
    89  	if length == 0 {
    90  		return 0, nil
    91  	}
    92  	if length < 0 {
    93  		return 0, linuxerr.EINVAL
    94  	}
    95  	max := hostarch.Addr(len(b.Bytes))
    96  	if addr >= max {
    97  		return 0, syserror.EFAULT
    98  	}
    99  	end, ok := addr.AddLength(uint64(length))
   100  	if !ok || end > max {
   101  		return int(max - addr), syserror.EFAULT
   102  	}
   103  	return length, nil
   104  }
   105  
   106  func (b *BytesIO) blocksFromAddrRanges(ars hostarch.AddrRangeSeq) (safemem.BlockSeq, error) {
   107  	switch ars.NumRanges() {
   108  	case 0:
   109  		return safemem.BlockSeq{}, nil
   110  	case 1:
   111  		block, err := b.blockFromAddrRange(ars.Head())
   112  		return safemem.BlockSeqOf(block), err
   113  	default:
   114  		blocks := make([]safemem.Block, 0, ars.NumRanges())
   115  		for !ars.IsEmpty() {
   116  			block, err := b.blockFromAddrRange(ars.Head())
   117  			if block.Len() != 0 {
   118  				blocks = append(blocks, block)
   119  			}
   120  			if err != nil {
   121  				return safemem.BlockSeqFromSlice(blocks), err
   122  			}
   123  			ars = ars.Tail()
   124  		}
   125  		return safemem.BlockSeqFromSlice(blocks), nil
   126  	}
   127  }
   128  
   129  func (b *BytesIO) blockFromAddrRange(ar hostarch.AddrRange) (safemem.Block, error) {
   130  	n, err := b.rangeCheck(ar.Start, int(ar.Length()))
   131  	if n == 0 {
   132  		return safemem.Block{}, err
   133  	}
   134  	return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
   135  }
   136  
   137  // BytesIOSequence returns an IOSequence representing the given byte slice.
   138  func BytesIOSequence(buf []byte) IOSequence {
   139  	return IOSequence{
   140  		IO:    &BytesIO{buf},
   141  		Addrs: hostarch.AddrRangeSeqOf(hostarch.AddrRange{0, hostarch.Addr(len(buf))}),
   142  	}
   143  }