github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_record.go (about) 1 package avro 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "reflect" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Struct: 16 return decoderOfStruct(cfg, schema, typ) 17 18 case reflect.Map: 19 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 20 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 21 break 22 } 23 return decoderOfRecord(cfg, schema, typ) 24 25 case reflect.Ptr: 26 return decoderOfPtr(cfg, schema, typ) 27 28 case reflect.Interface: 29 if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok { 30 return &recordIfaceDecoder{schema: schema, valType: ifaceType} 31 } 32 } 33 34 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 35 } 36 37 func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 38 switch typ.Kind() { 39 case reflect.Struct: 40 return encoderOfStruct(cfg, schema, typ) 41 42 case reflect.Map: 43 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 44 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 45 break 46 } 47 return encoderOfRecord(cfg, schema, typ) 48 49 case reflect.Ptr: 50 return encoderOfPtr(cfg, schema, typ) 51 } 52 53 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 54 } 55 56 func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 57 rec := schema.(*RecordSchema) 58 structDesc := describeStruct(cfg.getTagKey(), typ) 59 60 fields := make([]*structFieldDecoder, 0, len(rec.Fields())) 61 62 for _, field := range rec.Fields() { 63 if field.action == FieldIgnore { 64 fields = append(fields, &structFieldDecoder{ 65 decoder: createSkipDecoder(field.Type()), 66 }) 67 continue 68 } 69 70 sf := structDesc.Fields.Get(field.Name()) 71 if sf == nil { 72 for _, alias := range field.Aliases() { 73 sf = structDesc.Fields.Get(alias) 74 if sf != nil { 75 break 76 } 77 } 78 } 79 80 // Skip field if it doesnt exist 81 if sf == nil { 82 fields = append(fields, &structFieldDecoder{ 83 decoder: createSkipDecoder(field.Type()), 84 }) 85 continue 86 } 87 88 if field.action == FieldSetDefault { 89 if field.hasDef { 90 fields = append(fields, &structFieldDecoder{ 91 field: sf.Field, 92 decoder: createDefaultDecoder(cfg, field, sf.Field[len(sf.Field)-1].Type()), 93 }) 94 95 continue 96 } 97 } 98 99 dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) 100 fields = append(fields, &structFieldDecoder{ 101 field: sf.Field, 102 decoder: dec, 103 }) 104 } 105 106 return &structDecoder{typ: typ, fields: fields} 107 } 108 109 type structFieldDecoder struct { 110 field []*reflect2.UnsafeStructField 111 decoder ValDecoder 112 } 113 114 type structDecoder struct { 115 typ reflect2.Type 116 fields []*structFieldDecoder 117 } 118 119 func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 120 for _, field := range d.fields { 121 // Skip case 122 if field.field == nil { 123 field.decoder.Decode(nil, r) 124 continue 125 } 126 127 fieldPtr := ptr 128 for i, f := range field.field { 129 fieldPtr = f.UnsafeGet(fieldPtr) 130 131 if i == len(field.field)-1 { 132 break 133 } 134 135 if f.Type().Kind() == reflect.Ptr { 136 if *((*unsafe.Pointer)(fieldPtr)) == nil { 137 newPtr := f.Type().(*reflect2.UnsafePtrType).Elem().UnsafeNew() 138 *((*unsafe.Pointer)(fieldPtr)) = newPtr 139 } 140 141 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 142 } 143 } 144 field.decoder.Decode(fieldPtr, r) 145 146 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 147 for _, f := range field.field { 148 r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error) 149 return 150 } 151 } 152 } 153 } 154 155 func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 156 rec := schema.(*RecordSchema) 157 structDesc := describeStruct(cfg.getTagKey(), typ) 158 159 fields := make([]*structFieldEncoder, 0, len(rec.Fields())) 160 for _, field := range rec.Fields() { 161 sf := structDesc.Fields.Get(field.Name()) 162 if sf != nil { 163 fields = append(fields, &structFieldEncoder{ 164 field: sf.Field, 165 encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), 166 }) 167 continue 168 } 169 170 if !field.HasDefault() { 171 // In all other cases, this is a required field 172 err := fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name()) 173 return &errorEncoder{err: err} 174 } 175 176 def := field.Default() 177 if field.Default() == nil { 178 if field.Type().Type() == Null { 179 // We write nothing in a Null case, just skip it 180 continue 181 } 182 183 if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() { 184 defaultType := reflect2.TypeOf(&def) 185 fields = append(fields, &structFieldEncoder{ 186 defaultPtr: reflect2.PtrOf(&def), 187 encoder: encoderOfNullableUnion(cfg, field.Type(), defaultType), 188 }) 189 continue 190 } 191 } 192 193 defaultType := reflect2.TypeOf(def) 194 defaultEncoder := encoderOfType(cfg, field.Type(), defaultType) 195 if defaultType.LikePtr() { 196 defaultEncoder = &onePtrEncoder{defaultEncoder} 197 } 198 fields = append(fields, &structFieldEncoder{ 199 defaultPtr: reflect2.PtrOf(def), 200 encoder: defaultEncoder, 201 }) 202 } 203 return &structEncoder{typ: typ, fields: fields} 204 } 205 206 type structFieldEncoder struct { 207 field []*reflect2.UnsafeStructField 208 defaultPtr unsafe.Pointer 209 encoder ValEncoder 210 } 211 212 type structEncoder struct { 213 typ reflect2.Type 214 fields []*structFieldEncoder 215 } 216 217 func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 218 for _, field := range e.fields { 219 // Default case 220 if field.field == nil { 221 field.encoder.Encode(field.defaultPtr, w) 222 continue 223 } 224 225 fieldPtr := ptr 226 for i, f := range field.field { 227 fieldPtr = f.UnsafeGet(fieldPtr) 228 229 if i == len(field.field)-1 { 230 break 231 } 232 233 if f.Type().Kind() == reflect.Ptr { 234 if *((*unsafe.Pointer)(fieldPtr)) == nil { 235 w.Error = fmt.Errorf("embedded field %q is nil", f.Name()) 236 return 237 } 238 239 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 240 } 241 } 242 field.encoder.Encode(fieldPtr, w) 243 244 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 245 for _, f := range field.field { 246 w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error) 247 return 248 } 249 } 250 } 251 } 252 253 func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 254 rec := schema.(*RecordSchema) 255 mapType := typ.(*reflect2.UnsafeMapType) 256 257 fields := make([]recordMapDecoderField, len(rec.Fields())) 258 for i, field := range rec.Fields() { 259 switch field.action { 260 case FieldIgnore: 261 fields[i] = recordMapDecoderField{ 262 name: field.Name(), 263 decoder: createSkipDecoder(field.Type()), 264 skip: true, 265 } 266 continue 267 case FieldSetDefault: 268 if field.hasDef { 269 fields[i] = recordMapDecoderField{ 270 name: field.Name(), 271 decoder: createDefaultDecoder(cfg, field, mapType.Elem()), 272 } 273 continue 274 } 275 } 276 277 fields[i] = recordMapDecoderField{ 278 name: field.Name(), 279 decoder: newEfaceDecoder(cfg, field.Type()), 280 } 281 } 282 283 return &recordMapDecoder{ 284 mapType: mapType, 285 elemType: mapType.Elem(), 286 fields: fields, 287 } 288 } 289 290 type recordMapDecoderField struct { 291 name string 292 decoder ValDecoder 293 skip bool 294 } 295 296 type recordMapDecoder struct { 297 mapType *reflect2.UnsafeMapType 298 elemType reflect2.Type 299 fields []recordMapDecoderField 300 } 301 302 func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 303 if d.mapType.UnsafeIsNil(ptr) { 304 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(len(d.fields))) 305 } 306 307 for _, field := range d.fields { 308 elemPtr := d.elemType.UnsafeNew() 309 field.decoder.Decode(elemPtr, r) 310 if field.skip { 311 continue 312 } 313 314 d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elemPtr) 315 } 316 317 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 318 r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) 319 } 320 } 321 322 func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 323 rec := schema.(*RecordSchema) 324 mapType := typ.(*reflect2.UnsafeMapType) 325 326 fields := make([]mapEncoderField, len(rec.Fields())) 327 for i, field := range rec.Fields() { 328 fields[i] = mapEncoderField{ 329 name: field.Name(), 330 hasDef: field.HasDefault(), 331 def: field.Default(), 332 encoder: encoderOfType(cfg, field.Type(), mapType.Elem()), 333 } 334 335 if field.HasDefault() { 336 switch { 337 case field.Type().Type() == Union: 338 union := field.Type().(*UnionSchema) 339 fields[i].def = map[string]any{ 340 string(union.Types()[0].Type()): field.Default(), 341 } 342 case field.Default() == nil: 343 continue 344 } 345 346 defaultType := reflect2.TypeOf(fields[i].def) 347 fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType) 348 if defaultType.LikePtr() { 349 fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder} 350 } 351 } 352 } 353 354 return &recordMapEncoder{ 355 mapType: mapType, 356 fields: fields, 357 } 358 } 359 360 type mapEncoderField struct { 361 name string 362 hasDef bool 363 def any 364 defEncoder ValEncoder 365 encoder ValEncoder 366 } 367 368 type recordMapEncoder struct { 369 mapType *reflect2.UnsafeMapType 370 fields []mapEncoderField 371 } 372 373 func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 374 for _, field := range e.fields { 375 // The first property of mapEncoderField is the name, so a pointer 376 // to field is a pointer to the name. 377 valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field)) 378 if valPtr == nil { 379 // Missing required field 380 if !field.hasDef { 381 w.Error = fmt.Errorf("avro: missing required field %s", field.name) 382 return 383 } 384 385 // Null default 386 if field.def == nil { 387 continue 388 } 389 390 defPtr := reflect2.PtrOf(field.def) 391 field.defEncoder.Encode(defPtr, w) 392 continue 393 } 394 395 field.encoder.Encode(valPtr, w) 396 397 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 398 w.Error = fmt.Errorf("%s: %w", field.name, w.Error) 399 return 400 } 401 } 402 } 403 404 type recordIfaceDecoder struct { 405 schema Schema 406 valType *reflect2.UnsafeIFaceType 407 } 408 409 func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 410 obj := d.valType.UnsafeIndirect(ptr) 411 if reflect2.IsNil(obj) { 412 r.ReportError("decode non empty interface", "can not unmarshal into nil") 413 return 414 } 415 416 r.ReadVal(d.schema, obj) 417 } 418 419 type structDescriptor struct { 420 Type reflect2.Type 421 Fields structFields 422 } 423 424 type structFields []*structField 425 426 func (sf structFields) Get(name string) *structField { 427 for _, f := range sf { 428 if f.Name == name { 429 return f 430 } 431 } 432 433 return nil 434 } 435 436 type structField struct { 437 Name string 438 Field []*reflect2.UnsafeStructField 439 440 anon *reflect2.UnsafeStructType 441 } 442 443 func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor { 444 structType := typ.(*reflect2.UnsafeStructType) 445 fields := structFields{} 446 447 var curr []structField 448 next := []structField{{anon: structType}} 449 450 visited := map[uintptr]bool{} 451 452 for len(next) > 0 { 453 curr, next = next, curr[:0] 454 455 for _, f := range curr { 456 rtype := f.anon.RType() 457 if visited[f.anon.RType()] { 458 continue 459 } 460 visited[rtype] = true 461 462 for i := 0; i < f.anon.NumField(); i++ { 463 field := f.anon.Field(i).(*reflect2.UnsafeStructField) 464 isUnexported := field.PkgPath() != "" 465 466 chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1) 467 copy(chain, f.Field) 468 chain[len(f.Field)] = field 469 470 if field.Anonymous() { 471 t := field.Type() 472 if t.Kind() == reflect.Ptr { 473 t = t.(*reflect2.UnsafePtrType).Elem() 474 } 475 if t.Kind() != reflect.Struct { 476 continue 477 } 478 479 next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)}) 480 continue 481 } 482 483 // Ignore unexported fields. 484 if isUnexported { 485 continue 486 } 487 488 fieldName := field.Name() 489 if tag, ok := field.Tag().Lookup(tagKey); ok { 490 fieldName = tag 491 } 492 493 fields = append(fields, &structField{ 494 Name: fieldName, 495 Field: chain, 496 }) 497 } 498 } 499 } 500 501 return &structDescriptor{ 502 Type: structType, 503 Fields: fields, 504 } 505 }