github.com/aacfactory/avro@v1.2.12/internal/base/codec_record.go (about) 1 package base 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "reflect" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Struct: 16 return decoderOfStruct(cfg, schema, typ) 17 18 case reflect.Map: 19 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 20 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 21 break 22 } 23 return decoderOfRecord(cfg, schema, typ) 24 25 case reflect.Ptr: 26 return decoderOfPtr(cfg, schema, typ) 27 28 case reflect.Interface: 29 if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok { 30 return &recordIfaceDecoder{schema: schema, valType: ifaceType} 31 } 32 default: 33 break 34 } 35 36 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 37 } 38 39 func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 40 switch typ.Kind() { 41 case reflect.Struct: 42 return encoderOfStruct(cfg, schema, typ) 43 44 case reflect.Map: 45 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 46 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 47 break 48 } 49 return encoderOfRecord(cfg, schema, typ) 50 51 case reflect.Ptr: 52 return encoderOfPtr(cfg, schema, typ) 53 default: 54 break 55 } 56 57 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 58 } 59 60 func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 61 processing := cfg.getProcessingDecoderFromCache(schema.Fingerprint(), typ.RType()) 62 if processing != nil { 63 return processing 64 } 65 dec := &structDecoder{typ: typ, fields: nil} 66 cfg.addProcessingDecoderToCache(schema.Fingerprint(), typ.RType(), dec) 67 68 rec := schema.(*RecordSchema) 69 structDesc := describeStruct(cfg.getTagKey(), typ) 70 71 fields := make([]*structFieldDecoder, 0, len(rec.Fields())) 72 for _, field := range rec.Fields() { 73 sf := structDesc.Fields.Get(field.Name()) 74 if sf == nil { 75 for _, alias := range field.Aliases() { 76 sf = structDesc.Fields.Get(alias) 77 if sf != nil { 78 break 79 } 80 } 81 } 82 83 // Skip field if it doesnt exist 84 if sf == nil { 85 fields = append(fields, &structFieldDecoder{ 86 decoder: createSkipDecoder(field.Type()), 87 }) 88 continue 89 } 90 91 fields = append(fields, &structFieldDecoder{ 92 field: sf.Field, 93 decoder: decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), 94 }) 95 } 96 97 dec.fields = fields 98 return dec 99 } 100 101 type structFieldDecoder struct { 102 field []*reflect2.UnsafeStructField 103 decoder ValDecoder 104 } 105 106 type structDecoder struct { 107 typ reflect2.Type 108 fields []*structFieldDecoder 109 } 110 111 func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 112 for _, field := range d.fields { 113 // Skip case 114 if field.field == nil { 115 field.decoder.Decode(nil, r) 116 continue 117 } 118 119 fieldPtr := ptr 120 for i, f := range field.field { 121 fieldPtr = f.UnsafeGet(fieldPtr) 122 123 if i == len(field.field)-1 { 124 break 125 } 126 127 if f.Type().Kind() == reflect.Ptr { 128 if *((*unsafe.Pointer)(fieldPtr)) == nil { 129 newPtr := f.Type().(*reflect2.UnsafePtrType).Elem().UnsafeNew() 130 *((*unsafe.Pointer)(fieldPtr)) = newPtr 131 } 132 133 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 134 } 135 } 136 field.decoder.Decode(fieldPtr, r) 137 138 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 139 for _, f := range field.field { 140 r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error) 141 return 142 } 143 } 144 } 145 } 146 147 func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 148 processing := cfg.getProcessingEncoderFromCache(schema.Fingerprint(), typ.RType()) 149 if processing != nil { 150 return processing 151 } 152 enc := &structEncoder{typ: typ, fields: nil} 153 cfg.addProcessingEncoderToCache(schema.Fingerprint(), typ.RType(), enc) 154 155 rec := schema.(*RecordSchema) 156 structDesc := describeStruct(cfg.getTagKey(), typ) 157 158 fields := make([]*structFieldEncoder, 0, len(rec.Fields())) 159 for _, field := range rec.Fields() { 160 sf := structDesc.Fields.Get(field.Name()) 161 if sf != nil { 162 fields = append(fields, &structFieldEncoder{ 163 field: sf.Field, 164 encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), 165 }) 166 continue 167 } 168 169 if !field.HasDefault() { 170 // In all other cases, this is a required field 171 err := fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name()) 172 return &errorEncoder{err: err} 173 } 174 175 def := field.Default() 176 if field.Default() == nil { 177 if field.Type().Type() == Null { 178 // We write nothing in a Null case, just skip it 179 continue 180 } 181 182 if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() { 183 defaultType := reflect2.TypeOf(&def) 184 fields = append(fields, &structFieldEncoder{ 185 defaultPtr: reflect2.PtrOf(&def), 186 encoder: encoderOfPtrUnion(cfg, field.Type(), defaultType), 187 }) 188 continue 189 } 190 } 191 192 defaultType := reflect2.TypeOf(def) 193 defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) 194 if defaultType.LikePtr() { 195 defaultEncoder = &onePtrEncoder{defaultEncoder} 196 } 197 fields = append(fields, &structFieldEncoder{ 198 defaultPtr: reflect2.PtrOf(def), 199 encoder: defaultEncoder, 200 }) 201 } 202 203 enc.fields = fields 204 return enc 205 } 206 207 type structFieldEncoder struct { 208 field []*reflect2.UnsafeStructField 209 defaultPtr unsafe.Pointer 210 encoder ValEncoder 211 } 212 213 type structEncoder struct { 214 typ reflect2.Type 215 fields []*structFieldEncoder 216 } 217 218 func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 219 for _, field := range e.fields { 220 // Default case 221 if field.field == nil { 222 field.encoder.Encode(field.defaultPtr, w) 223 continue 224 } 225 226 fieldPtr := ptr 227 for i, f := range field.field { 228 fieldPtr = f.UnsafeGet(fieldPtr) 229 230 if i == len(field.field)-1 { 231 break 232 } 233 234 if f.Type().Kind() == reflect.Ptr { 235 if *((*unsafe.Pointer)(fieldPtr)) == nil { 236 w.Error = fmt.Errorf("embedded field %q is nil", f.Name()) 237 return 238 } 239 240 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 241 } 242 } 243 field.encoder.Encode(fieldPtr, w) 244 245 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 246 for _, f := range field.field { 247 w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error) 248 return 249 } 250 } 251 } 252 } 253 254 func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 255 rec := schema.(*RecordSchema) 256 mapType := typ.(*reflect2.UnsafeMapType) 257 258 fields := make([]recordMapDecoderField, len(rec.Fields())) 259 for i, field := range rec.Fields() { 260 fields[i] = recordMapDecoderField{ 261 name: field.Name(), 262 decoder: decoderOfType(cfg, field.Type(), mapType.Elem()), 263 } 264 } 265 266 return &recordMapDecoder{ 267 mapType: mapType, 268 elemType: mapType.Elem(), 269 fields: fields, 270 } 271 } 272 273 type recordMapDecoderField struct { 274 name string 275 decoder ValDecoder 276 } 277 278 type recordMapDecoder struct { 279 mapType *reflect2.UnsafeMapType 280 elemType reflect2.Type 281 fields []recordMapDecoderField 282 } 283 284 func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 285 if d.mapType.UnsafeIsNil(ptr) { 286 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 287 } 288 289 for _, field := range d.fields { 290 elem := d.elemType.UnsafeNew() 291 field.decoder.Decode(elem, r) 292 293 d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem) 294 } 295 296 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 297 r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) 298 } 299 } 300 301 func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 302 rec := schema.(*RecordSchema) 303 mapType := typ.(*reflect2.UnsafeMapType) 304 305 fields := make([]mapEncoderField, len(rec.Fields())) 306 for i, field := range rec.Fields() { 307 fields[i] = mapEncoderField{ 308 name: field.Name(), 309 hasDef: field.HasDefault(), 310 def: field.Default(), 311 encoder: encoderOfType(cfg, field.Type(), mapType.Elem()), 312 } 313 314 if field.HasDefault() { 315 switch { 316 case field.Type().Type() == Union: 317 union := field.Type().(*UnionSchema) 318 fields[i].def = map[string]any{ 319 string(union.Types()[0].Type()): field.Default(), 320 } 321 case field.Default() == nil: 322 continue 323 } 324 325 defaultType := reflect2.TypeOf(fields[i].def) 326 fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType) 327 if defaultType.LikePtr() { 328 fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder} 329 } 330 } 331 } 332 333 return &recordMapEncoder{ 334 mapType: mapType, 335 fields: fields, 336 } 337 } 338 339 type mapEncoderField struct { 340 name string 341 hasDef bool 342 def any 343 defEncoder ValEncoder 344 encoder ValEncoder 345 } 346 347 type recordMapEncoder struct { 348 mapType *reflect2.UnsafeMapType 349 fields []mapEncoderField 350 } 351 352 func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 353 for _, field := range e.fields { 354 // The first property of mapEncoderField is the name, so a pointer 355 // to field is a pointer to the name. 356 valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field)) 357 if valPtr == nil { 358 // Missing required field 359 if !field.hasDef { 360 w.Error = fmt.Errorf("avro: missing required field %s", field.name) 361 return 362 } 363 364 // Null default 365 if field.def == nil { 366 continue 367 } 368 369 defPtr := reflect2.PtrOf(field.def) 370 field.defEncoder.Encode(defPtr, w) 371 continue 372 } 373 374 field.encoder.Encode(valPtr, w) 375 376 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 377 w.Error = fmt.Errorf("%s: %w", field.name, w.Error) 378 return 379 } 380 } 381 } 382 383 type recordIfaceDecoder struct { 384 schema Schema 385 valType *reflect2.UnsafeIFaceType 386 } 387 388 func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 389 obj := d.valType.UnsafeIndirect(ptr) 390 if reflect2.IsNil(obj) { 391 r.ReportError("decode non empty interface", "can not unmarshal into nil") 392 return 393 } 394 395 r.ReadVal(d.schema, obj) 396 } 397 398 type structDescriptor struct { 399 Type reflect2.Type 400 Fields structFields 401 } 402 403 type structFields []*structField 404 405 func (sf structFields) Get(name string) *structField { 406 for _, f := range sf { 407 if f.Name == name { 408 return f 409 } 410 } 411 412 return nil 413 } 414 415 type structField struct { 416 Name string 417 Field []*reflect2.UnsafeStructField 418 419 anon *reflect2.UnsafeStructType 420 } 421 422 func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor { 423 structType := typ.(*reflect2.UnsafeStructType) 424 fields := structFields{} 425 426 var curr []structField 427 next := []structField{{anon: structType}} 428 429 visited := map[uintptr]bool{} 430 431 for len(next) > 0 { 432 curr, next = next, curr[:0] 433 434 for _, f := range curr { 435 rtype := f.anon.RType() 436 if visited[f.anon.RType()] { 437 continue 438 } 439 visited[rtype] = true 440 441 for i := 0; i < f.anon.NumField(); i++ { 442 field := f.anon.Field(i).(*reflect2.UnsafeStructField) 443 isUnexported := field.PkgPath() != "" 444 445 chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1) 446 copy(chain, f.Field) 447 chain[len(f.Field)] = field 448 449 if field.Anonymous() { 450 t := field.Type() 451 if t.Kind() == reflect.Ptr { 452 t = t.(*reflect2.UnsafePtrType).Elem() 453 } 454 if t.Kind() != reflect.Struct { 455 continue 456 } 457 458 next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)}) 459 continue 460 } 461 462 // Ignore unexported fields. 463 if isUnexported { 464 continue 465 } 466 467 fieldName := field.Name() 468 if tag, ok := field.Tag().Lookup(tagKey); ok { 469 fieldName = tag 470 } 471 472 fields = append(fields, &structField{ 473 Name: fieldName, 474 Field: chain, 475 }) 476 } 477 } 478 } 479 480 return &structDescriptor{ 481 Type: structType, 482 Fields: fields, 483 } 484 }