github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/codec/encode_fields.go (about) 1 package codec 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 "sort" 8 9 "github.com/golang/protobuf/proto" 10 "github.com/golang/protobuf/protoc-gen-go/descriptor" 11 12 "github.com/hoveychen/protoreflect/desc" 13 ) 14 15 func (cb *Buffer) EncodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { 16 if fd.IsMap() { 17 mp := val.(map[interface{}]interface{}) 18 entryType := fd.GetMessageType() 19 keyType := entryType.FindFieldByNumber(1) 20 valType := entryType.FindFieldByNumber(2) 21 var entryBuffer Buffer 22 if cb.deterministic { 23 keys := make([]interface{}, 0, len(mp)) 24 for k := range mp { 25 keys = append(keys, k) 26 } 27 sort.Sort(sortable(keys)) 28 for _, k := range keys { 29 v := mp[k] 30 entryBuffer.Reset() 31 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { 32 return err 33 } 34 if err := entryBuffer.encodeFieldElement(valType, v); err != nil { 35 return err 36 } 37 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 38 return err 39 } 40 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { 41 return err 42 } 43 } 44 } else { 45 for k, v := range mp { 46 entryBuffer.Reset() 47 if err := entryBuffer.encodeFieldElement(keyType, k); err != nil { 48 return err 49 } 50 if err := entryBuffer.encodeFieldElement(valType, v); err != nil { 51 return err 52 } 53 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 54 return err 55 } 56 if err := cb.EncodeRawBytes(entryBuffer.Bytes()); err != nil { 57 return err 58 } 59 } 60 } 61 return nil 62 } else if fd.IsRepeated() { 63 sl := val.([]interface{}) 64 wt, err := getWireType(fd.GetType()) 65 if err != nil { 66 return err 67 } 68 if isPacked(fd) && len(sl) > 1 && 69 (wt == proto.WireVarint || wt == proto.WireFixed32 || wt == proto.WireFixed64) { 70 // packed repeated field 71 var packedBuffer Buffer 72 for _, v := range sl { 73 if err := packedBuffer.encodeFieldValue(fd, v); err != nil { 74 return err 75 } 76 } 77 if err := cb.EncodeTagAndWireType(fd.GetNumber(), proto.WireBytes); err != nil { 78 return err 79 } 80 return cb.EncodeRawBytes(packedBuffer.Bytes()) 81 } else { 82 // non-packed repeated field 83 for _, v := range sl { 84 if err := cb.encodeFieldElement(fd, v); err != nil { 85 return err 86 } 87 } 88 return nil 89 } 90 } else { 91 return cb.encodeFieldElement(fd, val) 92 } 93 } 94 95 func isPacked(fd *desc.FieldDescriptor) bool { 96 opts := fd.AsFieldDescriptorProto().GetOptions() 97 // if set, use that value 98 if opts != nil && opts.Packed != nil { 99 return opts.GetPacked() 100 } 101 // if unset: proto2 defaults to false, proto3 to true 102 return fd.GetFile().IsProto3() 103 } 104 105 // sortable is used to sort map keys. Values will be integers (int32, int64, uint32, and uint64), 106 // bools, or strings. 107 type sortable []interface{} 108 109 func (s sortable) Len() int { 110 return len(s) 111 } 112 113 func (s sortable) Less(i, j int) bool { 114 vi := s[i] 115 vj := s[j] 116 switch reflect.TypeOf(vi).Kind() { 117 case reflect.Int32: 118 return vi.(int32) < vj.(int32) 119 case reflect.Int64: 120 return vi.(int64) < vj.(int64) 121 case reflect.Uint32: 122 return vi.(uint32) < vj.(uint32) 123 case reflect.Uint64: 124 return vi.(uint64) < vj.(uint64) 125 case reflect.String: 126 return vi.(string) < vj.(string) 127 case reflect.Bool: 128 return !vi.(bool) && vj.(bool) 129 default: 130 panic(fmt.Sprintf("cannot compare keys of type %v", reflect.TypeOf(vi))) 131 } 132 } 133 134 func (s sortable) Swap(i, j int) { 135 s[i], s[j] = s[j], s[i] 136 } 137 138 func (b *Buffer) encodeFieldElement(fd *desc.FieldDescriptor, val interface{}) error { 139 wt, err := getWireType(fd.GetType()) 140 if err != nil { 141 return err 142 } 143 if err := b.EncodeTagAndWireType(fd.GetNumber(), wt); err != nil { 144 return err 145 } 146 if err := b.encodeFieldValue(fd, val); err != nil { 147 return err 148 } 149 if wt == proto.WireStartGroup { 150 return b.EncodeTagAndWireType(fd.GetNumber(), proto.WireEndGroup) 151 } 152 return nil 153 } 154 155 func (b *Buffer) encodeFieldValue(fd *desc.FieldDescriptor, val interface{}) error { 156 switch fd.GetType() { 157 case descriptor.FieldDescriptorProto_TYPE_BOOL: 158 v := val.(bool) 159 if v { 160 return b.EncodeVarint(1) 161 } 162 return b.EncodeVarint(0) 163 164 case descriptor.FieldDescriptorProto_TYPE_ENUM, 165 descriptor.FieldDescriptorProto_TYPE_INT32: 166 v := val.(int32) 167 return b.EncodeVarint(uint64(v)) 168 169 case descriptor.FieldDescriptorProto_TYPE_SFIXED32: 170 v := val.(int32) 171 return b.EncodeFixed32(uint64(v)) 172 173 case descriptor.FieldDescriptorProto_TYPE_SINT32: 174 v := val.(int32) 175 return b.EncodeVarint(EncodeZigZag32(v)) 176 177 case descriptor.FieldDescriptorProto_TYPE_UINT32: 178 v := val.(uint32) 179 return b.EncodeVarint(uint64(v)) 180 181 case descriptor.FieldDescriptorProto_TYPE_FIXED32: 182 v := val.(uint32) 183 return b.EncodeFixed32(uint64(v)) 184 185 case descriptor.FieldDescriptorProto_TYPE_INT64: 186 v := val.(int64) 187 return b.EncodeVarint(uint64(v)) 188 189 case descriptor.FieldDescriptorProto_TYPE_SFIXED64: 190 v := val.(int64) 191 return b.EncodeFixed64(uint64(v)) 192 193 case descriptor.FieldDescriptorProto_TYPE_SINT64: 194 v := val.(int64) 195 return b.EncodeVarint(EncodeZigZag64(v)) 196 197 case descriptor.FieldDescriptorProto_TYPE_UINT64: 198 v := val.(uint64) 199 return b.EncodeVarint(v) 200 201 case descriptor.FieldDescriptorProto_TYPE_FIXED64: 202 v := val.(uint64) 203 return b.EncodeFixed64(v) 204 205 case descriptor.FieldDescriptorProto_TYPE_DOUBLE: 206 v := val.(float64) 207 return b.EncodeFixed64(math.Float64bits(v)) 208 209 case descriptor.FieldDescriptorProto_TYPE_FLOAT: 210 v := val.(float32) 211 return b.EncodeFixed32(uint64(math.Float32bits(v))) 212 213 case descriptor.FieldDescriptorProto_TYPE_BYTES: 214 v := val.([]byte) 215 return b.EncodeRawBytes(v) 216 217 case descriptor.FieldDescriptorProto_TYPE_STRING: 218 v := val.(string) 219 return b.EncodeRawBytes(([]byte)(v)) 220 221 case descriptor.FieldDescriptorProto_TYPE_MESSAGE: 222 return b.EncodeDelimitedMessage(val.(proto.Message)) 223 224 case descriptor.FieldDescriptorProto_TYPE_GROUP: 225 // just append the nested message to this buffer 226 return b.EncodeMessage(val.(proto.Message)) 227 // whosoever writeth start-group tag (e.g. caller) is responsible for writing end-group tag 228 229 default: 230 return fmt.Errorf("unrecognized field type: %v", fd.GetType()) 231 } 232 } 233 234 func getWireType(t descriptor.FieldDescriptorProto_Type) (int8, error) { 235 switch t { 236 case descriptor.FieldDescriptorProto_TYPE_ENUM, 237 descriptor.FieldDescriptorProto_TYPE_BOOL, 238 descriptor.FieldDescriptorProto_TYPE_INT32, 239 descriptor.FieldDescriptorProto_TYPE_SINT32, 240 descriptor.FieldDescriptorProto_TYPE_UINT32, 241 descriptor.FieldDescriptorProto_TYPE_INT64, 242 descriptor.FieldDescriptorProto_TYPE_SINT64, 243 descriptor.FieldDescriptorProto_TYPE_UINT64: 244 return proto.WireVarint, nil 245 246 case descriptor.FieldDescriptorProto_TYPE_FIXED32, 247 descriptor.FieldDescriptorProto_TYPE_SFIXED32, 248 descriptor.FieldDescriptorProto_TYPE_FLOAT: 249 return proto.WireFixed32, nil 250 251 case descriptor.FieldDescriptorProto_TYPE_FIXED64, 252 descriptor.FieldDescriptorProto_TYPE_SFIXED64, 253 descriptor.FieldDescriptorProto_TYPE_DOUBLE: 254 return proto.WireFixed64, nil 255 256 case descriptor.FieldDescriptorProto_TYPE_BYTES, 257 descriptor.FieldDescriptorProto_TYPE_STRING, 258 descriptor.FieldDescriptorProto_TYPE_MESSAGE: 259 return proto.WireBytes, nil 260 261 case descriptor.FieldDescriptorProto_TYPE_GROUP: 262 return proto.WireStartGroup, nil 263 264 default: 265 return 0, ErrBadWireType 266 } 267 }