github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/proto/struct.go (about) 1 package proto 2 3 import ( 4 "fmt" 5 "io" 6 "reflect" 7 "unsafe" 8 ) 9 10 const ( 11 embedded = 1 << 0 12 repeated = 1 << 1 13 zigzag = 1 << 2 14 ) 15 16 type structField struct { 17 number uint16 18 tagsize uint8 19 flags uint8 20 offset uint32 21 codec *codec 22 } 23 24 func (f *structField) String() string { 25 return fmt.Sprintf("[%d,%s]", f.fieldNumber(), f.wireType()) 26 } 27 28 func (f *structField) fieldNumber() fieldNumber { 29 return fieldNumber(f.number) 30 } 31 32 func (f *structField) wireType() wireType { 33 return f.codec.wire 34 } 35 36 func (f *structField) embedded() bool { 37 return (f.flags & embedded) != 0 38 } 39 40 func (f *structField) repeated() bool { 41 return (f.flags & repeated) != 0 42 } 43 44 func (f *structField) pointer(p unsafe.Pointer) unsafe.Pointer { 45 return unsafe.Pointer(uintptr(p) + uintptr(f.offset)) 46 } 47 48 func (f *structField) makeFlags(base flags) flags { 49 return base | flags(f.flags&zigzag) 50 } 51 52 func structCodecOf(t reflect.Type, seen map[reflect.Type]*codec) *codec { 53 c := &codec{wire: varlen} 54 seen[t] = c 55 56 numField := t.NumField() 57 number := fieldNumber(1) 58 fields := make([]structField, 0, numField) 59 60 for i := 0; i < numField; i++ { 61 f := t.Field(i) 62 63 if f.PkgPath != "" { 64 continue // unexported 65 } 66 67 field := structField{ 68 number: uint16(number), 69 offset: uint32(f.Offset), 70 } 71 72 if tag, ok := f.Tag.Lookup("protobuf"); ok { 73 t, err := parseStructTag(tag) 74 if err == nil { 75 field.number = uint16(t.fieldNumber) 76 if t.repeated { 77 field.flags |= repeated 78 } 79 if t.zigzag { 80 field.flags |= zigzag 81 } 82 switch t.wireType { 83 case Fixed32: 84 switch baseKindOf(f.Type) { 85 case reflect.Uint32: 86 field.codec = &fixed32Codec 87 case reflect.Float32: 88 field.codec = &float32Codec 89 } 90 case Fixed64: 91 switch baseKindOf(f.Type) { 92 case reflect.Uint64: 93 field.codec = &fixed64Codec 94 case reflect.Float64: 95 field.codec = &float64Codec 96 } 97 } 98 } 99 } 100 101 if field.codec == nil { 102 switch baseKindOf(f.Type) { 103 case reflect.Struct: 104 field.flags |= embedded 105 field.codec = codecOf(f.Type, seen) 106 107 case reflect.Slice: 108 elem := f.Type.Elem() 109 110 if elem.Kind() == reflect.Uint8 { // []byte 111 field.codec = codecOf(f.Type, seen) 112 } else { 113 if baseKindOf(elem) == reflect.Struct { 114 field.flags |= embedded 115 } 116 field.flags |= repeated 117 field.codec = codecOf(elem, seen) 118 field.codec = sliceCodecOf(f.Type, field, seen) 119 } 120 121 case reflect.Map: 122 key, val := f.Type.Key(), f.Type.Elem() 123 k := codecOf(key, seen) 124 v := codecOf(val, seen) 125 m := &mapField{ 126 number: field.number, 127 keyCodec: k, 128 valCodec: v, 129 } 130 if baseKindOf(key) == reflect.Struct { 131 m.keyFlags |= embedded 132 } 133 if baseKindOf(val) == reflect.Struct { 134 m.valFlags |= embedded 135 } 136 field.flags |= embedded | repeated 137 field.codec = mapCodecOf(f.Type, m, seen) 138 139 default: 140 field.codec = codecOf(f.Type, seen) 141 } 142 } 143 144 field.tagsize = uint8(sizeOfTag(fieldNumber(field.number), wireType(field.codec.wire))) 145 fields = append(fields, field) 146 number++ 147 } 148 149 c.size = structSizeFuncOf(t, fields) 150 c.encode = structEncodeFuncOf(t, fields) 151 c.decode = structDecodeFuncOf(t, fields) 152 return c 153 } 154 155 func baseKindOf(t reflect.Type) reflect.Kind { 156 return baseTypeOf(t).Kind() 157 } 158 159 func baseTypeOf(t reflect.Type) reflect.Type { 160 for t.Kind() == reflect.Ptr { 161 t = t.Elem() 162 } 163 return t 164 } 165 166 func structSizeFuncOf(t reflect.Type, fields []structField) sizeFunc { 167 var inlined = inlined(t) 168 var unique, repeated []*structField 169 170 for i := range fields { 171 f := &fields[i] 172 if f.repeated() { 173 repeated = append(repeated, f) 174 } else { 175 unique = append(unique, f) 176 } 177 } 178 179 return func(p unsafe.Pointer, flags flags) int { 180 if p == nil { 181 return 0 182 } 183 184 if !inlined { 185 flags = flags.without(inline | toplevel) 186 } else { 187 flags = flags.without(toplevel) 188 } 189 n := 0 190 191 for _, f := range unique { 192 size := f.codec.size(f.pointer(p), f.makeFlags(flags)) 193 if size > 0 { 194 n += int(f.tagsize) + size 195 if f.embedded() { 196 n += sizeOfVarint(uint64(size)) 197 } 198 flags = flags.without(wantzero) 199 } 200 } 201 202 for _, f := range repeated { 203 size := f.codec.size(f.pointer(p), f.makeFlags(flags)) 204 if size > 0 { 205 n += size 206 flags = flags.without(wantzero) 207 } 208 } 209 210 return n 211 } 212 } 213 214 func structEncodeFuncOf(t reflect.Type, fields []structField) encodeFunc { 215 var inlined = inlined(t) 216 var unique, repeated []*structField 217 218 for i := range fields { 219 f := &fields[i] 220 if f.repeated() { 221 repeated = append(repeated, f) 222 } else { 223 unique = append(unique, f) 224 } 225 } 226 227 return func(b []byte, p unsafe.Pointer, flags flags) (int, error) { 228 if p == nil { 229 return 0, nil 230 } 231 232 if !inlined { 233 flags = flags.without(inline | toplevel) 234 } else { 235 flags = flags.without(toplevel) 236 } 237 offset := 0 238 239 for _, f := range unique { 240 fieldFlags := f.makeFlags(flags) 241 elem := f.pointer(p) 242 size := f.codec.size(elem, fieldFlags) 243 244 if size > 0 { 245 n, err := encodeTag(b[offset:], f.fieldNumber(), f.wireType()) 246 offset += n 247 if err != nil { 248 return offset, err 249 } 250 251 if f.embedded() { 252 n, err := encodeVarint(b[offset:], uint64(size)) 253 offset += n 254 if err != nil { 255 return offset, err 256 } 257 } 258 259 if (len(b) - offset) < size { 260 return len(b), io.ErrShortBuffer 261 } 262 263 n, err = f.codec.encode(b[offset:offset+size], elem, fieldFlags) 264 offset += n 265 if err != nil { 266 return offset, err 267 } 268 269 flags = flags.without(wantzero) 270 } 271 } 272 273 for _, f := range repeated { 274 n, err := f.codec.encode(b[offset:], f.pointer(p), f.makeFlags(flags)) 275 offset += n 276 if err != nil { 277 return offset, err 278 } 279 if n > 0 { 280 flags = flags.without(wantzero) 281 } 282 } 283 284 return offset, nil 285 } 286 } 287 288 func structDecodeFuncOf(t reflect.Type, fields []structField) decodeFunc { 289 maxFieldNumber := fieldNumber(0) 290 291 for _, f := range fields { 292 if n := f.fieldNumber(); n > maxFieldNumber { 293 maxFieldNumber = n 294 } 295 } 296 297 fieldIndex := make([]*structField, maxFieldNumber+1) 298 299 for i := range fields { 300 f := &fields[i] 301 fieldIndex[f.fieldNumber()] = f 302 } 303 304 return func(b []byte, p unsafe.Pointer, flags flags) (int, error) { 305 flags = flags.without(toplevel) 306 offset := 0 307 308 for offset < len(b) { 309 fieldNumber, wireType, n, err := decodeTag(b[offset:]) 310 offset += n 311 if err != nil { 312 return offset, err 313 } 314 315 i := int(fieldNumber) 316 f := (*structField)(nil) 317 318 if i >= 0 && i < len(fieldIndex) { 319 f = fieldIndex[i] 320 } 321 322 if f == nil { 323 skip := 0 324 size := uint64(0) 325 switch wireType { 326 case varint: 327 _, skip, err = decodeVarint(b[offset:]) 328 case varlen: 329 size, skip, err = decodeVarint(b[offset:]) 330 if err == nil { 331 if size > uint64(len(b)-skip) { 332 err = io.ErrUnexpectedEOF 333 } else { 334 skip += int(size) 335 } 336 } 337 case fixed32: 338 _, skip, err = decodeLE32(b[offset:]) 339 case fixed64: 340 _, skip, err = decodeLE64(b[offset:]) 341 default: 342 err = ErrWireTypeUnknown 343 } 344 if (offset + skip) <= len(b) { 345 offset += skip 346 } else { 347 offset, err = len(b), io.ErrUnexpectedEOF 348 } 349 if err != nil { 350 return offset, fieldError(fieldNumber, wireType, err) 351 } 352 continue 353 } 354 355 if wireType != f.wireType() { 356 return offset, fieldError(fieldNumber, wireType, fmt.Errorf("expected wire type %d", f.wireType())) 357 } 358 359 // `data` will only contain the section of the input buffer where 360 // the data for the next field is available. This is necessary to 361 // limit how many bytes will be consumed by embedded messages. 362 var data []byte 363 switch wireType { 364 case varint: 365 _, n, err := decodeVarint(b[offset:]) 366 if err != nil { 367 return offset, fieldError(fieldNumber, wireType, err) 368 } 369 data = b[offset : offset+n] 370 371 case varlen: 372 l, n, err := decodeVarint(b[offset:]) 373 if err != nil { 374 return offset + n, fieldError(fieldNumber, wireType, err) 375 } 376 if l > uint64(len(b)-(offset+n)) { 377 return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF) 378 } 379 if f.embedded() { 380 offset += n 381 data = b[offset : offset+int(l)] 382 } else { 383 data = b[offset : offset+n+int(l)] 384 } 385 386 case fixed32: 387 if (offset + 4) > len(b) { 388 return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF) 389 } 390 data = b[offset : offset+4] 391 392 case fixed64: 393 if (offset + 8) > len(b) { 394 return len(b), fieldError(fieldNumber, wireType, io.ErrUnexpectedEOF) 395 } 396 data = b[offset : offset+8] 397 398 default: 399 return offset, fieldError(fieldNumber, wireType, ErrWireTypeUnknown) 400 } 401 402 n, err = f.codec.decode(data, f.pointer(p), f.makeFlags(flags)) 403 offset += n 404 if err != nil { 405 return offset, fieldError(fieldNumber, wireType, err) 406 } 407 } 408 409 return offset, nil 410 } 411 }