github.com/aacfactory/avro@v1.2.12/internal/base/codec_union.go (about) 1 package base 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strings" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Map: 16 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 17 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 18 break 19 } 20 return decoderOfMapUnion(cfg, schema, typ) 21 22 case reflect.Ptr: 23 if !schema.(*UnionSchema).Nullable() { 24 break 25 } 26 return decoderOfPtrUnion(cfg, schema, typ) 27 28 case reflect.Interface: 29 if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { 30 dec, err := decoderOfResolvedUnion(cfg, schema) 31 if err != nil { 32 return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)} 33 } 34 35 return dec 36 } 37 default: 38 break 39 } 40 41 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 42 } 43 44 func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 45 switch typ.Kind() { 46 case reflect.Map: 47 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 48 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 49 break 50 } 51 return encoderOfMapUnion(cfg, schema, typ) 52 53 case reflect.Ptr: 54 if !schema.(*UnionSchema).Nullable() { 55 break 56 } 57 return encoderOfPtrUnion(cfg, schema, typ) 58 default: 59 break 60 } 61 62 return encoderOfResolverUnion(cfg, schema, typ) 63 } 64 65 func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 66 union := schema.(*UnionSchema) 67 mapType := typ.(*reflect2.UnsafeMapType) 68 69 return &mapUnionDecoder{ 70 cfg: cfg, 71 schema: union, 72 mapType: mapType, 73 elemType: mapType.Elem(), 74 } 75 } 76 77 type mapUnionDecoder struct { 78 cfg *frozenConfig 79 schema *UnionSchema 80 mapType *reflect2.UnsafeMapType 81 elemType reflect2.Type 82 } 83 84 func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 85 _, resSchema := getUnionSchema(d.schema, r) 86 if resSchema == nil { 87 return 88 } 89 90 // In a null case, just return 91 if resSchema.Type() == Null { 92 return 93 } 94 95 if d.mapType.UnsafeIsNil(ptr) { 96 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 97 } 98 99 key := schemaTypeName(resSchema) 100 keyPtr := reflect2.PtrOf(key) 101 102 elemPtr := d.elemType.UnsafeNew() 103 decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r) 104 105 d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) 106 } 107 108 func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder { 109 union := schema.(*UnionSchema) 110 111 return &mapUnionEncoder{ 112 cfg: cfg, 113 schema: union, 114 } 115 } 116 117 type mapUnionEncoder struct { 118 cfg *frozenConfig 119 schema *UnionSchema 120 } 121 122 func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 123 m := *((*map[string]any)(ptr)) 124 125 if len(m) > 1 { 126 w.Error = errors.New("avro: cannot encode union map with multiple entries") 127 return 128 } 129 130 name := "null" 131 val := any(nil) 132 for k, v := range m { 133 name = k 134 val = v 135 break 136 } 137 138 schema, pos := e.schema.Types().Get(name) 139 if schema == nil { 140 w.Error = fmt.Errorf("avro: unknown union type %s", name) 141 return 142 } 143 144 w.WriteLong(int64(pos)) 145 146 if schema.Type() == Null && val == nil { 147 return 148 } 149 150 elemType := reflect2.TypeOf(val) 151 elemPtr := reflect2.PtrOf(val) 152 153 encoder := encoderOfType(e.cfg, schema, elemType) 154 if elemType.LikePtr() { 155 encoder = &onePtrEncoder{encoder} 156 } 157 encoder.Encode(elemPtr, w) 158 } 159 160 func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 161 union := schema.(*UnionSchema) 162 _, typeIdx := union.Indices() 163 ptrType := typ.(*reflect2.UnsafePtrType) 164 elemType := ptrType.Elem() 165 decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) 166 167 return &unionPtrDecoder{ 168 schema: union, 169 typ: elemType, 170 decoder: decoder, 171 } 172 } 173 174 type unionPtrDecoder struct { 175 schema *UnionSchema 176 typ reflect2.Type 177 decoder ValDecoder 178 } 179 180 func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 181 _, schema := getUnionSchema(d.schema, r) 182 if schema == nil { 183 return 184 } 185 186 if schema.Type() == Null { 187 *((*unsafe.Pointer)(ptr)) = nil 188 return 189 } 190 191 if *((*unsafe.Pointer)(ptr)) == nil { 192 // Create new instance 193 newPtr := d.typ.UnsafeNew() 194 d.decoder.Decode(newPtr, r) 195 *((*unsafe.Pointer)(ptr)) = newPtr 196 return 197 } 198 199 // Reuse existing instance 200 d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) 201 } 202 203 func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 204 union := schema.(*UnionSchema) 205 nullIdx, typeIdx := union.Indices() 206 ptrType := typ.(*reflect2.UnsafePtrType) 207 encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) 208 209 return &unionPtrEncoder{ 210 schema: union, 211 encoder: encoder, 212 nullIdx: int64(nullIdx), 213 typeIdx: int64(typeIdx), 214 } 215 } 216 217 type unionPtrEncoder struct { 218 schema *UnionSchema 219 encoder ValEncoder 220 nullIdx int64 221 typeIdx int64 222 } 223 224 func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 225 if *((*unsafe.Pointer)(ptr)) == nil { 226 w.WriteLong(e.nullIdx) 227 return 228 } 229 230 w.WriteLong(e.typeIdx) 231 e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) 232 } 233 234 func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) { 235 union := schema.(*UnionSchema) 236 237 types := make([]reflect2.Type, len(union.Types())) 238 decoders := make([]ValDecoder, len(union.Types())) 239 for i, schema := range union.Types() { 240 name := unionResolutionName(schema) 241 242 typ, err := cfg.resolver.Type(name) 243 if err != nil { 244 if cfg.config.UnionResolutionError { 245 return nil, err 246 } 247 248 if cfg.config.PartialUnionTypeResolution { 249 decoders[i] = nil 250 types[i] = nil 251 continue 252 } 253 254 decoders = []ValDecoder{} 255 types = []reflect2.Type{} 256 break 257 } 258 259 decoder := decoderOfType(cfg, schema, typ) 260 decoders[i] = decoder 261 types[i] = typ 262 } 263 264 return &unionResolvedDecoder{ 265 cfg: cfg, 266 schema: union, 267 types: types, 268 decoders: decoders, 269 }, nil 270 } 271 272 type unionResolvedDecoder struct { 273 cfg *frozenConfig 274 schema *UnionSchema 275 types []reflect2.Type 276 decoders []ValDecoder 277 } 278 279 func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 280 i, schema := getUnionSchema(d.schema, r) 281 if schema == nil { 282 return 283 } 284 285 pObj := (*any)(ptr) 286 287 if schema.Type() == Null { 288 *pObj = nil 289 return 290 } 291 292 if i >= len(d.decoders) || d.decoders[i] == nil { 293 if d.cfg.config.UnionResolutionError { 294 r.ReportError("decode union type", "unknown union type") 295 return 296 } 297 298 // We cannot resolve this, set it to the map type 299 name := schemaTypeName(schema) 300 obj := map[string]any{} 301 obj[name] = r.ReadNext(schema) 302 303 *pObj = obj 304 return 305 } 306 307 typ := d.types[i] 308 var newPtr unsafe.Pointer 309 switch typ.Kind() { 310 case reflect.Map: 311 mapType := typ.(*reflect2.UnsafeMapType) 312 newPtr = mapType.UnsafeMakeMap(1) 313 314 case reflect.Slice: 315 mapType := typ.(*reflect2.UnsafeSliceType) 316 newPtr = mapType.UnsafeMakeSlice(1, 1) 317 318 case reflect.Ptr: 319 elemType := typ.(*reflect2.UnsafePtrType).Elem() 320 newPtr = elemType.UnsafeNew() 321 322 default: 323 newPtr = typ.UnsafeNew() 324 } 325 326 d.decoders[i].Decode(newPtr, r) 327 *pObj = typ.UnsafeIndirect(newPtr) 328 } 329 330 func unionResolutionName(schema Schema) string { 331 name := schemaTypeName(schema) 332 switch schema.Type() { 333 case Map: 334 name += ":" 335 valSchema := schema.(*MapSchema).Values() 336 valName := schemaTypeName(valSchema) 337 338 name += valName 339 340 case Array: 341 name += ":" 342 itemSchema := schema.(*ArraySchema).Items() 343 itemName := schemaTypeName(itemSchema) 344 345 name += itemName 346 } 347 348 return name 349 } 350 351 func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 352 union := schema.(*UnionSchema) 353 354 names, err := cfg.resolver.Name(typ) 355 if err != nil { 356 return &errorEncoder{err: err} 357 } 358 359 var pos int 360 for _, name := range names { 361 if idx := strings.Index(name, ":"); idx > 0 { 362 name = name[:idx] 363 } 364 365 schema, pos = union.Types().Get(name) 366 if schema != nil { 367 break 368 } 369 } 370 if schema == nil { 371 return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} 372 } 373 374 encoder := encoderOfType(cfg, schema, typ) 375 376 return &unionResolverEncoder{ 377 pos: pos, 378 encoder: encoder, 379 } 380 } 381 382 type unionResolverEncoder struct { 383 pos int 384 encoder ValEncoder 385 } 386 387 func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 388 w.WriteLong(int64(e.pos)) 389 390 e.encoder.Encode(ptr, w) 391 } 392 393 func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) { 394 types := schema.Types() 395 396 idx := int(r.ReadLong()) 397 if idx < 0 || idx > len(types)-1 { 398 r.ReportError("decode union type", "unknown union type") 399 return 0, nil 400 } 401 402 return idx, types[idx] 403 }