github.com/jhump/protoreflect@v1.16.0/internal/codec/decode.go (about)

     1  package codec
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  )
    11  
    12  // ErrOverflow is returned when an integer is too large to be represented.
    13  var ErrOverflow = errors.New("proto: integer overflow")
    14  
    15  // ErrBadWireType is returned when decoding a wire-type from a buffer that
    16  // is not valid.
    17  var ErrBadWireType = errors.New("proto: bad wiretype")
    18  
    19  func (cb *Buffer) decodeVarintSlow() (x uint64, err error) {
    20  	i := cb.index
    21  	l := len(cb.buf)
    22  
    23  	for shift := uint(0); shift < 64; shift += 7 {
    24  		if i >= l {
    25  			err = io.ErrUnexpectedEOF
    26  			return
    27  		}
    28  		b := cb.buf[i]
    29  		i++
    30  		x |= (uint64(b) & 0x7F) << shift
    31  		if b < 0x80 {
    32  			cb.index = i
    33  			return
    34  		}
    35  	}
    36  
    37  	// The number is too large to represent in a 64-bit value.
    38  	err = ErrOverflow
    39  	return
    40  }
    41  
    42  // DecodeVarint reads a varint-encoded integer from the Buffer.
    43  // This is the format for the
    44  // int32, int64, uint32, uint64, bool, and enum
    45  // protocol buffer types.
    46  func (cb *Buffer) DecodeVarint() (uint64, error) {
    47  	i := cb.index
    48  	buf := cb.buf
    49  
    50  	if i >= len(buf) {
    51  		return 0, io.ErrUnexpectedEOF
    52  	} else if buf[i] < 0x80 {
    53  		cb.index++
    54  		return uint64(buf[i]), nil
    55  	} else if len(buf)-i < 10 {
    56  		return cb.decodeVarintSlow()
    57  	}
    58  
    59  	var b uint64
    60  	// we already checked the first byte
    61  	x := uint64(buf[i]) - 0x80
    62  	i++
    63  
    64  	b = uint64(buf[i])
    65  	i++
    66  	x += b << 7
    67  	if b&0x80 == 0 {
    68  		goto done
    69  	}
    70  	x -= 0x80 << 7
    71  
    72  	b = uint64(buf[i])
    73  	i++
    74  	x += b << 14
    75  	if b&0x80 == 0 {
    76  		goto done
    77  	}
    78  	x -= 0x80 << 14
    79  
    80  	b = uint64(buf[i])
    81  	i++
    82  	x += b << 21
    83  	if b&0x80 == 0 {
    84  		goto done
    85  	}
    86  	x -= 0x80 << 21
    87  
    88  	b = uint64(buf[i])
    89  	i++
    90  	x += b << 28
    91  	if b&0x80 == 0 {
    92  		goto done
    93  	}
    94  	x -= 0x80 << 28
    95  
    96  	b = uint64(buf[i])
    97  	i++
    98  	x += b << 35
    99  	if b&0x80 == 0 {
   100  		goto done
   101  	}
   102  	x -= 0x80 << 35
   103  
   104  	b = uint64(buf[i])
   105  	i++
   106  	x += b << 42
   107  	if b&0x80 == 0 {
   108  		goto done
   109  	}
   110  	x -= 0x80 << 42
   111  
   112  	b = uint64(buf[i])
   113  	i++
   114  	x += b << 49
   115  	if b&0x80 == 0 {
   116  		goto done
   117  	}
   118  	x -= 0x80 << 49
   119  
   120  	b = uint64(buf[i])
   121  	i++
   122  	x += b << 56
   123  	if b&0x80 == 0 {
   124  		goto done
   125  	}
   126  	x -= 0x80 << 56
   127  
   128  	b = uint64(buf[i])
   129  	i++
   130  	x += b << 63
   131  	if b&0x80 == 0 {
   132  		goto done
   133  	}
   134  	// x -= 0x80 << 63 // Always zero.
   135  
   136  	return 0, ErrOverflow
   137  
   138  done:
   139  	cb.index = i
   140  	return x, nil
   141  }
   142  
   143  // DecodeTagAndWireType decodes a field tag and wire type from input.
   144  // This reads a varint and then extracts the two fields from the varint
   145  // value read.
   146  func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) {
   147  	var v uint64
   148  	v, err = cb.DecodeVarint()
   149  	if err != nil {
   150  		return
   151  	}
   152  	// low 7 bits is wire type
   153  	wireType = int8(v & 7)
   154  	// rest is int32 tag number
   155  	v = v >> 3
   156  	if v > math.MaxInt32 {
   157  		err = fmt.Errorf("tag number out of range: %d", v)
   158  		return
   159  	}
   160  	tag = int32(v)
   161  	return
   162  }
   163  
   164  // DecodeFixed64 reads a 64-bit integer from the Buffer.
   165  // This is the format for the
   166  // fixed64, sfixed64, and double protocol buffer types.
   167  func (cb *Buffer) DecodeFixed64() (x uint64, err error) {
   168  	// x, err already 0
   169  	i := cb.index + 8
   170  	if i < 0 || i > len(cb.buf) {
   171  		err = io.ErrUnexpectedEOF
   172  		return
   173  	}
   174  	cb.index = i
   175  
   176  	x = uint64(cb.buf[i-8])
   177  	x |= uint64(cb.buf[i-7]) << 8
   178  	x |= uint64(cb.buf[i-6]) << 16
   179  	x |= uint64(cb.buf[i-5]) << 24
   180  	x |= uint64(cb.buf[i-4]) << 32
   181  	x |= uint64(cb.buf[i-3]) << 40
   182  	x |= uint64(cb.buf[i-2]) << 48
   183  	x |= uint64(cb.buf[i-1]) << 56
   184  	return
   185  }
   186  
   187  // DecodeFixed32 reads a 32-bit integer from the Buffer.
   188  // This is the format for the
   189  // fixed32, sfixed32, and float protocol buffer types.
   190  func (cb *Buffer) DecodeFixed32() (x uint64, err error) {
   191  	// x, err already 0
   192  	i := cb.index + 4
   193  	if i < 0 || i > len(cb.buf) {
   194  		err = io.ErrUnexpectedEOF
   195  		return
   196  	}
   197  	cb.index = i
   198  
   199  	x = uint64(cb.buf[i-4])
   200  	x |= uint64(cb.buf[i-3]) << 8
   201  	x |= uint64(cb.buf[i-2]) << 16
   202  	x |= uint64(cb.buf[i-1]) << 24
   203  	return
   204  }
   205  
   206  // DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
   207  // This is the format used for the bytes protocol buffer
   208  // type and for embedded messages.
   209  func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
   210  	n, err := cb.DecodeVarint()
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	nb := int(n)
   216  	if nb < 0 {
   217  		return nil, fmt.Errorf("proto: bad byte length %d", nb)
   218  	}
   219  	end := cb.index + nb
   220  	if end < cb.index || end > len(cb.buf) {
   221  		return nil, io.ErrUnexpectedEOF
   222  	}
   223  
   224  	if !alloc {
   225  		buf = cb.buf[cb.index:end]
   226  		cb.index = end
   227  		return
   228  	}
   229  
   230  	buf = make([]byte, nb)
   231  	copy(buf, cb.buf[cb.index:])
   232  	cb.index = end
   233  	return
   234  }
   235  
   236  // ReadGroup reads the input until a "group end" tag is found
   237  // and returns the data up to that point. Subsequent reads from
   238  // the buffer will read data after the group end tag. If alloc
   239  // is true, the data is copied to a new slice before being returned.
   240  // Otherwise, the returned slice is a view into the buffer's
   241  // underlying byte slice.
   242  //
   243  // This function correctly handles nested groups: if a "group start"
   244  // tag is found, then that group's end tag will be included in the
   245  // returned data.
   246  func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) {
   247  	var groupEnd, dataEnd int
   248  	groupEnd, dataEnd, err := cb.findGroupEnd()
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	var results []byte
   253  	if !alloc {
   254  		results = cb.buf[cb.index:dataEnd]
   255  	} else {
   256  		results = make([]byte, dataEnd-cb.index)
   257  		copy(results, cb.buf[cb.index:])
   258  	}
   259  	cb.index = groupEnd
   260  	return results, nil
   261  }
   262  
   263  // SkipGroup is like ReadGroup, except that it discards the
   264  // data and just advances the buffer to point to the input
   265  // right *after* the "group end" tag.
   266  func (cb *Buffer) SkipGroup() error {
   267  	groupEnd, _, err := cb.findGroupEnd()
   268  	if err != nil {
   269  		return err
   270  	}
   271  	cb.index = groupEnd
   272  	return nil
   273  }
   274  
   275  // SkipField attempts to skip the value of a field with the given wire
   276  // type. When consuming a protobuf-encoded stream, it can be called immediately
   277  // after DecodeTagAndWireType to discard the subsequent data for the field.
   278  func (cb *Buffer) SkipField(wireType int8) error {
   279  	switch wireType {
   280  	case proto.WireFixed32:
   281  		if err := cb.Skip(4); err != nil {
   282  			return err
   283  		}
   284  	case proto.WireFixed64:
   285  		if err := cb.Skip(8); err != nil {
   286  			return err
   287  		}
   288  	case proto.WireVarint:
   289  		// skip varint by finding last byte (has high bit unset)
   290  		i := cb.index
   291  		limit := i + 10 // varint cannot be >10 bytes
   292  		for {
   293  			if i >= limit {
   294  				return ErrOverflow
   295  			}
   296  			if i >= len(cb.buf) {
   297  				return io.ErrUnexpectedEOF
   298  			}
   299  			if cb.buf[i]&0x80 == 0 {
   300  				break
   301  			}
   302  			i++
   303  		}
   304  		// TODO: This would only overflow if buffer length was MaxInt and we
   305  		// read the last byte. This is not a real/feasible concern on 64-bit
   306  		// systems. Something to worry about for 32-bit systems? Do we care?
   307  		cb.index = i + 1
   308  	case proto.WireBytes:
   309  		l, err := cb.DecodeVarint()
   310  		if err != nil {
   311  			return err
   312  		}
   313  		if err := cb.Skip(int(l)); err != nil {
   314  			return err
   315  		}
   316  	case proto.WireStartGroup:
   317  		if err := cb.SkipGroup(); err != nil {
   318  			return err
   319  		}
   320  	default:
   321  		return ErrBadWireType
   322  	}
   323  	return nil
   324  }
   325  
   326  func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) {
   327  	start := cb.index
   328  	defer func() {
   329  		cb.index = start
   330  	}()
   331  	for {
   332  		fieldStart := cb.index
   333  		// read a field tag
   334  		_, wireType, err := cb.DecodeTagAndWireType()
   335  		if err != nil {
   336  			return 0, 0, err
   337  		}
   338  		if wireType == proto.WireEndGroup {
   339  			return cb.index, fieldStart, nil
   340  		}
   341  		// skip past the field's data
   342  		if err := cb.SkipField(wireType); err != nil {
   343  			return 0, 0, err
   344  		}
   345  	}
   346  }