github.com/hamba/avro@v1.8.0/codec_union.go (about) 1 package avro 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 "strings" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 14 switch typ.Kind() { 15 case reflect.Map: 16 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 17 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 18 break 19 } 20 return decoderOfMapUnion(cfg, schema, typ) 21 22 case reflect.Ptr: 23 if !schema.(*UnionSchema).Nullable() { 24 break 25 } 26 return decoderOfPtrUnion(cfg, schema, typ) 27 28 case reflect.Interface: 29 if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { 30 return decoderOfResolvedUnion(cfg, schema) 31 } 32 } 33 34 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 35 } 36 37 func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 38 switch typ.Kind() { 39 case reflect.Map: 40 if typ.(reflect2.MapType).Key().Kind() != reflect.String || 41 typ.(reflect2.MapType).Elem().Kind() != reflect.Interface { 42 break 43 } 44 return encoderOfMapUnion(cfg, schema, typ) 45 46 case reflect.Ptr: 47 if !schema.(*UnionSchema).Nullable() { 48 break 49 } 50 return encoderOfPtrUnion(cfg, schema, typ) 51 } 52 53 return encoderOfResolverUnion(cfg, schema, typ) 54 } 55 56 func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 57 union := schema.(*UnionSchema) 58 mapType := typ.(*reflect2.UnsafeMapType) 59 60 return &mapUnionDecoder{ 61 cfg: cfg, 62 schema: union, 63 mapType: mapType, 64 elemType: mapType.Elem(), 65 } 66 } 67 68 type mapUnionDecoder struct { 69 cfg *frozenConfig 70 schema *UnionSchema 71 mapType *reflect2.UnsafeMapType 72 elemType reflect2.Type 73 } 74 75 func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 76 _, resSchema := getUnionSchema(d.schema, r) 77 if resSchema == nil { 78 return 79 } 80 81 // In a null case, just return 82 if resSchema.Type() == Null { 83 return 84 } 85 86 if d.mapType.UnsafeIsNil(ptr) { 87 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 88 } 89 90 key := schemaTypeName(resSchema) 91 keyPtr := reflect2.PtrOf(key) 92 93 elemPtr := d.elemType.UnsafeNew() 94 decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r) 95 96 d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) 97 } 98 99 func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder { 100 union := schema.(*UnionSchema) 101 102 return &mapUnionEncoder{ 103 cfg: cfg, 104 schema: union, 105 } 106 } 107 108 type mapUnionEncoder struct { 109 cfg *frozenConfig 110 schema *UnionSchema 111 } 112 113 func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 114 m := *((*map[string]interface{})(ptr)) 115 116 if len(m) > 1 { 117 w.Error = errors.New("avro: cannot encode union map with multiple entries") 118 return 119 } 120 121 name := "null" 122 val := interface{}(nil) 123 for k, v := range m { 124 name = k 125 val = v 126 break 127 } 128 129 schema, pos := e.schema.Types().Get(name) 130 if schema == nil { 131 w.Error = fmt.Errorf("avro: unknown union type %s", name) 132 return 133 } 134 135 w.WriteLong(int64(pos)) 136 137 if schema.Type() == Null && val == nil { 138 return 139 } 140 141 elemType := reflect2.TypeOf(val) 142 elemPtr := reflect2.PtrOf(val) 143 144 encoder := encoderOfType(e.cfg, schema, elemType) 145 if elemType.LikePtr() { 146 encoder = &onePtrEncoder{encoder} 147 } 148 encoder.Encode(elemPtr, w) 149 } 150 151 func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 152 union := schema.(*UnionSchema) 153 _, typeIdx := union.Indices() 154 ptrType := typ.(*reflect2.UnsafePtrType) 155 elemType := ptrType.Elem() 156 decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) 157 158 return &unionPtrDecoder{ 159 schema: union, 160 typ: elemType, 161 decoder: decoder, 162 } 163 } 164 165 type unionPtrDecoder struct { 166 schema *UnionSchema 167 typ reflect2.Type 168 decoder ValDecoder 169 } 170 171 func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 172 _, schema := getUnionSchema(d.schema, r) 173 if schema == nil { 174 return 175 } 176 177 if schema.Type() == Null { 178 *((*unsafe.Pointer)(ptr)) = nil 179 return 180 } 181 182 if *((*unsafe.Pointer)(ptr)) == nil { 183 // Create new instance 184 newPtr := d.typ.UnsafeNew() 185 d.decoder.Decode(newPtr, r) 186 *((*unsafe.Pointer)(ptr)) = newPtr 187 return 188 } 189 190 // Reuse existing instance 191 d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r) 192 } 193 194 func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 195 union := schema.(*UnionSchema) 196 nullIdx, typeIdx := union.Indices() 197 ptrType := typ.(*reflect2.UnsafePtrType) 198 encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) 199 200 return &unionPtrEncoder{ 201 schema: union, 202 encoder: encoder, 203 nullIdx: int64(nullIdx), 204 typeIdx: int64(typeIdx), 205 } 206 } 207 208 type unionPtrEncoder struct { 209 schema *UnionSchema 210 encoder ValEncoder 211 nullIdx int64 212 typeIdx int64 213 } 214 215 func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 216 if *((*unsafe.Pointer)(ptr)) == nil { 217 w.WriteLong(e.nullIdx) 218 return 219 } 220 221 w.WriteLong(e.typeIdx) 222 e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w) 223 } 224 225 func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) ValDecoder { 226 union := schema.(*UnionSchema) 227 228 types := make([]reflect2.Type, len(union.Types())) 229 decoders := make([]ValDecoder, len(union.Types())) 230 for i, schema := range union.Types() { 231 name := unionResolutionName(schema) 232 if typ, err := cfg.resolver.Type(name); err == nil { 233 decoder := decoderOfType(cfg, schema, typ) 234 decoders[i] = decoder 235 types[i] = typ 236 continue 237 } 238 239 decoders = []ValDecoder{} 240 types = []reflect2.Type{} 241 break 242 } 243 244 return &unionResolvedDecoder{ 245 cfg: cfg, 246 schema: union, 247 types: types, 248 decoders: decoders, 249 } 250 } 251 252 type unionResolvedDecoder struct { 253 cfg *frozenConfig 254 schema *UnionSchema 255 types []reflect2.Type 256 decoders []ValDecoder 257 } 258 259 func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 260 i, schema := getUnionSchema(d.schema, r) 261 if schema == nil { 262 return 263 } 264 265 pObj := (*interface{})(ptr) 266 267 if schema.Type() == Null { 268 *pObj = nil 269 return 270 } 271 272 if i >= len(d.decoders) { 273 if d.cfg.config.UnionResolutionError { 274 r.ReportError("decode union type", "unknown union type") 275 return 276 } 277 278 // We cannot resolve this, set it to the map type 279 name := schemaTypeName(schema) 280 obj := map[string]interface{}{} 281 obj[name] = r.ReadNext(schema) 282 283 *pObj = obj 284 return 285 } 286 287 typ := d.types[i] 288 var newPtr unsafe.Pointer 289 switch typ.Kind() { 290 case reflect.Map: 291 mapType := typ.(*reflect2.UnsafeMapType) 292 newPtr = mapType.UnsafeMakeMap(1) 293 294 case reflect.Slice: 295 mapType := typ.(*reflect2.UnsafeSliceType) 296 newPtr = mapType.UnsafeMakeSlice(1, 1) 297 298 case reflect.Ptr: 299 elemType := typ.(*reflect2.UnsafePtrType).Elem() 300 newPtr = elemType.UnsafeNew() 301 302 default: 303 newPtr = typ.UnsafeNew() 304 } 305 306 d.decoders[i].Decode(newPtr, r) 307 *pObj = typ.UnsafeIndirect(newPtr) 308 } 309 310 func unionResolutionName(schema Schema) string { 311 name := schemaTypeName(schema) 312 switch schema.Type() { 313 case Map: 314 name += ":" 315 valSchema := schema.(*MapSchema).Values() 316 valName := schemaTypeName(valSchema) 317 318 name += valName 319 320 case Array: 321 name += ":" 322 itemSchema := schema.(*ArraySchema).Items() 323 itemName := schemaTypeName(itemSchema) 324 325 name += itemName 326 } 327 328 return name 329 } 330 331 func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 332 union := schema.(*UnionSchema) 333 334 names, err := cfg.resolver.Name(typ) 335 if err != nil { 336 return &errorEncoder{err: err} 337 } 338 339 var pos int 340 for _, name := range names { 341 if idx := strings.Index(name, ":"); idx > 0 { 342 name = name[:idx] 343 } 344 345 schema, pos = union.Types().Get(name) 346 if schema != nil { 347 break 348 } 349 } 350 if schema == nil { 351 return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])} 352 } 353 354 encoder := encoderOfType(cfg, schema, typ) 355 356 return &unionResolverEncoder{ 357 pos: pos, 358 encoder: encoder, 359 } 360 } 361 362 type unionResolverEncoder struct { 363 pos int 364 encoder ValEncoder 365 } 366 367 func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 368 w.WriteLong(int64(e.pos)) 369 370 e.encoder.Encode(ptr, w) 371 } 372 373 func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) { 374 types := schema.Types() 375 376 idx := int(r.ReadLong()) 377 if idx < 0 || idx > len(types)-1 { 378 r.ReportError("decode union type", "unknown union type") 379 return 0, nil 380 } 381 382 return idx, types[idx] 383 }