github.com/hamba/avro@v1.8.0/codec_record.go (about) 1 package avro 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "reflect" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Struct: 16 return decoderOfStruct(cfg, schema, typ) 17 18 case reflect.Map: 19 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 20 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 21 break 22 } 23 return decoderOfRecord(cfg, schema, typ) 24 25 case reflect.Ptr: 26 return decoderOfPtr(cfg, schema, typ) 27 28 case reflect.Interface: 29 if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok { 30 return &recordIfaceDecoder{schema: schema, valType: ifaceType} 31 } 32 } 33 34 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 35 } 36 37 func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 38 switch typ.Kind() { 39 case reflect.Struct: 40 return encoderOfStruct(cfg, schema, typ) 41 42 case reflect.Map: 43 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 44 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 45 break 46 } 47 return encoderOfRecord(cfg, schema, typ) 48 49 case reflect.Ptr: 50 return encoderOfPtr(cfg, schema, typ) 51 } 52 53 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())} 54 } 55 56 func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 57 rec := schema.(*RecordSchema) 58 structDesc := describeStruct(cfg.getTagKey(), typ) 59 60 fields := make([]*structFieldDecoder, 0, len(rec.Fields())) 61 for _, field := range rec.Fields() { 62 sf := structDesc.Fields.Get(field.Name()) 63 64 // Skip field if it doesnt exist 65 if sf == nil { 66 fields = append(fields, &structFieldDecoder{ 67 decoder: createSkipDecoder(field.Type()), 68 }) 69 continue 70 } 71 72 dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()) 73 fields = append(fields, &structFieldDecoder{ 74 field: sf.Field, 75 decoder: dec, 76 }) 77 } 78 79 return &structDecoder{typ: typ, fields: fields} 80 } 81 82 type structFieldDecoder struct { 83 field []*reflect2.UnsafeStructField 84 decoder ValDecoder 85 } 86 87 type structDecoder struct { 88 typ reflect2.Type 89 fields []*structFieldDecoder 90 } 91 92 func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 93 for _, field := range d.fields { 94 // Skip case 95 if field.field == nil { 96 field.decoder.Decode(nil, r) 97 continue 98 } 99 100 fieldPtr := ptr 101 for i, f := range field.field { 102 fieldPtr = f.UnsafeGet(fieldPtr) 103 104 if i == len(field.field)-1 { 105 break 106 } 107 108 if f.Type().Kind() == reflect.Ptr { 109 if *((*unsafe.Pointer)(ptr)) == nil { 110 newPtr := f.Type().UnsafeNew() 111 *((*unsafe.Pointer)(fieldPtr)) = newPtr 112 } 113 114 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 115 } 116 } 117 field.decoder.Decode(fieldPtr, r) 118 119 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 120 for _, f := range field.field { 121 r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error) 122 } 123 return 124 } 125 } 126 } 127 128 func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 129 rec := schema.(*RecordSchema) 130 structDesc := describeStruct(cfg.getTagKey(), typ) 131 132 fields := make([]*structFieldEncoder, 0, len(rec.Fields())) 133 for _, field := range rec.Fields() { 134 sf := structDesc.Fields.Get(field.Name()) 135 136 if sf == nil { 137 if !field.HasDefault() { 138 // In all other cases, this is a required field 139 return &errorEncoder{err: fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name())} 140 } 141 142 def := field.Default() 143 if field.Default() == nil { 144 if field.Type().Type() == Null { 145 // We write nothing in a Null case, just skip it 146 continue 147 } 148 149 if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() { 150 defaultType := reflect2.TypeOf(&def) 151 fields = append(fields, &structFieldEncoder{ 152 defaultPtr: reflect2.PtrOf(&def), 153 encoder: encoderOfPtrUnion(cfg, field.Type(), defaultType), 154 }) 155 continue 156 } 157 } 158 159 defaultType := reflect2.TypeOf(def) 160 fields = append(fields, &structFieldEncoder{ 161 defaultPtr: reflect2.PtrOf(def), 162 encoder: encoderOfType(cfg, field.Type(), defaultType), 163 }) 164 165 continue 166 } 167 168 fields = append(fields, &structFieldEncoder{ 169 field: sf.Field, 170 encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()), 171 }) 172 } 173 174 return &structEncoder{typ: typ, fields: fields} 175 } 176 177 type structFieldEncoder struct { 178 field []*reflect2.UnsafeStructField 179 defaultPtr unsafe.Pointer 180 encoder ValEncoder 181 } 182 183 type structEncoder struct { 184 typ reflect2.Type 185 fields []*structFieldEncoder 186 } 187 188 func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 189 for _, field := range e.fields { 190 // Default case 191 if field.field == nil { 192 field.encoder.Encode(field.defaultPtr, w) 193 continue 194 } 195 196 fieldPtr := ptr 197 for i, f := range field.field { 198 fieldPtr = f.UnsafeGet(fieldPtr) 199 200 if i == len(field.field)-1 { 201 break 202 } 203 204 if f.Type().Kind() == reflect.Ptr { 205 if *((*unsafe.Pointer)(ptr)) == nil { 206 w.Error = fmt.Errorf("embedded field %q is nil", f.Name()) 207 return 208 } 209 210 fieldPtr = *((*unsafe.Pointer)(fieldPtr)) 211 } 212 } 213 field.encoder.Encode(fieldPtr, w) 214 215 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 216 for _, f := range field.field { 217 w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error) 218 } 219 } 220 } 221 } 222 223 func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 224 rec := schema.(*RecordSchema) 225 mapType := typ.(*reflect2.UnsafeMapType) 226 227 fields := make([]recordMapDecoderField, len(rec.Fields())) 228 for i, field := range rec.Fields() { 229 fields[i] = recordMapDecoderField{ 230 name: field.Name(), 231 decoder: decoderOfType(cfg, field.Type(), mapType.Elem()), 232 } 233 } 234 235 return &recordMapDecoder{ 236 mapType: mapType, 237 elemType: mapType.Elem(), 238 fields: fields, 239 } 240 } 241 242 type recordMapDecoderField struct { 243 name string 244 decoder ValDecoder 245 } 246 247 type recordMapDecoder struct { 248 mapType *reflect2.UnsafeMapType 249 elemType reflect2.Type 250 fields []recordMapDecoderField 251 } 252 253 func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 254 if d.mapType.UnsafeIsNil(ptr) { 255 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 256 } 257 258 for _, field := range d.fields { 259 elem := d.elemType.UnsafeNew() 260 field.decoder.Decode(elem, r) 261 262 d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem) 263 } 264 265 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 266 r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) 267 } 268 } 269 270 func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 271 rec := schema.(*RecordSchema) 272 mapType := typ.(*reflect2.UnsafeMapType) 273 274 fields := make([]mapEncoderField, len(rec.Fields())) 275 for i, field := range rec.Fields() { 276 fields[i] = mapEncoderField{ 277 name: field.Name(), 278 hasDef: field.HasDefault(), 279 def: field.Default(), 280 encoder: encoderOfType(cfg, field.Type(), mapType.Elem()), 281 } 282 283 if field.HasDefault() { 284 switch { 285 case field.Type().Type() == Union: 286 union := field.Type().(*UnionSchema) 287 fields[i].def = map[string]interface{}{ 288 string(union.Types()[0].Type()): field.Default(), 289 } 290 case field.Default() == nil: 291 continue 292 } 293 294 defaultType := reflect2.TypeOf(fields[i].def) 295 fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType) 296 if defaultType.LikePtr() { 297 fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder} 298 } 299 } 300 } 301 302 return &recordMapEncoder{ 303 mapType: mapType, 304 fields: fields, 305 } 306 } 307 308 type mapEncoderField struct { 309 name string 310 hasDef bool 311 def interface{} 312 defEncoder ValEncoder 313 encoder ValEncoder 314 } 315 316 type recordMapEncoder struct { 317 mapType *reflect2.UnsafeMapType 318 fields []mapEncoderField 319 } 320 321 func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 322 for _, field := range e.fields { 323 valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field)) 324 if valPtr == nil { 325 // Missing required field 326 if !field.hasDef { 327 w.Error = fmt.Errorf("avro: missing required field %s", field.name) 328 return 329 } 330 331 // Null default 332 if field.def == nil { 333 continue 334 } 335 336 defPtr := reflect2.PtrOf(field.def) 337 field.defEncoder.Encode(defPtr, w) 338 continue 339 } 340 341 field.encoder.Encode(valPtr, w) 342 } 343 } 344 345 type recordIfaceDecoder struct { 346 schema Schema 347 valType *reflect2.UnsafeIFaceType 348 } 349 350 func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 351 obj := d.valType.UnsafeIndirect(ptr) 352 if reflect2.IsNil(obj) { 353 r.ReportError("decode non empty interface", "can not unmarshal into nil") 354 return 355 } 356 357 r.ReadVal(d.schema, obj) 358 } 359 360 type structDescriptor struct { 361 Type reflect2.Type 362 Fields structFields 363 } 364 365 type structFields []*structField 366 367 func (sf structFields) Get(name string) *structField { 368 for _, f := range sf { 369 if f.Name == name { 370 return f 371 } 372 } 373 374 return nil 375 } 376 377 type structField struct { 378 Name string 379 Field []*reflect2.UnsafeStructField 380 381 anon *reflect2.UnsafeStructType 382 } 383 384 func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor { 385 structType := typ.(*reflect2.UnsafeStructType) 386 fields := structFields{} 387 388 var curr []structField 389 next := []structField{{anon: structType}} 390 391 visited := map[uintptr]bool{} 392 393 for len(next) > 0 { 394 curr, next = next, curr[:0] 395 396 for _, f := range curr { 397 rtype := f.anon.RType() 398 if visited[f.anon.RType()] { 399 continue 400 } 401 visited[rtype] = true 402 403 for i := 0; i < f.anon.NumField(); i++ { 404 field := f.anon.Field(i).(*reflect2.UnsafeStructField) 405 isUnexported := field.PkgPath() != "" 406 407 chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1) 408 copy(chain, f.Field) 409 chain[len(f.Field)] = field 410 411 if field.Anonymous() { 412 t := field.Type() 413 if t.Kind() == reflect.Ptr { 414 t = t.(*reflect2.UnsafePtrType).Elem() 415 } 416 if t.Kind() != reflect.Struct { 417 continue 418 } 419 420 next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)}) 421 continue 422 } 423 424 // Ignore unexported fields. 425 if isUnexported { 426 continue 427 } 428 429 fieldName := field.Name() 430 if tag, ok := field.Tag().Lookup(tagKey); ok { 431 fieldName = tag 432 } 433 434 fields = append(fields, &structField{ 435 Name: fieldName, 436 Field: chain, 437 }) 438 } 439 } 440 } 441 442 return &structDescriptor{ 443 Type: structType, 444 Fields: fields, 445 } 446 }