github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/safemem/io.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safemem 16 17 import ( 18 "errors" 19 "io" 20 "math" 21 ) 22 23 // ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write 24 // beyond the end of the BlockSeq. 25 var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq") 26 27 // Reader represents a streaming byte source like io.Reader. 28 type Reader interface { 29 // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the 30 // number of bytes read. It may return a partial read without an error 31 // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a 32 // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil); 33 // note that this differs from io.Reader.Read (in particular, io.EOF should 34 // not be returned if ReadToBlocks successfully reads dsts.NumBytes() 35 // bytes.) 36 ReadToBlocks(dsts BlockSeq) (uint64, error) 37 } 38 39 // Writer represents a streaming byte sink like io.Writer. 40 type Writer interface { 41 // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns 42 // the number of bytes written. It may return a partial write without an 43 // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not 44 // return a full write with an error (i.e. srcs.NumBytes(), err) where err 45 // != nil). 46 WriteFromBlocks(srcs BlockSeq) (uint64, error) 47 } 48 49 // ReadFullToBlocks repeatedly invokes r until dsts.NumBytes() bytes have been 50 // read or r returns an error. Note that we avoid a Reader interface receiver 51 // to avoid heap allocation. 52 func ReadFullToBlocks(r ReaderFunc, dsts BlockSeq) (uint64, error) { 53 var done uint64 54 for !dsts.IsEmpty() { 55 n, err := r(dsts) 56 done += n 57 if err != nil { 58 return done, err 59 } 60 dsts = dsts.DropFirst64(n) 61 } 62 return done, nil 63 } 64 65 // WriteFullFromBlocks repeatedly invokes w until srcs.NumBytes() bytes have 66 // been written or w returns an error. Note that we avoid a Writer interface 67 // receiver to avoid heap allocation. 68 func WriteFullFromBlocks(w WriterFunc, srcs BlockSeq) (uint64, error) { 69 var done uint64 70 for !srcs.IsEmpty() { 71 n, err := w(srcs) 72 done += n 73 if err != nil { 74 return done, err 75 } 76 srcs = srcs.DropFirst64(n) 77 } 78 return done, nil 79 } 80 81 // BlockSeqReader implements Reader by reading from a BlockSeq. 82 type BlockSeqReader struct { 83 Blocks BlockSeq 84 } 85 86 // ReadToBlocks implements Reader.ReadToBlocks. 87 func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { 88 n, err := CopySeq(dsts, r.Blocks) 89 r.Blocks = r.Blocks.DropFirst64(n) 90 if err != nil { 91 return n, err 92 } 93 if n < dsts.NumBytes() { 94 return n, io.EOF 95 } 96 return n, nil 97 } 98 99 // BlockSeqWriter implements Writer by writing to a BlockSeq. 100 type BlockSeqWriter struct { 101 Blocks BlockSeq 102 } 103 104 // WriteFromBlocks implements Writer.WriteFromBlocks. 105 func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 106 n, err := CopySeq(w.Blocks, srcs) 107 w.Blocks = w.Blocks.DropFirst64(n) 108 if err != nil { 109 return n, err 110 } 111 if n < srcs.NumBytes() { 112 return n, ErrEndOfBlockSeq 113 } 114 return n, nil 115 } 116 117 // ReaderFunc implements Reader for a function with the semantics of 118 // Reader.ReadToBlocks. 119 type ReaderFunc func(dsts BlockSeq) (uint64, error) 120 121 // ReadToBlocks implements Reader.ReadToBlocks. 122 func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { 123 return f(dsts) 124 } 125 126 // WriterFunc implements Writer for a function with the semantics of 127 // Writer.WriteFromBlocks. 128 type WriterFunc func(srcs BlockSeq) (uint64, error) 129 130 // WriteFromBlocks implements Writer.WriteFromBlocks. 131 func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 132 return f(srcs) 133 } 134 135 // ToIOReader implements io.Reader for a (safemem.)Reader. 136 // 137 // ToIOReader will return a successful partial read iff Reader.ReadToBlocks does 138 // so. 139 type ToIOReader struct { 140 Reader Reader 141 } 142 143 // Read implements io.Reader.Read. 144 func (r ToIOReader) Read(dst []byte) (int, error) { 145 n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst))) 146 return int(n), err 147 } 148 149 // FromIOReader implements Reader for an io.Reader by repeatedly invoking 150 // io.Reader.Read until it returns an error or partial read. This is not 151 // thread-safe. 152 // 153 // FromIOReader will return a successful partial read iff Reader.Read does so. 154 type FromIOReader struct { 155 Reader io.Reader 156 } 157 158 // ReadToBlocks implements Reader.ReadToBlocks. 159 func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { 160 var buf []byte 161 var done uint64 162 for !dsts.IsEmpty() { 163 dst := dsts.Head() 164 var n int 165 var err error 166 n, buf, err = r.readToBlock(dst, buf) 167 done += uint64(n) 168 if n != dst.Len() { 169 return done, err 170 } 171 dsts = dsts.Tail() 172 if err != nil { 173 if dsts.IsEmpty() && err == io.EOF { 174 return done, nil 175 } 176 return done, err 177 } 178 } 179 return done, nil 180 } 181 182 func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { 183 // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require 184 // safecopy. 185 if !dst.NeedSafecopy() { 186 n, err := r.Reader.Read(dst.ToSlice()) 187 return n, buf, err 188 } 189 if len(buf) < dst.Len() { 190 buf = make([]byte, dst.Len()) 191 } 192 rn, rerr := r.Reader.Read(buf[:dst.Len()]) 193 wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) 194 if wberr != nil { 195 return wbn, buf, wberr 196 } 197 return wbn, buf, rerr 198 } 199 200 // FromIOWriter implements Writer for an io.Writer by repeatedly invoking 201 // io.Writer.Write until it returns an error or partial write. 202 // 203 // FromIOWriter will tolerate implementations of io.Writer.Write that return 204 // partial writes with a nil error in contravention of io.Writer's 205 // requirements, since Writer is permitted to do so. FromIOWriter will return a 206 // successful partial write iff Writer.Write does so. 207 type FromIOWriter struct { 208 Writer io.Writer 209 } 210 211 // WriteFromBlocks implements Writer.WriteFromBlocks. 212 func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 213 var buf []byte 214 var done uint64 215 for !srcs.IsEmpty() { 216 src := srcs.Head() 217 var n int 218 var err error 219 n, buf, err = w.writeFromBlock(src, buf) 220 done += uint64(n) 221 if n != src.Len() || err != nil { 222 return done, err 223 } 224 srcs = srcs.Tail() 225 } 226 return done, nil 227 } 228 229 func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) { 230 // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require 231 // safecopy. 232 if !src.NeedSafecopy() { 233 n, err := w.Writer.Write(src.ToSlice()) 234 return n, buf, err 235 } 236 if len(buf) < src.Len() { 237 buf = make([]byte, src.Len()) 238 } 239 bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src) 240 wn, werr := w.Writer.Write(buf[:bufn]) 241 if werr != nil { 242 return wn, buf, werr 243 } 244 return wn, buf, buferr 245 } 246 247 // FromVecReaderFunc implements Reader for a function that reads data into a 248 // [][]byte and returns the number of bytes read as an int64. 249 type FromVecReaderFunc struct { 250 ReadVec func(dsts [][]byte) (int64, error) 251 } 252 253 // ReadToBlocks implements Reader.ReadToBlocks. 254 // 255 // ReadToBlocks calls r.ReadVec at most once. 256 func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { 257 if dsts.IsEmpty() { 258 return 0, nil 259 } 260 // Ensure that we don't pass a [][]byte with a total length > MaxInt64. 261 dsts = dsts.TakeFirst64(uint64(math.MaxInt64)) 262 dstSlices := make([][]byte, 0, dsts.NumBlocks()) 263 // Buffer Blocks that require safecopy. 264 for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() { 265 dst := tmp.Head() 266 if dst.NeedSafecopy() { 267 dstSlices = append(dstSlices, make([]byte, dst.Len())) 268 } else { 269 dstSlices = append(dstSlices, dst.ToSlice()) 270 } 271 } 272 rn, rerr := r.ReadVec(dstSlices) 273 dsts = dsts.TakeFirst64(uint64(rn)) 274 var done uint64 275 var i int 276 for !dsts.IsEmpty() { 277 dst := dsts.Head() 278 if dst.NeedSafecopy() { 279 n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i])) 280 done += uint64(n) 281 if err != nil { 282 return done, err 283 } 284 } else { 285 done += uint64(dst.Len()) 286 } 287 dsts = dsts.Tail() 288 i++ 289 } 290 return done, rerr 291 } 292 293 // FromVecWriterFunc implements Writer for a function that writes data from a 294 // [][]byte and returns the number of bytes written. 295 type FromVecWriterFunc struct { 296 WriteVec func(srcs [][]byte) (int64, error) 297 } 298 299 // WriteFromBlocks implements Writer.WriteFromBlocks. 300 // 301 // WriteFromBlocks calls w.WriteVec at most once. 302 func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 303 if srcs.IsEmpty() { 304 return 0, nil 305 } 306 // Ensure that we don't pass a [][]byte with a total length > MaxInt64. 307 srcs = srcs.TakeFirst64(uint64(math.MaxInt64)) 308 srcSlices := make([][]byte, 0, srcs.NumBlocks()) 309 // Buffer Blocks that require safecopy. 310 var buferr error 311 for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() { 312 src := tmp.Head() 313 if src.NeedSafecopy() { 314 slice := make([]byte, src.Len()) 315 n, err := Copy(BlockFromSafeSlice(slice), src) 316 srcSlices = append(srcSlices, slice[:n]) 317 if err != nil { 318 buferr = err 319 break 320 } 321 } else { 322 srcSlices = append(srcSlices, src.ToSlice()) 323 } 324 } 325 n, err := w.WriteVec(srcSlices) 326 if err != nil { 327 return uint64(n), err 328 } 329 return uint64(n), buferr 330 }