github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/codec/decode.go (about)

     1  package codec
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    11  )
    12  
    13  // ErrOverflow is returned when an integer is too large to be represented.
    14  var ErrOverflow = errors.New("proto: integer overflow")
    15  
    16  // ErrBadWireType is returned when decoding a wire-type from a buffer that
    17  // is not valid.
    18  var ErrBadWireType = errors.New("proto: bad wiretype")
    19  
    20  var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{}
    21  var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{}
    22  var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{}
    23  
    24  func init() {
    25  	varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true
    26  	varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true
    27  	varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true
    28  	varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true
    29  	varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true
    30  	varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true
    31  	varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true
    32  	varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true
    33  
    34  	fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true
    35  	fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true
    36  	fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true
    37  
    38  	fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true
    39  	fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true
    40  	fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true
    41  }
    42  
    43  func (cb *Buffer) decodeVarintSlow() (x uint64, err error) {
    44  	i := cb.index
    45  	l := len(cb.buf)
    46  
    47  	for shift := uint(0); shift < 64; shift += 7 {
    48  		if i >= l {
    49  			err = io.ErrUnexpectedEOF
    50  			return
    51  		}
    52  		b := cb.buf[i]
    53  		i++
    54  		x |= (uint64(b) & 0x7F) << shift
    55  		if b < 0x80 {
    56  			cb.index = i
    57  			return
    58  		}
    59  	}
    60  
    61  	// The number is too large to represent in a 64-bit value.
    62  	err = ErrOverflow
    63  	return
    64  }
    65  
    66  // DecodeVarint reads a varint-encoded integer from the Buffer.
    67  // This is the format for the
    68  // int32, int64, uint32, uint64, bool, and enum
    69  // protocol buffer types.
    70  func (cb *Buffer) DecodeVarint() (uint64, error) {
    71  	i := cb.index
    72  	buf := cb.buf
    73  
    74  	if i >= len(buf) {
    75  		return 0, io.ErrUnexpectedEOF
    76  	} else if buf[i] < 0x80 {
    77  		cb.index++
    78  		return uint64(buf[i]), nil
    79  	} else if len(buf)-i < 10 {
    80  		return cb.decodeVarintSlow()
    81  	}
    82  
    83  	var b uint64
    84  	// we already checked the first byte
    85  	x := uint64(buf[i]) - 0x80
    86  	i++
    87  
    88  	b = uint64(buf[i])
    89  	i++
    90  	x += b << 7
    91  	if b&0x80 == 0 {
    92  		goto done
    93  	}
    94  	x -= 0x80 << 7
    95  
    96  	b = uint64(buf[i])
    97  	i++
    98  	x += b << 14
    99  	if b&0x80 == 0 {
   100  		goto done
   101  	}
   102  	x -= 0x80 << 14
   103  
   104  	b = uint64(buf[i])
   105  	i++
   106  	x += b << 21
   107  	if b&0x80 == 0 {
   108  		goto done
   109  	}
   110  	x -= 0x80 << 21
   111  
   112  	b = uint64(buf[i])
   113  	i++
   114  	x += b << 28
   115  	if b&0x80 == 0 {
   116  		goto done
   117  	}
   118  	x -= 0x80 << 28
   119  
   120  	b = uint64(buf[i])
   121  	i++
   122  	x += b << 35
   123  	if b&0x80 == 0 {
   124  		goto done
   125  	}
   126  	x -= 0x80 << 35
   127  
   128  	b = uint64(buf[i])
   129  	i++
   130  	x += b << 42
   131  	if b&0x80 == 0 {
   132  		goto done
   133  	}
   134  	x -= 0x80 << 42
   135  
   136  	b = uint64(buf[i])
   137  	i++
   138  	x += b << 49
   139  	if b&0x80 == 0 {
   140  		goto done
   141  	}
   142  	x -= 0x80 << 49
   143  
   144  	b = uint64(buf[i])
   145  	i++
   146  	x += b << 56
   147  	if b&0x80 == 0 {
   148  		goto done
   149  	}
   150  	x -= 0x80 << 56
   151  
   152  	b = uint64(buf[i])
   153  	i++
   154  	x += b << 63
   155  	if b&0x80 == 0 {
   156  		goto done
   157  	}
   158  	// x -= 0x80 << 63 // Always zero.
   159  
   160  	return 0, ErrOverflow
   161  
   162  done:
   163  	cb.index = i
   164  	return x, nil
   165  }
   166  
   167  // DecodeTagAndWireType decodes a field tag and wire type from input.
   168  // This reads a varint and then extracts the two fields from the varint
   169  // value read.
   170  func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) {
   171  	var v uint64
   172  	v, err = cb.DecodeVarint()
   173  	if err != nil {
   174  		return
   175  	}
   176  	// low 7 bits is wire type
   177  	wireType = int8(v & 7)
   178  	// rest is int32 tag number
   179  	v = v >> 3
   180  	if v > math.MaxInt32 {
   181  		err = fmt.Errorf("tag number out of range: %d", v)
   182  		return
   183  	}
   184  	tag = int32(v)
   185  	return
   186  }
   187  
   188  // DecodeFixed64 reads a 64-bit integer from the Buffer.
   189  // This is the format for the
   190  // fixed64, sfixed64, and double protocol buffer types.
   191  func (cb *Buffer) DecodeFixed64() (x uint64, err error) {
   192  	// x, err already 0
   193  	i := cb.index + 8
   194  	if i < 0 || i > len(cb.buf) {
   195  		err = io.ErrUnexpectedEOF
   196  		return
   197  	}
   198  	cb.index = i
   199  
   200  	x = uint64(cb.buf[i-8])
   201  	x |= uint64(cb.buf[i-7]) << 8
   202  	x |= uint64(cb.buf[i-6]) << 16
   203  	x |= uint64(cb.buf[i-5]) << 24
   204  	x |= uint64(cb.buf[i-4]) << 32
   205  	x |= uint64(cb.buf[i-3]) << 40
   206  	x |= uint64(cb.buf[i-2]) << 48
   207  	x |= uint64(cb.buf[i-1]) << 56
   208  	return
   209  }
   210  
   211  // DecodeFixed32 reads a 32-bit integer from the Buffer.
   212  // This is the format for the
   213  // fixed32, sfixed32, and float protocol buffer types.
   214  func (cb *Buffer) DecodeFixed32() (x uint64, err error) {
   215  	// x, err already 0
   216  	i := cb.index + 4
   217  	if i < 0 || i > len(cb.buf) {
   218  		err = io.ErrUnexpectedEOF
   219  		return
   220  	}
   221  	cb.index = i
   222  
   223  	x = uint64(cb.buf[i-4])
   224  	x |= uint64(cb.buf[i-3]) << 8
   225  	x |= uint64(cb.buf[i-2]) << 16
   226  	x |= uint64(cb.buf[i-1]) << 24
   227  	return
   228  }
   229  
   230  // DecodeZigZag32 decodes a signed 32-bit integer from the given
   231  // zig-zag encoded value.
   232  func DecodeZigZag32(v uint64) int32 {
   233  	return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31))
   234  }
   235  
   236  // DecodeZigZag64 decodes a signed 64-bit integer from the given
   237  // zig-zag encoded value.
   238  func DecodeZigZag64(v uint64) int64 {
   239  	return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63))
   240  }
   241  
   242  // DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
   243  // This is the format used for the bytes protocol buffer
   244  // type and for embedded messages.
   245  func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
   246  	n, err := cb.DecodeVarint()
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	nb := int(n)
   252  	if nb < 0 {
   253  		return nil, fmt.Errorf("proto: bad byte length %d", nb)
   254  	}
   255  	end := cb.index + nb
   256  	if end < cb.index || end > len(cb.buf) {
   257  		return nil, io.ErrUnexpectedEOF
   258  	}
   259  
   260  	if !alloc {
   261  		buf = cb.buf[cb.index:end]
   262  		cb.index = end
   263  		return
   264  	}
   265  
   266  	buf = make([]byte, nb)
   267  	copy(buf, cb.buf[cb.index:])
   268  	cb.index = end
   269  	return
   270  }
   271  
   272  // ReadGroup reads the input until a "group end" tag is found
   273  // and returns the data up to that point. Subsequent reads from
   274  // the buffer will read data after the group end tag. If alloc
   275  // is true, the data is copied to a new slice before being returned.
   276  // Otherwise, the returned slice is a view into the buffer's
   277  // underlying byte slice.
   278  //
   279  // This function correctly handles nested groups: if a "group start"
   280  // tag is found, then that group's end tag will be included in the
   281  // returned data.
   282  func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) {
   283  	var groupEnd, dataEnd int
   284  	groupEnd, dataEnd, err := cb.findGroupEnd()
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	var results []byte
   289  	if !alloc {
   290  		results = cb.buf[cb.index:dataEnd]
   291  	} else {
   292  		results = make([]byte, dataEnd-cb.index)
   293  		copy(results, cb.buf[cb.index:])
   294  	}
   295  	cb.index = groupEnd
   296  	return results, nil
   297  }
   298  
   299  // SkipGroup is like ReadGroup, except that it discards the
   300  // data and just advances the buffer to point to the input
   301  // right *after* the "group end" tag.
   302  func (cb *Buffer) SkipGroup() error {
   303  	groupEnd, _, err := cb.findGroupEnd()
   304  	if err != nil {
   305  		return err
   306  	}
   307  	cb.index = groupEnd
   308  	return nil
   309  }
   310  
   311  func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) {
   312  	bs := cb.buf
   313  	start := cb.index
   314  	defer func() {
   315  		cb.index = start
   316  	}()
   317  	for {
   318  		fieldStart := cb.index
   319  		// read a field tag
   320  		_, wireType, err := cb.DecodeTagAndWireType()
   321  		if err != nil {
   322  			return 0, 0, err
   323  		}
   324  		// skip past the field's data
   325  		switch wireType {
   326  		case proto.WireFixed32:
   327  			if err := cb.Skip(4); err != nil {
   328  				return 0, 0, err
   329  			}
   330  		case proto.WireFixed64:
   331  			if err := cb.Skip(8); err != nil {
   332  				return 0, 0, err
   333  			}
   334  		case proto.WireVarint:
   335  			// skip varint by finding last byte (has high bit unset)
   336  			i := cb.index
   337  			limit := i + 10 // varint cannot be >10 bytes
   338  			for {
   339  				if i >= limit {
   340  					return 0, 0, ErrOverflow
   341  				}
   342  				if i >= len(bs) {
   343  					return 0, 0, io.ErrUnexpectedEOF
   344  				}
   345  				if bs[i]&0x80 == 0 {
   346  					break
   347  				}
   348  				i++
   349  			}
   350  			// TODO: This would only overflow if buffer length was MaxInt and we
   351  			// read the last byte. This is not a real/feasible concern on 64-bit
   352  			// systems. Something to worry about for 32-bit systems? Do we care?
   353  			cb.index = i + 1
   354  		case proto.WireBytes:
   355  			l, err := cb.DecodeVarint()
   356  			if err != nil {
   357  				return 0, 0, err
   358  			}
   359  			if err := cb.Skip(int(l)); err != nil {
   360  				return 0, 0, err
   361  			}
   362  		case proto.WireStartGroup:
   363  			if err := cb.SkipGroup(); err != nil {
   364  				return 0, 0, err
   365  			}
   366  		case proto.WireEndGroup:
   367  			return cb.index, fieldStart, nil
   368  		default:
   369  			return 0, 0, ErrBadWireType
   370  		}
   371  	}
   372  }