zombiezen.com/go/lua@v0.0.0-20231013005828-290725fb9140/internal/bufseek/bufseek.go (about)

     1  // Copyright 2023 Ross Light
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy of
     4  // this software and associated documentation files (the “Software”), to deal in
     5  // the Software without restriction, including without limitation the rights to
     6  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
     7  // the Software, and to permit persons to whom the Software is furnished to do so,
     8  // subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in all
    11  // copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
    15  // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
    16  // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
    17  // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
    18  // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
    19  //
    20  // SPDX-License-Identifier: MIT
    21  
    22  // Package bufseek provides a buffered [io.Reader] that also implements [io.Seeker].
    23  package bufseek
    24  
    25  import (
    26  	"errors"
    27  	"fmt"
    28  	"io"
    29  )
    30  
    31  const defaultBufSize = 4096
    32  const maxConsecutiveEmptyReads = 100
    33  
    34  // Reader implements buffering for an [io.ReadSeeker] object.
    35  type Reader struct {
    36  	buf  []byte
    37  	rd   io.ReadSeeker
    38  	r, w int
    39  	err  error
    40  	// pos is the stream position of the beginning of buf.
    41  	pos int64
    42  }
    43  
    44  // NewReaderSize returns a new [Reader]
    45  // whose buffer has at least the specified size.
    46  // If the argument [io.Reader] is already a *Reader or *ReadWriter with large enough size,
    47  // it returns the underlying *Reader.
    48  func NewReaderSize(rd io.ReadSeeker, size int) *Reader {
    49  	switch b := rd.(type) {
    50  	case *Reader:
    51  		if len(b.buf) >= size {
    52  			return b
    53  		}
    54  	case *ReadWriter:
    55  		if len(b.r.buf) >= size {
    56  			return b.r
    57  		}
    58  	}
    59  	size = max(size, 16)
    60  	return &Reader{
    61  		buf: make([]byte, size),
    62  		rd:  rd,
    63  		pos: -1,
    64  	}
    65  }
    66  
    67  // NewReader returns a new Reader whose buffer has the default size.
    68  func NewReader(rd io.ReadSeeker) *Reader {
    69  	return NewReaderSize(rd, defaultBufSize)
    70  }
    71  
    72  func (b *Reader) advance(n int) {
    73  	if b.pos >= 0 {
    74  		b.pos += int64(n)
    75  	}
    76  }
    77  
    78  func (b *Reader) fill() {
    79  	// Slide existing data to beginning.
    80  	if b.r > 0 {
    81  		copy(b.buf, b.buf[b.r:b.w])
    82  		b.advance(b.r)
    83  		b.w -= b.r
    84  		b.r = 0
    85  	}
    86  
    87  	if b.w >= len(b.buf) {
    88  		panic("bufseek: tried to fill full buffer")
    89  	}
    90  
    91  	// Read new data: try a limited number of times.
    92  	for i := maxConsecutiveEmptyReads; i > 0; i-- {
    93  		n, err := b.rd.Read(b.buf[b.w:])
    94  		if n < 0 {
    95  			panic(errNegativeRead)
    96  		}
    97  		b.w += n
    98  		if err != nil {
    99  			b.err = err
   100  			return
   101  		}
   102  		if n > 0 {
   103  			return
   104  		}
   105  	}
   106  	b.err = io.ErrNoProgress
   107  }
   108  
   109  func (b *Reader) readErr() error {
   110  	err := b.err
   111  	b.err = nil
   112  	return err
   113  }
   114  
   115  // ReadByte reads and returns a single byte.
   116  // If no byte is available, returns an error.
   117  func (b *Reader) ReadByte() (byte, error) {
   118  	for b.r == b.w {
   119  		if b.err != nil {
   120  			return 0, b.readErr()
   121  		}
   122  		b.fill() // Buffer is empty.
   123  	}
   124  	c := b.buf[b.r]
   125  	b.r++
   126  	return c, nil
   127  }
   128  
   129  // Read reads data into p.
   130  // It returns the number of bytes read into p.
   131  // The bytes are taken from at most one Read on the underlying Reader,
   132  // hence n may be less than len(p).
   133  // To read exactly len(p) bytes, use io.ReadFull(b, p).
   134  // If the underlying Reader can return a non-zero count with io.EOF,
   135  // then this Read method can do so as well; see the [io.Reader] docs.
   136  func (b *Reader) Read(p []byte) (n int, err error) {
   137  	if len(p) == 0 {
   138  		if b.Buffered() > 0 {
   139  			return 0, nil
   140  		}
   141  		return 0, b.readErr()
   142  	}
   143  	if b.r == b.w {
   144  		if b.err != nil {
   145  			return 0, b.readErr()
   146  		}
   147  		if len(p) >= len(b.buf) {
   148  			// Large read, empty buffer.
   149  			// Read directly into p to avoid copy.
   150  			n, b.err = b.rd.Read(p)
   151  			if n < 0 {
   152  				panic(errNegativeRead)
   153  			}
   154  			b.advance(b.r)
   155  			b.advance(n)
   156  			b.r = 0
   157  			b.w = 0
   158  			return n, b.readErr()
   159  		}
   160  		// One read.
   161  		// Do not use b.fill, which will loop.
   162  		b.advance(b.r)
   163  		b.r = 0
   164  		b.w = 0
   165  		n, b.err = b.rd.Read(b.buf)
   166  		if n < 0 {
   167  			panic(errNegativeRead)
   168  		}
   169  		if n == 0 {
   170  			return 0, b.readErr()
   171  		}
   172  		b.w += n
   173  	}
   174  
   175  	// Copy as much as we can.
   176  	// Note: if the slice panics here, it is probably because
   177  	// the underlying reader returned a bad count. See https://go.dev/issue/49795.
   178  	n = copy(p, b.buf[b.r:b.w])
   179  	b.r += n
   180  	return n, nil
   181  }
   182  
   183  // Seek sets the offset for the next Read to offset, interpreted according to whence;
   184  // see the [io.Seeker] docs.
   185  func (b *Reader) Seek(offset int64, whence int) (pos int64, err error) {
   186  	if whence == io.SeekCurrent {
   187  		if 0 <= offset && offset <= int64(b.Buffered()) {
   188  			if b.pos < 0 {
   189  				pos, err := b.rd.Seek(0, io.SeekCurrent)
   190  				if err != nil {
   191  					return 0, err
   192  				}
   193  				b.pos = pos - int64(b.w)
   194  			}
   195  			b.r += int(offset)
   196  			return b.pos + int64(b.r), nil
   197  		}
   198  		pos, err = b.rd.Seek(offset-int64(b.w), io.SeekCurrent)
   199  	} else {
   200  		pos, err = b.rd.Seek(offset, whence)
   201  	}
   202  	if err == nil {
   203  		b.clear(pos)
   204  	}
   205  	return pos, err
   206  }
   207  
   208  func (b *Reader) clear(pos int64) {
   209  	b.pos = pos
   210  	b.r = 0
   211  	b.w = 0
   212  	b.err = nil
   213  }
   214  
   215  // Buffered returns the number of bytes that can be read from the current buffer.
   216  func (b *Reader) Buffered() int { return b.w - b.r }
   217  
   218  // ReadWriter implements buffering for an [io.ReadWriter] or an [io.ReadWriteSeeker] object.
   219  type ReadWriter struct {
   220  	r *Reader
   221  	w io.Writer
   222  }
   223  
   224  // NewReadWriterSize returns a new [ReadWriter]
   225  // whose buffer has at least the specified size.
   226  // If the argument [io.ReadWriter] is already a *ReadWriter with large enough size,
   227  // it returns the underlying *ReadWriter.
   228  func NewReadWriterSize(rw io.ReadWriteSeeker, size int) *ReadWriter {
   229  	if b, ok := rw.(*ReadWriter); ok && len(b.r.buf) >= size {
   230  		return b
   231  	}
   232  	return &ReadWriter{
   233  		r: NewReaderSize(rw, size),
   234  		w: rw,
   235  	}
   236  }
   237  
   238  // NewReadWriter returns a new ReadWriter that has the default size.
   239  func NewReadWriter(rw io.ReadWriteSeeker) *ReadWriter {
   240  	return NewReadWriterSize(rw, defaultBufSize)
   241  }
   242  
   243  // ReadByte reads and returns a single byte.
   244  // If no byte is available, returns an error.
   245  func (b *ReadWriter) ReadByte() (byte, error) {
   246  	return b.r.ReadByte()
   247  }
   248  
   249  // Read reads data into p.
   250  // It returns the number of bytes read into p.
   251  // The bytes are taken from at most one Read on the underlying Reader,
   252  // hence n may be less than len(p).
   253  // To read exactly len(p) bytes, use io.ReadFull(b, p).
   254  // If the underlying Reader can return a non-zero count with io.EOF,
   255  // then this Read method can do so as well; see the [io.Reader] docs.
   256  func (b *ReadWriter) Read(p []byte) (n int, err error) {
   257  	return b.r.Read(p)
   258  }
   259  
   260  // Seek sets the offset for the next Read to offset, interpreted according to whence;
   261  // see the [io.Seeker] docs.
   262  func (b *ReadWriter) Seek(offset int64, whence int) (int64, error) {
   263  	return b.r.Seek(offset, whence)
   264  }
   265  
   266  // Write writes data from p.
   267  // It returns the number of bytes written from p
   268  // and any error encountered that caused the write to stop early.
   269  func (b *ReadWriter) Write(p []byte) (n int, err error) {
   270  	if len(p) == 0 {
   271  		return 0, nil
   272  	}
   273  	if err := b.syncWritePosition(); err != nil {
   274  		return 0, err
   275  	}
   276  
   277  	n, err = b.w.Write(p)
   278  	// If we cached a Seek position, we have certainly invalidated it.
   279  	// Files opened for appending make the final position hard to predict,
   280  	// so we just clear the position and recompute it as needed.
   281  	b.r.clear(-1)
   282  	return n, err
   283  }
   284  
   285  // WriteString writes data from s.
   286  // It returns the number of bytes written from s
   287  // and any error encountered that caused the write to stop early.
   288  func (b *ReadWriter) WriteString(s string) (n int, err error) {
   289  	if len(s) == 0 {
   290  		return 0, nil
   291  	}
   292  	if err := b.syncWritePosition(); err != nil {
   293  		return 0, err
   294  	}
   295  
   296  	n, err = io.WriteString(b.w, s)
   297  	// Same note as in Write.
   298  	b.r.clear(-1)
   299  	return n, err
   300  }
   301  
   302  func (b *ReadWriter) syncWritePosition() error {
   303  	if b.r.Buffered() > 0 {
   304  		_, err := b.r.rd.Seek(-int64(b.r.Buffered()), io.SeekCurrent)
   305  		if err != nil {
   306  			return fmt.Errorf("bufseek: seek for write: %w", err)
   307  		}
   308  	}
   309  	return nil
   310  }
   311  
   312  var errNegativeRead = errors.New("bufseek: reader returned negative count from Read")