github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/safemem/io.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safemem 16 17 import ( 18 "errors" 19 "io" 20 "math" 21 ) 22 23 // ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write 24 // beyond the end of the BlockSeq. 25 var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq") 26 27 // Reader represents a streaming byte source like io.Reader. 28 type Reader interface { 29 // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the 30 // number of bytes read. It may return a partial read without an error 31 // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a 32 // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil); 33 // note that this differs from io.Reader.Read (in particular, io.EOF should 34 // not be returned if ReadToBlocks successfully reads dsts.NumBytes() 35 // bytes.) 36 ReadToBlocks(dsts BlockSeq) (uint64, error) 37 } 38 39 // Writer represents a streaming byte sink like io.Writer. 40 type Writer interface { 41 // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns 42 // the number of bytes written. It may return a partial write without an 43 // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not 44 // return a full write with an error (i.e. srcs.NumBytes(), err) where err 45 // != nil). 46 WriteFromBlocks(srcs BlockSeq) (uint64, error) 47 } 48 49 // ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes() 50 // bytes have been read or ReadToBlocks returns an error. 51 func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) { 52 var done uint64 53 for !dsts.IsEmpty() { 54 n, err := r.ReadToBlocks(dsts) 55 done += n 56 if err != nil { 57 return done, err 58 } 59 dsts = dsts.DropFirst64(n) 60 } 61 return done, nil 62 } 63 64 // WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until 65 // srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error. 66 func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) { 67 var done uint64 68 for !srcs.IsEmpty() { 69 n, err := w.WriteFromBlocks(srcs) 70 done += n 71 if err != nil { 72 return done, err 73 } 74 srcs = srcs.DropFirst64(n) 75 } 76 return done, nil 77 } 78 79 // BlockSeqReader implements Reader by reading from a BlockSeq. 80 type BlockSeqReader struct { 81 Blocks BlockSeq 82 } 83 84 // ReadToBlocks implements Reader.ReadToBlocks. 85 func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { 86 n, err := CopySeq(dsts, r.Blocks) 87 r.Blocks = r.Blocks.DropFirst64(n) 88 if err != nil { 89 return n, err 90 } 91 if n < dsts.NumBytes() { 92 return n, io.EOF 93 } 94 return n, nil 95 } 96 97 // BlockSeqWriter implements Writer by writing to a BlockSeq. 98 type BlockSeqWriter struct { 99 Blocks BlockSeq 100 } 101 102 // WriteFromBlocks implements Writer.WriteFromBlocks. 103 func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 104 n, err := CopySeq(w.Blocks, srcs) 105 w.Blocks = w.Blocks.DropFirst64(n) 106 if err != nil { 107 return n, err 108 } 109 if n < srcs.NumBytes() { 110 return n, ErrEndOfBlockSeq 111 } 112 return n, nil 113 } 114 115 // ReaderFunc implements Reader for a function with the semantics of 116 // Reader.ReadToBlocks. 117 type ReaderFunc func(dsts BlockSeq) (uint64, error) 118 119 // ReadToBlocks implements Reader.ReadToBlocks. 120 func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { 121 return f(dsts) 122 } 123 124 // WriterFunc implements Writer for a function with the semantics of 125 // Writer.WriteFromBlocks. 126 type WriterFunc func(srcs BlockSeq) (uint64, error) 127 128 // WriteFromBlocks implements Writer.WriteFromBlocks. 129 func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 130 return f(srcs) 131 } 132 133 // ToIOReader implements io.Reader for a (safemem.)Reader. 134 // 135 // ToIOReader will return a successful partial read iff Reader.ReadToBlocks does 136 // so. 137 type ToIOReader struct { 138 Reader Reader 139 } 140 141 // Read implements io.Reader.Read. 142 func (r ToIOReader) Read(dst []byte) (int, error) { 143 n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst))) 144 return int(n), err 145 } 146 147 // ToIOWriter implements io.Writer for a (safemem.)Writer. 148 type ToIOWriter struct { 149 Writer Writer 150 } 151 152 // Write implements io.Writer.Write. 153 func (w ToIOWriter) Write(src []byte) (int, error) { 154 // io.Writer does not permit partial writes. 155 n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src))) 156 return int(n), err 157 } 158 159 // FromIOReader implements Reader for an io.Reader by repeatedly invoking 160 // io.Reader.Read until it returns an error or partial read. This is not 161 // thread-safe. 162 // 163 // FromIOReader will return a successful partial read iff Reader.Read does so. 164 type FromIOReader struct { 165 Reader io.Reader 166 } 167 168 // ReadToBlocks implements Reader.ReadToBlocks. 169 func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { 170 var buf []byte 171 var done uint64 172 for !dsts.IsEmpty() { 173 dst := dsts.Head() 174 var n int 175 var err error 176 n, buf, err = r.readToBlock(dst, buf) 177 done += uint64(n) 178 if n != dst.Len() { 179 return done, err 180 } 181 dsts = dsts.Tail() 182 if err != nil { 183 if dsts.IsEmpty() && err == io.EOF { 184 return done, nil 185 } 186 return done, err 187 } 188 } 189 return done, nil 190 } 191 192 func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { 193 // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require 194 // safecopy. 195 if !dst.NeedSafecopy() { 196 n, err := r.Reader.Read(dst.ToSlice()) 197 return n, buf, err 198 } 199 if len(buf) < dst.Len() { 200 buf = make([]byte, dst.Len()) 201 } 202 rn, rerr := r.Reader.Read(buf[:dst.Len()]) 203 wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) 204 if wberr != nil { 205 return wbn, buf, wberr 206 } 207 return wbn, buf, rerr 208 } 209 210 // FromIOWriter implements Writer for an io.Writer by repeatedly invoking 211 // io.Writer.Write until it returns an error or partial write. 212 // 213 // FromIOWriter will tolerate implementations of io.Writer.Write that return 214 // partial writes with a nil error in contravention of io.Writer's 215 // requirements, since Writer is permitted to do so. FromIOWriter will return a 216 // successful partial write iff Writer.Write does so. 217 type FromIOWriter struct { 218 Writer io.Writer 219 } 220 221 // WriteFromBlocks implements Writer.WriteFromBlocks. 222 func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 223 var buf []byte 224 var done uint64 225 for !srcs.IsEmpty() { 226 src := srcs.Head() 227 var n int 228 var err error 229 n, buf, err = w.writeFromBlock(src, buf) 230 done += uint64(n) 231 if n != src.Len() || err != nil { 232 return done, err 233 } 234 srcs = srcs.Tail() 235 } 236 return done, nil 237 } 238 239 func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) { 240 // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require 241 // safecopy. 242 if !src.NeedSafecopy() { 243 n, err := w.Writer.Write(src.ToSlice()) 244 return n, buf, err 245 } 246 if len(buf) < src.Len() { 247 buf = make([]byte, src.Len()) 248 } 249 bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src) 250 wn, werr := w.Writer.Write(buf[:bufn]) 251 if werr != nil { 252 return wn, buf, werr 253 } 254 return wn, buf, buferr 255 } 256 257 // FromVecReaderFunc implements Reader for a function that reads data into a 258 // [][]byte and returns the number of bytes read as an int64. 259 type FromVecReaderFunc struct { 260 ReadVec func(dsts [][]byte) (int64, error) 261 } 262 263 // ReadToBlocks implements Reader.ReadToBlocks. 264 // 265 // ReadToBlocks calls r.ReadVec at most once. 266 func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { 267 if dsts.IsEmpty() { 268 return 0, nil 269 } 270 // Ensure that we don't pass a [][]byte with a total length > MaxInt64. 271 dsts = dsts.TakeFirst64(uint64(math.MaxInt64)) 272 dstSlices := make([][]byte, 0, dsts.NumBlocks()) 273 // Buffer Blocks that require safecopy. 274 for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() { 275 dst := tmp.Head() 276 if dst.NeedSafecopy() { 277 dstSlices = append(dstSlices, make([]byte, dst.Len())) 278 } else { 279 dstSlices = append(dstSlices, dst.ToSlice()) 280 } 281 } 282 rn, rerr := r.ReadVec(dstSlices) 283 dsts = dsts.TakeFirst64(uint64(rn)) 284 var done uint64 285 var i int 286 for !dsts.IsEmpty() { 287 dst := dsts.Head() 288 if dst.NeedSafecopy() { 289 n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i])) 290 done += uint64(n) 291 if err != nil { 292 return done, err 293 } 294 } else { 295 done += uint64(dst.Len()) 296 } 297 dsts = dsts.Tail() 298 i++ 299 } 300 return done, rerr 301 } 302 303 // FromVecWriterFunc implements Writer for a function that writes data from a 304 // [][]byte and returns the number of bytes written. 305 type FromVecWriterFunc struct { 306 WriteVec func(srcs [][]byte) (int64, error) 307 } 308 309 // WriteFromBlocks implements Writer.WriteFromBlocks. 310 // 311 // WriteFromBlocks calls w.WriteVec at most once. 312 func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { 313 if srcs.IsEmpty() { 314 return 0, nil 315 } 316 // Ensure that we don't pass a [][]byte with a total length > MaxInt64. 317 srcs = srcs.TakeFirst64(uint64(math.MaxInt64)) 318 srcSlices := make([][]byte, 0, srcs.NumBlocks()) 319 // Buffer Blocks that require safecopy. 320 var buferr error 321 for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() { 322 src := tmp.Head() 323 if src.NeedSafecopy() { 324 slice := make([]byte, src.Len()) 325 n, err := Copy(BlockFromSafeSlice(slice), src) 326 srcSlices = append(srcSlices, slice[:n]) 327 if err != nil { 328 buferr = err 329 break 330 } 331 } else { 332 srcSlices = append(srcSlices, src.ToSlice()) 333 } 334 } 335 n, err := w.WriteVec(srcSlices) 336 if err != nil { 337 return uint64(n), err 338 } 339 return uint64(n), buferr 340 }