github.com/jhump/protoreflect@v1.16.0/codec/encode_fields.go (about) 1 package codec 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 "sort" 8 9 "github.com/golang/protobuf/proto" 10 "google.golang.org/protobuf/types/descriptorpb" 11 12 "github.com/jhump/protoreflect/desc" 13 ) 14 15 // EncodeZigZag64 does zig-zag encoding to convert the given 16 // signed 64-bit integer into a form that can be expressed 17 // efficiently as a varint, even for negative values. 18 func EncodeZigZag64(v int64) uint64 { 19 return (uint64(v) << 1) ^ uint64(v>>63) 20 } 21 22 // EncodeZigZag32 does zig-zag encoding to convert the given 23 // signed 32-bit integer into a form that can be expressed 24 // efficiently as a varint, even for negative values. 25 func EncodeZigZag32(v int32) uint64 { 26 return uint64((uint32(v) << 1) ^ uint32((v >> 31))) 27 } 28 29 func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { 30 if fd.IsMap() { 31 mp := val.(map[interface{}]interface{}) 32 entryType := fd.GetMessageType() 33 keyType := entryType.FindFieldByNumber(1) 34 valType := entryType.FindFieldByNumber(2) 35 var entryBuffer Buffer 36 if cb.IsDeterministic() { 37 entryBuffer.SetDeterministic(true) 38 keys := make([]interface{}, 0, len(mp)) 39 for k := range mp { 40 keys = append(keys, k) 41 } 42 sort.Sort(sortable(keys)) 43 for _, k := range keys { 44 v := mp[k] 45 entryBuffer.Reset() 46 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { 47 return err 48 } 49 rv := reflect.ValueOf(v) 50 if rv.Kind() != reflect.Ptr || !rv.IsNil() { 51 if err := entryBuffer.encodeFieldElement(valType, v); err != nil { 52 return err 53 } 54 } 55 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 56 return err 57 } 58 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { 59 return err 60 } 61 } 62 } else { 63 for k, v := range mp { 64 entryBuffer.Reset() 65 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { 66 return err 67 } 68 rv := reflect.ValueOf(v) 69 if rv.Kind() != reflect.Ptr || !rv.IsNil() { 70 if err := entryBuffer.encodeFieldElement(valType, v); err != nil { 71 return err 72 } 73 } 74 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 75 return err 76 } 77 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { 78 return err 79 } 80 } 81 } 82 return nil 83 } else if fd.IsRepeated() { 84 sl := val.([]interface{}) 85 wt, err := getWireType(fd.GetType()) 86 if err != nil { 87 return err 88 } 89 if isPacked(fd) && len(sl) > 0 && 90 (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) { 91 // packed repeated field 92 var packedBuffer Buffer 93 for _, v := range sl { 94 if err := packedBuffer.encodeFieldValue(fd, v); err != nil { 95 return err 96 } 97 } 98 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 99 return err 100 } 101 return cb.EncodeRawBytes(packedBuffer.Bytes()) 102 } else { 103 // non-packed repeated field 104 for _, v := range sl { 105 if err := cb.encodeFieldElement(fd, v); err != nil { 106 return err 107 } 108 } 109 return nil 110 } 111 } else { 112 return cb.encodeFieldElement(fd, val) 113 } 114 } 115 116 func isPacked(fd *desc.FieldDescriptor) bool { 117 opts := fd.AsFieldDescriptorProto().GetOptions() 118 // if set, use that value 119 if opts != nil && opts.Packed != nil { 120 return opts.GetPacked() 121 } 122 // if unset: proto2 defaults to false, proto3 to true 123 return fd.GetFile().IsProto3() 124 } 125 126 // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64), 127 // bools, or strings. 128 type sortable []interface{} 129 130 func (s sortable) Len() int { 131 return len(s) 132 } 133 134 func (s sortable) Less(i, j int) bool { 135 vi := s[i] 136 vj := s[j] 137 switch reflect.TypeOf(vi).Kind() { 138 case reflect.Int32: 139 return vi.(int32) < vj.(int32) 140 case reflect.Int64: 141 return vi.(int64) < vj.(int64) 142 case reflect.Uint32: 143 return vi.(uint32) < vj.(uint32) 144 case reflect.Uint64: 145 return vi.(uint64) < vj.(uint64) 146 case reflect.String: 147 return vi.(string) < vj.(string) 148 case reflect.Bool: 149 return !vi.(bool) && vj.(bool) 150 default: 151 panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi))) 152 } 153 } 154 155 func (s sortable) Swap(i, j int) { 156 s[i], s[j] = s[j], s[i] 157 } 158 159 func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error { 160 wt, err := getWireType(fd.GetType()) 161 if err != nil { 162 return err 163 } 164 if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil { 165 return err 166 } 167 if err := b.encodeFieldValue(fd, val); err != nil { 168 return err 169 } 170 if wt == proto.WireStartGroup { 171 return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup) 172 } 173 return nil 174 } 175 176 func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { 177 switch fd.GetType() { 178 case descriptorpb.FieldDescriptorProto_TYPE_BOOL: 179 v := val.(bool) 180 if v { 181 return b.EncodeVarint(1) 182 } 183 return b.EncodeVarint(0) 184 185 case descriptorpb.FieldDescriptorProto_TYPE_ENUM, 186 descriptorpb.FieldDescriptorProto_TYPE_INT32: 187 v := val.(int32) 188 return b.EncodeVarint(uint64(v)) 189 190 case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: 191 v := val.(int32) 192 return b.EncodeFixed32(uint64(v)) 193 194 case descriptorpb.FieldDescriptorProto_TYPE_SINT32: 195 v := val.(int32) 196 return b.EncodeVarint(EncodeZigZag32(v)) 197 198 case descriptorpb.FieldDescriptorProto_TYPE_UINT32: 199 v := val.(uint32) 200 return b.EncodeVarint(uint64(v)) 201 202 case descriptorpb.FieldDescriptorProto_TYPE_FIXED32: 203 v := val.(uint32) 204 return b.EncodeFixed32(uint64(v)) 205 206 case descriptorpb.FieldDescriptorProto_TYPE_INT64: 207 v := val.(int64) 208 return b.EncodeVarint(uint64(v)) 209 210 case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: 211 v := val.(int64) 212 return b.EncodeFixed64(uint64(v)) 213 214 case descriptorpb.FieldDescriptorProto_TYPE_SINT64: 215 v := val.(int64) 216 return b.EncodeVarint(EncodeZigZag64(v)) 217 218 case descriptorpb.FieldDescriptorProto_TYPE_UINT64: 219 v := val.(uint64) 220 return b.EncodeVarint(v) 221 222 case descriptorpb.FieldDescriptorProto_TYPE_FIXED64: 223 v := val.(uint64) 224 return b.EncodeFixed64(v) 225 226 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: 227 v := val.(float64) 228 return b.EncodeFixed64(math.Float64bits(v)) 229 230 case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: 231 v := val.(float32) 232 return b.EncodeFixed32(uint64(math.Float32bits(v))) 233 234 case descriptorpb.FieldDescriptorProto_TYPE_BYTES: 235 v := val.([]byte) 236 return b.EncodeRawBytes(v) 237 238 case descriptorpb.FieldDescriptorProto_TYPE_STRING: 239 v := val.(string) 240 return b.EncodeRawBytes(([]byte)(v)) 241 242 case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: 243 return b.EncodeDelimitedMessage(val.(proto.Message)) 244 245 case descriptorpb.FieldDescriptorProto_TYPE_GROUP: 246 // just append the nested message to this buffer 247 return b.EncodeMessage(val.(proto.Message)) 248 // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag 249 250 default: 251 return fmt.Errorf("unrecognized field type: %v", fd.GetType()) 252 } 253 } 254 255 func getWireType(t descriptorpb.FieldDescriptorProto_Type) (int8, error) { 256 switch t { 257 case descriptorpb.FieldDescriptorProto_TYPE_ENUM, 258 descriptorpb.FieldDescriptorProto_TYPE_BOOL, 259 descriptorpb.FieldDescriptorProto_TYPE_INT32, 260 descriptorpb.FieldDescriptorProto_TYPE_SINT32, 261 descriptorpb.FieldDescriptorProto_TYPE_UINT32, 262 descriptorpb.FieldDescriptorProto_TYPE_INT64, 263 descriptorpb.FieldDescriptorProto_TYPE_SINT64, 264 descriptorpb.FieldDescriptorProto_TYPE_UINT64: 265 return proto.WireVarint, nil 266 267 case descriptorpb.FieldDescriptorProto_TYPE_FIXED32, 268 descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, 269 descriptorpb.FieldDescriptorProto_TYPE_FLOAT: 270 return proto.WireFixed32, nil 271 272 case descriptorpb.FieldDescriptorProto_TYPE_FIXED64, 273 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, 274 descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: 275 return proto.WireFixed64, nil 276 277 case descriptorpb.FieldDescriptorProto_TYPE_BYTES, 278 descriptorpb.FieldDescriptorProto_TYPE_STRING, 279 descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: 280 return proto.WireBytes, nil 281 282 case descriptorpb.FieldDescriptorProto_TYPE_GROUP: 283 return proto.WireStartGroup, nil 284 285 default: 286 return 0, ErrBadWireType 287 } 288 }