github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/reader.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"strings"
     8  	"unsafe"
     9  )
    10  
    11  const (
    12  	maxIntBufSize  = 5
    13  	maxLongBufSize = 10
    14  )
    15  
    16  // ReaderFunc is a function used to customize the Reader.
    17  type ReaderFunc func(r *Reader)
    18  
    19  // WithReaderConfig specifies the configuration to use with a reader.
    20  func WithReaderConfig(cfg API) ReaderFunc {
    21  	return func(r *Reader) {
    22  		r.cfg = cfg.(*frozenConfig)
    23  	}
    24  }
    25  
    26  // Reader is an Avro specific io.Reader.
    27  type Reader struct {
    28  	cfg    *frozenConfig
    29  	reader io.Reader
    30  	slab   []byte
    31  	buf    []byte
    32  	head   int
    33  	tail   int
    34  	Error  error
    35  }
    36  
    37  // NewReader creates a new Reader.
    38  func NewReader(r io.Reader, bufSize int, opts ...ReaderFunc) *Reader {
    39  	reader := &Reader{
    40  		cfg:    DefaultConfig.(*frozenConfig),
    41  		reader: r,
    42  		buf:    make([]byte, bufSize),
    43  		head:   0,
    44  		tail:   0,
    45  	}
    46  
    47  	for _, opt := range opts {
    48  		opt(reader)
    49  	}
    50  
    51  	return reader
    52  }
    53  
    54  // Reset resets a Reader with a new byte array attached.
    55  func (r *Reader) Reset(b []byte) *Reader {
    56  	r.reader = nil
    57  	r.buf = b
    58  	r.head = 0
    59  	r.tail = len(b)
    60  	return r
    61  }
    62  
    63  // ReportError record a error in iterator instance with current position.
    64  func (r *Reader) ReportError(operation, msg string) {
    65  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
    66  		return
    67  	}
    68  
    69  	r.Error = fmt.Errorf("avro: %s: %s", operation, msg)
    70  }
    71  
    72  func (r *Reader) loadMore() bool {
    73  	if r.reader == nil {
    74  		if r.Error == nil {
    75  			r.head = r.tail
    76  			r.Error = io.EOF
    77  		}
    78  		return false
    79  	}
    80  
    81  	for {
    82  		n, err := r.reader.Read(r.buf)
    83  		if n == 0 {
    84  			if err != nil {
    85  				if r.Error == nil {
    86  					r.Error = err
    87  				}
    88  				return false
    89  			}
    90  			continue
    91  		}
    92  
    93  		r.head = 0
    94  		r.tail = n
    95  		return true
    96  	}
    97  }
    98  
    99  func (r *Reader) readByte() byte {
   100  	if r.head == r.tail {
   101  		if !r.loadMore() {
   102  			r.Error = io.ErrUnexpectedEOF
   103  			return 0
   104  		}
   105  	}
   106  
   107  	b := r.buf[r.head]
   108  	r.head++
   109  
   110  	return b
   111  }
   112  
   113  // Peek returns the next byte in the buffer.
   114  // The Reader Error will be io.EOF if no next byte exists.
   115  func (r *Reader) Peek() byte {
   116  	if r.head == r.tail {
   117  		if !r.loadMore() {
   118  			return 0
   119  		}
   120  	}
   121  	return r.buf[r.head]
   122  }
   123  
   124  // Read reads data into the given bytes.
   125  func (r *Reader) Read(b []byte) {
   126  	size := len(b)
   127  	read := 0
   128  
   129  	for read < size {
   130  		if r.head == r.tail {
   131  			if !r.loadMore() {
   132  				r.Error = io.ErrUnexpectedEOF
   133  				return
   134  			}
   135  		}
   136  
   137  		n := copy(b[read:], r.buf[r.head:r.tail])
   138  		r.head += n
   139  		read += n
   140  	}
   141  }
   142  
   143  // ReadBool reads a Bool from the Reader.
   144  func (r *Reader) ReadBool() bool {
   145  	b := r.readByte()
   146  
   147  	if b != 0 && b != 1 {
   148  		r.ReportError("ReadBool", "invalid bool")
   149  	}
   150  	return b == 1
   151  }
   152  
   153  // ReadInt reads an Int from the Reader.
   154  //
   155  //nolint:dupl
   156  func (r *Reader) ReadInt() int32 {
   157  	if r.Error != nil {
   158  		return 0
   159  	}
   160  
   161  	var (
   162  		n int
   163  		v uint32
   164  		s uint8
   165  	)
   166  
   167  	for {
   168  		tail := r.tail
   169  		if r.tail-r.head+n > maxIntBufSize {
   170  			tail = r.head + maxIntBufSize - n
   171  		}
   172  
   173  		// Consume what it is in the buffer.
   174  		var i int
   175  		for _, b := range r.buf[r.head:tail] {
   176  			v |= uint32(b&0x7f) << s
   177  			if b&0x80 == 0 {
   178  				r.head += i + 1
   179  				return int32((v >> 1) ^ -(v & 1))
   180  			}
   181  			s += 7
   182  			i++
   183  		}
   184  		if n >= maxIntBufSize {
   185  			r.ReportError("ReadInt", "int overflow")
   186  			return 0
   187  		}
   188  		r.head += i
   189  		n += i
   190  
   191  		// We ran out of buffer and are not at the end of the int,
   192  		// Read more into the buffer.
   193  		if !r.loadMore() {
   194  			r.Error = fmt.Errorf("reading int: %w", r.Error)
   195  			return 0
   196  		}
   197  	}
   198  }
   199  
   200  // ReadLong reads a Long from the Reader.
   201  //
   202  //nolint:dupl
   203  func (r *Reader) ReadLong() int64 {
   204  	if r.Error != nil {
   205  		return 0
   206  	}
   207  
   208  	var (
   209  		n int
   210  		v uint64
   211  		s uint8
   212  	)
   213  
   214  	for {
   215  		tail := r.tail
   216  		if r.tail-r.head+n > maxLongBufSize {
   217  			tail = r.head + maxLongBufSize - n
   218  		}
   219  
   220  		// Consume what it is in the buffer.
   221  		var i int
   222  		for _, b := range r.buf[r.head:tail] {
   223  			v |= uint64(b&0x7f) << s
   224  			if b&0x80 == 0 {
   225  				r.head += i + 1
   226  				return int64((v >> 1) ^ -(v & 1))
   227  			}
   228  			s += 7
   229  			i++
   230  		}
   231  		if n >= maxLongBufSize {
   232  			r.ReportError("ReadLong", "int overflow")
   233  			return 0
   234  		}
   235  		r.head += i
   236  		n += i
   237  
   238  		// We ran out of buffer and are not at the end of the long,
   239  		// Read more into the buffer.
   240  		if !r.loadMore() {
   241  			r.Error = fmt.Errorf("reading long: %w", r.Error)
   242  			return 0
   243  		}
   244  	}
   245  }
   246  
   247  // ReadFloat reads a Float from the Reader.
   248  func (r *Reader) ReadFloat() float32 {
   249  	var buf [4]byte
   250  	r.Read(buf[:])
   251  
   252  	float := *(*float32)(unsafe.Pointer(&buf[0]))
   253  	return float
   254  }
   255  
   256  // ReadDouble reads a Double from the Reader.
   257  func (r *Reader) ReadDouble() float64 {
   258  	var buf [8]byte
   259  	r.Read(buf[:])
   260  
   261  	float := *(*float64)(unsafe.Pointer(&buf[0]))
   262  	return float
   263  }
   264  
   265  // ReadBytes reads Bytes from the Reader.
   266  func (r *Reader) ReadBytes() []byte {
   267  	return r.readBytes("bytes")
   268  }
   269  
   270  // ReadString reads a String from the Reader.
   271  func (r *Reader) ReadString() string {
   272  	b := r.readBytes("string")
   273  	if len(b) == 0 {
   274  		return ""
   275  	}
   276  
   277  	return *(*string)(unsafe.Pointer(&b))
   278  }
   279  
   280  func (r *Reader) readBytes(op string) []byte {
   281  	size := int(r.ReadLong())
   282  	if size < 0 {
   283  		fnName := "Read" + strings.ToTitle(op)
   284  		r.ReportError(fnName, "invalid "+op+" length")
   285  		return nil
   286  	}
   287  	if size == 0 {
   288  		return []byte{}
   289  	}
   290  	if max := r.cfg.getMaxByteSliceSize(); max > 0 && size > max {
   291  		fnName := "Read" + strings.ToTitle(op)
   292  		r.ReportError(fnName, "size is greater than `Config.MaxByteSliceSize`")
   293  		return nil
   294  	}
   295  
   296  	// The bytes are entirely in the buffer and of a reasonable size.
   297  	// Use the byte slab.
   298  	if r.head+size <= r.tail && size <= 1024 {
   299  		if cap(r.slab) < size {
   300  			r.slab = make([]byte, 1024)
   301  		}
   302  		dst := r.slab[:size]
   303  		r.slab = r.slab[size:]
   304  		copy(dst, r.buf[r.head:r.head+size])
   305  		r.head += size
   306  		return dst
   307  	}
   308  
   309  	buf := make([]byte, size)
   310  	r.Read(buf)
   311  	return buf
   312  }
   313  
   314  // ReadBlockHeader reads a Block Header from the Reader.
   315  func (r *Reader) ReadBlockHeader() (int64, int64) {
   316  	length := r.ReadLong()
   317  	if length < 0 {
   318  		size := r.ReadLong()
   319  
   320  		return -length, size
   321  	}
   322  
   323  	return length, 0
   324  }