github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_union.go (about) 1 package avro 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strings" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Map: 16 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 17 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 18 break 19 } 20 return decoderOfMapUnion(cfg, schema, typ) 21 case reflect.Slice: 22 if !schema.(*UnionSchema).Nullable() { 23 break 24 } 25 return decoderOfNullableUnion(cfg, schema, typ) 26 case reflect.Ptr: 27 if !schema.(*UnionSchema).Nullable() { 28 break 29 } 30 return decoderOfNullableUnion(cfg, schema, typ) 31 case reflect.Interface: 32 if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { 33 dec, err := decoderOfResolvedUnion(cfg, schema) 34 if err != nil { 35 return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} 36 } 37 return dec 38 } 39 } 40 41 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 42 } 43 44 func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 45 switch typ.Kind() { 46 case reflect.Map: 47 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 48 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 49 break 50 } 51 return encoderOfMapUnion(cfg, schema, typ) 52 case reflect.Slice: 53 if !schema.(*UnionSchema).Nullable() { 54 break 55 } 56 return encoderOfNullableUnion(cfg, schema, typ) 57 case reflect.Ptr: 58 if !schema.(*UnionSchema).Nullable() { 59 break 60 } 61 return encoderOfNullableUnion(cfg, schema, typ) 62 } 63 return encoderOfResolverUnion(cfg, schema, typ) 64 } 65 66 func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 67 union := schema.(*UnionSchema) 68 mapType := typ.(*reflect2.UnsafeMapType) 69 70 typeDecs := make([]ValDecoder, len(union.Types())) 71 for i, s := range union.Types() { 72 if s.Type() == Null { 73 continue 74 } 75 typeDecs[i] = newEfaceDecoder(cfg, s) 76 } 77 78 return &mapUnionDecoder{ 79 cfg: cfg, 80 schema: union, 81 mapType: mapType, 82 elemType: mapType.Elem(), 83 typeDecs: typeDecs, 84 } 85 } 86 87 type mapUnionDecoder struct { 88 cfg *frozenConfig 89 schema *UnionSchema 90 mapType *reflect2.UnsafeMapType 91 elemType reflect2.Type 92 typeDecs []ValDecoder 93 } 94 95 func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 96 idx, resSchema := getUnionSchema(d.schema, r) 97 if resSchema == nil { 98 return 99 } 100 101 // In a null case, just return 102 if resSchema.Type() == Null { 103 return 104 } 105 106 if d.mapType.UnsafeIsNil(ptr) { 107 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(1)) 108 } 109 110 key := schemaTypeName(resSchema) 111 keyPtr := reflect2.PtrOf(key) 112 113 elemPtr := d.elemType.UnsafeNew() 114 d.typeDecs[idx].Decode(elemPtr, r) 115 116 d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) 117 } 118 119 func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder { 120 union := schema.(*UnionSchema) 121 122 return &mapUnionEncoder{ 123 cfg: cfg, 124 schema: union, 125 } 126 } 127 128 type mapUnionEncoder struct { 129 cfg *frozenConfig 130 schema *UnionSchema 131 } 132 133 func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 134 m := *((*map[string]any)(ptr)) 135 136 if len(m) > 1 { 137 w.Error = errors.New("avro: cannot encode union map with multiple entries") 138 return 139 } 140 141 name := "null" 142 val := any(nil) 143 for k, v := range m { 144 name = k 145 val = v 146 break 147 } 148 149 schema, pos := e.schema.Types().Get(name) 150 if schema == nil { 151 w.Error = fmt.Errorf("avro: unknown union type %s", name) 152 return 153 } 154 155 w.WriteInt(int32(pos)) 156 157 if schema.Type() == Null && val == nil { 158 return 159 } 160 161 elemType := reflect2.TypeOf(val) 162 elemPtr := reflect2.PtrOf(val) 163 164 encoder := encoderOfType(e.cfg, schema, elemType) 165 if elemType.LikePtr() { 166 encoder = &onePtrEncoder{encoder} 167 } 168 encoder.Encode(elemPtr, w) 169 } 170 171 func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 172 union := schema.(*UnionSchema) 173 _, typeIdx := union.Indices() 174 175 var ( 176 baseTyp reflect2.Type 177 isPtr bool 178 ) 179 switch v := typ.(type) { 180 case *reflect2.UnsafePtrType: 181 baseTyp = v.Elem() 182 isPtr = true 183 case *reflect2.UnsafeSliceType: 184 baseTyp = v 185 } 186 decoder := decoderOfType(cfg, union.Types()[typeIdx], baseTyp) 187 188 return &unionNullableDecoder{ 189 schema: union, 190 typ: baseTyp, 191 isPtr: isPtr, 192 decoder: decoder, 193 } 194 } 195 196 type unionNullableDecoder struct { 197 schema *UnionSchema 198 typ reflect2.Type 199 isPtr bool 200 decoder ValDecoder 201 } 202 203 func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 204 _, schema := getUnionSchema(d.schema, r) 205 if schema == nil { 206 return 207 } 208 209 if schema.Type() == Null { 210 *((*unsafe.Pointer)(ptr)) = nil 211 return 212 } 213 214 // Handle the non-ptr case separately. 215 if !d.isPtr { 216 if d.typ.UnsafeIsNil(ptr) { 217 // Create a new instance. 218 newPtr := d.typ.UnsafeNew() 219 d.decoder.Decode(newPtr, r) 220 d.typ.UnsafeSet(ptr, newPtr) 221 return 222 } 223 224 // Reuse the existing instance. 225 d.decoder.Decode(ptr, r) 226 return 227 } 228 229 if *((*unsafe.Pointer)(ptr)) == nil { 230 // Create new instance. 231 newPtr := d.typ.UnsafeNew() 232 d.decoder.Decode(newPtr, r) 233 *((*unsafe.Pointer)(ptr)) = newPtr 234 return 235 } 236 237 // Reuse existing instance. 238 d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) 239 } 240 241 func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 242 union := schema.(*UnionSchema) 243 nullIdx, typeIdx := union.Indices() 244 245 var ( 246 baseTyp reflect2.Type 247 isPtr bool 248 ) 249 switch v := typ.(type) { 250 case *reflect2.UnsafePtrType: 251 baseTyp = v.Elem() 252 isPtr = true 253 case *reflect2.UnsafeSliceType: 254 baseTyp = v 255 } 256 encoder := encoderOfType(cfg, union.Types()[typeIdx], baseTyp) 257 258 return &unionNullableEncoder{ 259 schema: union, 260 encoder: encoder, 261 isPtr: isPtr, 262 nullIdx: int32(nullIdx), 263 typeIdx: int32(typeIdx), 264 } 265 } 266 267 type unionNullableEncoder struct { 268 schema *UnionSchema 269 encoder ValEncoder 270 isPtr bool 271 nullIdx int32 272 typeIdx int32 273 } 274 275 func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 276 if *((*unsafe.Pointer)(ptr)) == nil { 277 w.WriteInt(e.nullIdx) 278 return 279 } 280 281 w.WriteInt(e.typeIdx) 282 newPtr := ptr 283 if e.isPtr { 284 newPtr = *((*unsafe.Pointer)(ptr)) 285 } 286 e.encoder.Encode(newPtr, w) 287 } 288 289 func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) { 290 union := schema.(*UnionSchema) 291 292 types := make([]reflect2.Type, len(union.Types())) 293 decoders := make([]ValDecoder, len(union.Types())) 294 for i, schema := range union.Types() { 295 name := unionResolutionName(schema) 296 297 typ, err := cfg.resolver.Type(name) 298 if err != nil { 299 if cfg.config.UnionResolutionError { 300 return nil, err 301 } 302 303 if cfg.config.PartialUnionTypeResolution { 304 decoders[i] = nil 305 types[i] = nil 306 continue 307 } 308 309 decoders = []ValDecoder{} 310 types = []reflect2.Type{} 311 break 312 } 313 314 decoder := decoderOfType(cfg, schema, typ) 315 decoders[i] = decoder 316 types[i] = typ 317 } 318 319 return &unionResolvedDecoder{ 320 cfg: cfg, 321 schema: union, 322 types: types, 323 decoders: decoders, 324 }, nil 325 } 326 327 type unionResolvedDecoder struct { 328 cfg *frozenConfig 329 schema *UnionSchema 330 types []reflect2.Type 331 decoders []ValDecoder 332 } 333 334 func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 335 i, schema := getUnionSchema(d.schema, r) 336 if schema == nil { 337 return 338 } 339 340 pObj := (*any)(ptr) 341 342 if schema.Type() == Null { 343 *pObj = nil 344 return 345 } 346 347 if i >= len(d.decoders) || d.decoders[i] == nil { 348 if d.cfg.config.UnionResolutionError { 349 r.ReportError("decode union type", "unknown union type") 350 return 351 } 352 353 // We cannot resolve this, set it to the map type 354 name := schemaTypeName(schema) 355 obj := map[string]any{} 356 vTyp, err := genericReceiver(schema) 357 if err != nil { 358 r.ReportError("Union", err.Error()) 359 return 360 } 361 obj[name] = genericDecode(vTyp, decoderOfType(d.cfg, schema, vTyp), r) 362 363 *pObj = obj 364 return 365 } 366 367 typ := d.types[i] 368 var newPtr unsafe.Pointer 369 switch typ.Kind() { 370 case reflect.Map: 371 mapType := typ.(*reflect2.UnsafeMapType) 372 newPtr = mapType.UnsafeMakeMap(1) 373 374 case reflect.Slice: 375 mapType := typ.(*reflect2.UnsafeSliceType) 376 newPtr = mapType.UnsafeMakeSlice(1, 1) 377 378 case reflect.Ptr: 379 elemType := typ.(*reflect2.UnsafePtrType).Elem() 380 newPtr = elemType.UnsafeNew() 381 382 default: 383 newPtr = typ.UnsafeNew() 384 } 385 386 d.decoders[i].Decode(newPtr, r) 387 *pObj = typ.UnsafeIndirect(newPtr) 388 } 389 390 func unionResolutionName(schema Schema) string { 391 name := schemaTypeName(schema) 392 switch schema.Type() { 393 case Map: 394 name += ":" 395 valSchema := schema.(*MapSchema).Values() 396 valName := schemaTypeName(valSchema) 397 398 name += valName 399 400 case Array: 401 name += ":" 402 itemSchema := schema.(*ArraySchema).Items() 403 itemName := schemaTypeName(itemSchema) 404 405 name += itemName 406 } 407 408 return name 409 } 410 411 func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 412 union := schema.(*UnionSchema) 413 414 names, err := cfg.resolver.Name(typ) 415 if err != nil { 416 return &errorEncoder{err: err} 417 } 418 419 var pos int 420 for _, name := range names { 421 if idx := strings.Index(name, ":"); idx > 0 { 422 name = name[:idx] 423 } 424 425 schema, pos = union.Types().Get(name) 426 if schema != nil { 427 break 428 } 429 } 430 if schema == nil { 431 return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} 432 } 433 434 encoder := encoderOfType(cfg, schema, typ) 435 436 return &unionResolverEncoder{ 437 pos: pos, 438 encoder: encoder, 439 } 440 } 441 442 type unionResolverEncoder struct { 443 pos int 444 encoder ValEncoder 445 } 446 447 func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 448 w.WriteInt(int32(e.pos)) 449 450 e.encoder.Encode(ptr, w) 451 } 452 453 func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) { 454 types := schema.Types() 455 456 idx := int(r.ReadInt()) 457 if idx < 0 || idx > len(types)-1 { 458 r.ReportError("decode union type", "unknown union type") 459 return 0, nil 460 } 461 462 return idx, types[idx] 463 }