github.com/jhump/protoreflect@v1.16.0/codec/encode_fields.go (about)

     1  package codec
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"reflect"
     7  	"sort"
     8  
     9  	"github.com/golang/protobuf/proto"
    10  	"google.golang.org/protobuf/types/descriptorpb"
    11  
    12  	"github.com/jhump/protoreflect/desc"
    13  )
    14  
    15  // EncodeZigZag64 does zig-zag encoding to convert the given
    16  // signed 64-bit integer into a form that can be expressed
    17  // efficiently as a varint, even for negative values.
    18  func EncodeZigZag64(v int64) uint64 {
    19  	return (uint64(v) << 1) ^ uint64(v>>63)
    20  }
    21  
    22  // EncodeZigZag32 does zig-zag encoding to convert the given
    23  // signed 32-bit integer into a form that can be expressed
    24  // efficiently as a varint, even for negative values.
    25  func EncodeZigZag32(v int32) uint64 {
    26  	return uint64((uint32(v) << 1) ^ uint32((v >> 31)))
    27  }
    28  
    29  func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
    30  	if fd.IsMap() {
    31  		mp := val.(map[interface{}]interface{})
    32  		entryType := fd.GetMessageType()
    33  		keyType := entryType.FindFieldByNumber(1)
    34  		valType := entryType.FindFieldByNumber(2)
    35  		var entryBuffer Buffer
    36  		if cb.IsDeterministic() {
    37  			entryBuffer.SetDeterministic(true)
    38  			keys := make([]interface{}, 0, len(mp))
    39  			for k := range mp {
    40  				keys = append(keys, k)
    41  			}
    42  			sort.Sort(sortable(keys))
    43  			for _, k := range keys {
    44  				v := mp[k]
    45  				entryBuffer.Reset()
    46  				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
    47  					return err
    48  				}
    49  				rv := reflect.ValueOf(v)
    50  				if rv.Kind() != reflect.Ptr || !rv.IsNil() {
    51  					if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
    52  						return err
    53  					}
    54  				}
    55  				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    56  					return err
    57  				}
    58  				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
    59  					return err
    60  				}
    61  			}
    62  		} else {
    63  			for k, v := range mp {
    64  				entryBuffer.Reset()
    65  				if err := entryBuffer.encodeFieldElement(keyType, k); err != nil {
    66  					return err
    67  				}
    68  				rv := reflect.ValueOf(v)
    69  				if rv.Kind() != reflect.Ptr || !rv.IsNil() {
    70  					if err := entryBuffer.encodeFieldElement(valType, v); err != nil {
    71  						return err
    72  					}
    73  				}
    74  				if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    75  					return err
    76  				}
    77  				if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil {
    78  					return err
    79  				}
    80  			}
    81  		}
    82  		return nil
    83  	} else if fd.IsRepeated() {
    84  		sl := val.([]interface{})
    85  		wt, err := getWireType(fd.GetType())
    86  		if err != nil {
    87  			return err
    88  		}
    89  		if isPacked(fd) && len(sl) > 0 &&
    90  			(wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) {
    91  			// packed repeated field
    92  			var packedBuffer Buffer
    93  			for _, v := range sl {
    94  				if err := packedBuffer.encodeFieldValue(fd, v); err != nil {
    95  					return err
    96  				}
    97  			}
    98  			if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil {
    99  				return err
   100  			}
   101  			return cb.EncodeRawBytes(packedBuffer.Bytes())
   102  		} else {
   103  			// non-packed repeated field
   104  			for _, v := range sl {
   105  				if err := cb.encodeFieldElement(fd, v); err != nil {
   106  					return err
   107  				}
   108  			}
   109  			return nil
   110  		}
   111  	} else {
   112  		return cb.encodeFieldElement(fd, val)
   113  	}
   114  }
   115  
   116  func isPacked(fd *desc.FieldDescriptor) bool {
   117  	opts := fd.AsFieldDescriptorProto().GetOptions()
   118  	// if set, use that value
   119  	if opts != nil && opts.Packed != nil {
   120  		return opts.GetPacked()
   121  	}
   122  	// if unset: proto2 defaults to false, proto3 to true
   123  	return fd.GetFile().IsProto3()
   124  }
   125  
   126  // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64),
   127  // bools, or strings.
   128  type sortable []interface{}
   129  
   130  func (s sortable) Len() int {
   131  	return len(s)
   132  }
   133  
   134  func (s sortable) Less(i, j int) bool {
   135  	vi := s[i]
   136  	vj := s[j]
   137  	switch reflect.TypeOf(vi).Kind() {
   138  	case reflect.Int32:
   139  		return vi.(int32) < vj.(int32)
   140  	case reflect.Int64:
   141  		return vi.(int64) < vj.(int64)
   142  	case reflect.Uint32:
   143  		return vi.(uint32) < vj.(uint32)
   144  	case reflect.Uint64:
   145  		return vi.(uint64) < vj.(uint64)
   146  	case reflect.String:
   147  		return vi.(string) < vj.(string)
   148  	case reflect.Bool:
   149  		return !vi.(bool) && vj.(bool)
   150  	default:
   151  		panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi)))
   152  	}
   153  }
   154  
   155  func (s sortable) Swap(i, j int) {
   156  	s[i], s[j] = s[j], s[i]
   157  }
   158  
   159  func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error {
   160  	wt, err := getWireType(fd.GetType())
   161  	if err != nil {
   162  		return err
   163  	}
   164  	if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil {
   165  		return err
   166  	}
   167  	if err := b.encodeFieldValue(fd, val); err != nil {
   168  		return err
   169  	}
   170  	if wt == proto.WireStartGroup {
   171  		return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup)
   172  	}
   173  	return nil
   174  }
   175  
   176  func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error {
   177  	switch fd.GetType() {
   178  	case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   179  		v := val.(bool)
   180  		if v {
   181  			return b.EncodeVarint(1)
   182  		}
   183  		return b.EncodeVarint(0)
   184  
   185  	case descriptorpb.FieldDescriptorProto_TYPE_ENUM,
   186  		descriptorpb.FieldDescriptorProto_TYPE_INT32:
   187  		v := val.(int32)
   188  		return b.EncodeVarint(uint64(v))
   189  
   190  	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
   191  		v := val.(int32)
   192  		return b.EncodeFixed32(uint64(v))
   193  
   194  	case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
   195  		v := val.(int32)
   196  		return b.EncodeVarint(EncodeZigZag32(v))
   197  
   198  	case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
   199  		v := val.(uint32)
   200  		return b.EncodeVarint(uint64(v))
   201  
   202  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
   203  		v := val.(uint32)
   204  		return b.EncodeFixed32(uint64(v))
   205  
   206  	case descriptorpb.FieldDescriptorProto_TYPE_INT64:
   207  		v := val.(int64)
   208  		return b.EncodeVarint(uint64(v))
   209  
   210  	case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
   211  		v := val.(int64)
   212  		return b.EncodeFixed64(uint64(v))
   213  
   214  	case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   215  		v := val.(int64)
   216  		return b.EncodeVarint(EncodeZigZag64(v))
   217  
   218  	case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
   219  		v := val.(uint64)
   220  		return b.EncodeVarint(v)
   221  
   222  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
   223  		v := val.(uint64)
   224  		return b.EncodeFixed64(v)
   225  
   226  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   227  		v := val.(float64)
   228  		return b.EncodeFixed64(math.Float64bits(v))
   229  
   230  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   231  		v := val.(float32)
   232  		return b.EncodeFixed32(uint64(math.Float32bits(v)))
   233  
   234  	case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   235  		v := val.([]byte)
   236  		return b.EncodeRawBytes(v)
   237  
   238  	case descriptorpb.FieldDescriptorProto_TYPE_STRING:
   239  		v := val.(string)
   240  		return b.EncodeRawBytes(([]byte)(v))
   241  
   242  	case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
   243  		return b.EncodeDelimitedMessage(val.(proto.Message))
   244  
   245  	case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   246  		// just append the nested message to this buffer
   247  		return b.EncodeMessage(val.(proto.Message))
   248  		// whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag
   249  
   250  	default:
   251  		return fmt.Errorf("unrecognized field type: %v", fd.GetType())
   252  	}
   253  }
   254  
   255  func getWireType(t descriptorpb.FieldDescriptorProto_Type) (int8, error) {
   256  	switch t {
   257  	case descriptorpb.FieldDescriptorProto_TYPE_ENUM,
   258  		descriptorpb.FieldDescriptorProto_TYPE_BOOL,
   259  		descriptorpb.FieldDescriptorProto_TYPE_INT32,
   260  		descriptorpb.FieldDescriptorProto_TYPE_SINT32,
   261  		descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   262  		descriptorpb.FieldDescriptorProto_TYPE_INT64,
   263  		descriptorpb.FieldDescriptorProto_TYPE_SINT64,
   264  		descriptorpb.FieldDescriptorProto_TYPE_UINT64:
   265  		return proto.WireVarint, nil
   266  
   267  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
   268  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
   269  		descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   270  		return proto.WireFixed32, nil
   271  
   272  	case descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
   273  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
   274  		descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   275  		return proto.WireFixed64, nil
   276  
   277  	case descriptorpb.FieldDescriptorProto_TYPE_BYTES,
   278  		descriptorpb.FieldDescriptorProto_TYPE_STRING,
   279  		descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
   280  		return proto.WireBytes, nil
   281  
   282  	case descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   283  		return proto.WireStartGroup, nil
   284  
   285  	default:
   286  		return 0, ErrBadWireType
   287  	}
   288  }