github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_union.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Map:
    16  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    17  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    18  			break
    19  		}
    20  		return decoderOfMapUnion(cfg, schema, typ)
    21  	case reflect.Slice:
    22  		if !schema.(*UnionSchema).Nullable() {
    23  			break
    24  		}
    25  		return decoderOfNullableUnion(cfg, schema, typ)
    26  	case reflect.Ptr:
    27  		if !schema.(*UnionSchema).Nullable() {
    28  			break
    29  		}
    30  		return decoderOfNullableUnion(cfg, schema, typ)
    31  	case reflect.Interface:
    32  		if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
    33  			dec, err := decoderOfResolvedUnion(cfg, schema)
    34  			if err != nil {
    35  				return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)}
    36  			}
    37  			return dec
    38  		}
    39  	}
    40  
    41  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    42  }
    43  
    44  func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    45  	switch typ.Kind() {
    46  	case reflect.Map:
    47  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    48  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    49  			break
    50  		}
    51  		return encoderOfMapUnion(cfg, schema, typ)
    52  	case reflect.Slice:
    53  		if !schema.(*UnionSchema).Nullable() {
    54  			break
    55  		}
    56  		return encoderOfNullableUnion(cfg, schema, typ)
    57  	case reflect.Ptr:
    58  		if !schema.(*UnionSchema).Nullable() {
    59  			break
    60  		}
    61  		return encoderOfNullableUnion(cfg, schema, typ)
    62  	}
    63  	return encoderOfResolverUnion(cfg, schema, typ)
    64  }
    65  
    66  func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    67  	union := schema.(*UnionSchema)
    68  	mapType := typ.(*reflect2.UnsafeMapType)
    69  
    70  	typeDecs := make([]ValDecoder, len(union.Types()))
    71  	for i, s := range union.Types() {
    72  		if s.Type() == Null {
    73  			continue
    74  		}
    75  		typeDecs[i] = newEfaceDecoder(cfg, s)
    76  	}
    77  
    78  	return &mapUnionDecoder{
    79  		cfg:      cfg,
    80  		schema:   union,
    81  		mapType:  mapType,
    82  		elemType: mapType.Elem(),
    83  		typeDecs: typeDecs,
    84  	}
    85  }
    86  
    87  type mapUnionDecoder struct {
    88  	cfg      *frozenConfig
    89  	schema   *UnionSchema
    90  	mapType  *reflect2.UnsafeMapType
    91  	elemType reflect2.Type
    92  	typeDecs []ValDecoder
    93  }
    94  
    95  func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    96  	idx, resSchema := getUnionSchema(d.schema, r)
    97  	if resSchema == nil {
    98  		return
    99  	}
   100  
   101  	// In a null case, just return
   102  	if resSchema.Type() == Null {
   103  		return
   104  	}
   105  
   106  	if d.mapType.UnsafeIsNil(ptr) {
   107  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(1))
   108  	}
   109  
   110  	key := schemaTypeName(resSchema)
   111  	keyPtr := reflect2.PtrOf(key)
   112  
   113  	elemPtr := d.elemType.UnsafeNew()
   114  	d.typeDecs[idx].Decode(elemPtr, r)
   115  
   116  	d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
   117  }
   118  
   119  func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder {
   120  	union := schema.(*UnionSchema)
   121  
   122  	return &mapUnionEncoder{
   123  		cfg:    cfg,
   124  		schema: union,
   125  	}
   126  }
   127  
   128  type mapUnionEncoder struct {
   129  	cfg    *frozenConfig
   130  	schema *UnionSchema
   131  }
   132  
   133  func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   134  	m := *((*map[string]any)(ptr))
   135  
   136  	if len(m) > 1 {
   137  		w.Error = errors.New("avro: cannot encode union map with multiple entries")
   138  		return
   139  	}
   140  
   141  	name := "null"
   142  	val := any(nil)
   143  	for k, v := range m {
   144  		name = k
   145  		val = v
   146  		break
   147  	}
   148  
   149  	schema, pos := e.schema.Types().Get(name)
   150  	if schema == nil {
   151  		w.Error = fmt.Errorf("avro: unknown union type %s", name)
   152  		return
   153  	}
   154  
   155  	w.WriteInt(int32(pos))
   156  
   157  	if schema.Type() == Null && val == nil {
   158  		return
   159  	}
   160  
   161  	elemType := reflect2.TypeOf(val)
   162  	elemPtr := reflect2.PtrOf(val)
   163  
   164  	encoder := encoderOfType(e.cfg, schema, elemType)
   165  	if elemType.LikePtr() {
   166  		encoder = &onePtrEncoder{encoder}
   167  	}
   168  	encoder.Encode(elemPtr, w)
   169  }
   170  
   171  func decoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   172  	union := schema.(*UnionSchema)
   173  	_, typeIdx := union.Indices()
   174  
   175  	var (
   176  		baseTyp reflect2.Type
   177  		isPtr   bool
   178  	)
   179  	switch v := typ.(type) {
   180  	case *reflect2.UnsafePtrType:
   181  		baseTyp = v.Elem()
   182  		isPtr = true
   183  	case *reflect2.UnsafeSliceType:
   184  		baseTyp = v
   185  	}
   186  	decoder := decoderOfType(cfg, union.Types()[typeIdx], baseTyp)
   187  
   188  	return &unionNullableDecoder{
   189  		schema:  union,
   190  		typ:     baseTyp,
   191  		isPtr:   isPtr,
   192  		decoder: decoder,
   193  	}
   194  }
   195  
   196  type unionNullableDecoder struct {
   197  	schema  *UnionSchema
   198  	typ     reflect2.Type
   199  	isPtr   bool
   200  	decoder ValDecoder
   201  }
   202  
   203  func (d *unionNullableDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   204  	_, schema := getUnionSchema(d.schema, r)
   205  	if schema == nil {
   206  		return
   207  	}
   208  
   209  	if schema.Type() == Null {
   210  		*((*unsafe.Pointer)(ptr)) = nil
   211  		return
   212  	}
   213  
   214  	// Handle the non-ptr case separately.
   215  	if !d.isPtr {
   216  		if d.typ.UnsafeIsNil(ptr) {
   217  			// Create a new instance.
   218  			newPtr := d.typ.UnsafeNew()
   219  			d.decoder.Decode(newPtr, r)
   220  			d.typ.UnsafeSet(ptr, newPtr)
   221  			return
   222  		}
   223  
   224  		// Reuse the existing instance.
   225  		d.decoder.Decode(ptr, r)
   226  		return
   227  	}
   228  
   229  	if *((*unsafe.Pointer)(ptr)) == nil {
   230  		// Create new instance.
   231  		newPtr := d.typ.UnsafeNew()
   232  		d.decoder.Decode(newPtr, r)
   233  		*((*unsafe.Pointer)(ptr)) = newPtr
   234  		return
   235  	}
   236  
   237  	// Reuse existing instance.
   238  	d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r)
   239  }
   240  
   241  func encoderOfNullableUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   242  	union := schema.(*UnionSchema)
   243  	nullIdx, typeIdx := union.Indices()
   244  
   245  	var (
   246  		baseTyp reflect2.Type
   247  		isPtr   bool
   248  	)
   249  	switch v := typ.(type) {
   250  	case *reflect2.UnsafePtrType:
   251  		baseTyp = v.Elem()
   252  		isPtr = true
   253  	case *reflect2.UnsafeSliceType:
   254  		baseTyp = v
   255  	}
   256  	encoder := encoderOfType(cfg, union.Types()[typeIdx], baseTyp)
   257  
   258  	return &unionNullableEncoder{
   259  		schema:  union,
   260  		encoder: encoder,
   261  		isPtr:   isPtr,
   262  		nullIdx: int32(nullIdx),
   263  		typeIdx: int32(typeIdx),
   264  	}
   265  }
   266  
   267  type unionNullableEncoder struct {
   268  	schema  *UnionSchema
   269  	encoder ValEncoder
   270  	isPtr   bool
   271  	nullIdx int32
   272  	typeIdx int32
   273  }
   274  
   275  func (e *unionNullableEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   276  	if *((*unsafe.Pointer)(ptr)) == nil {
   277  		w.WriteInt(e.nullIdx)
   278  		return
   279  	}
   280  
   281  	w.WriteInt(e.typeIdx)
   282  	newPtr := ptr
   283  	if e.isPtr {
   284  		newPtr = *((*unsafe.Pointer)(ptr))
   285  	}
   286  	e.encoder.Encode(newPtr, w)
   287  }
   288  
   289  func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) {
   290  	union := schema.(*UnionSchema)
   291  
   292  	types := make([]reflect2.Type, len(union.Types()))
   293  	decoders := make([]ValDecoder, len(union.Types()))
   294  	for i, schema := range union.Types() {
   295  		name := unionResolutionName(schema)
   296  
   297  		typ, err := cfg.resolver.Type(name)
   298  		if err != nil {
   299  			if cfg.config.UnionResolutionError {
   300  				return nil, err
   301  			}
   302  
   303  			if cfg.config.PartialUnionTypeResolution {
   304  				decoders[i] = nil
   305  				types[i] = nil
   306  				continue
   307  			}
   308  
   309  			decoders = []ValDecoder{}
   310  			types = []reflect2.Type{}
   311  			break
   312  		}
   313  
   314  		decoder := decoderOfType(cfg, schema, typ)
   315  		decoders[i] = decoder
   316  		types[i] = typ
   317  	}
   318  
   319  	return &unionResolvedDecoder{
   320  		cfg:      cfg,
   321  		schema:   union,
   322  		types:    types,
   323  		decoders: decoders,
   324  	}, nil
   325  }
   326  
   327  type unionResolvedDecoder struct {
   328  	cfg      *frozenConfig
   329  	schema   *UnionSchema
   330  	types    []reflect2.Type
   331  	decoders []ValDecoder
   332  }
   333  
   334  func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   335  	i, schema := getUnionSchema(d.schema, r)
   336  	if schema == nil {
   337  		return
   338  	}
   339  
   340  	pObj := (*any)(ptr)
   341  
   342  	if schema.Type() == Null {
   343  		*pObj = nil
   344  		return
   345  	}
   346  
   347  	if i >= len(d.decoders) || d.decoders[i] == nil {
   348  		if d.cfg.config.UnionResolutionError {
   349  			r.ReportError("decode union type", "unknown union type")
   350  			return
   351  		}
   352  
   353  		// We cannot resolve this, set it to the map type
   354  		name := schemaTypeName(schema)
   355  		obj := map[string]any{}
   356  		vTyp, err := genericReceiver(schema)
   357  		if err != nil {
   358  			r.ReportError("Union", err.Error())
   359  			return
   360  		}
   361  		obj[name] = genericDecode(vTyp, decoderOfType(d.cfg, schema, vTyp), r)
   362  
   363  		*pObj = obj
   364  		return
   365  	}
   366  
   367  	typ := d.types[i]
   368  	var newPtr unsafe.Pointer
   369  	switch typ.Kind() {
   370  	case reflect.Map:
   371  		mapType := typ.(*reflect2.UnsafeMapType)
   372  		newPtr = mapType.UnsafeMakeMap(1)
   373  
   374  	case reflect.Slice:
   375  		mapType := typ.(*reflect2.UnsafeSliceType)
   376  		newPtr = mapType.UnsafeMakeSlice(1, 1)
   377  
   378  	case reflect.Ptr:
   379  		elemType := typ.(*reflect2.UnsafePtrType).Elem()
   380  		newPtr = elemType.UnsafeNew()
   381  
   382  	default:
   383  		newPtr = typ.UnsafeNew()
   384  	}
   385  
   386  	d.decoders[i].Decode(newPtr, r)
   387  	*pObj = typ.UnsafeIndirect(newPtr)
   388  }
   389  
   390  func unionResolutionName(schema Schema) string {
   391  	name := schemaTypeName(schema)
   392  	switch schema.Type() {
   393  	case Map:
   394  		name += ":"
   395  		valSchema := schema.(*MapSchema).Values()
   396  		valName := schemaTypeName(valSchema)
   397  
   398  		name += valName
   399  
   400  	case Array:
   401  		name += ":"
   402  		itemSchema := schema.(*ArraySchema).Items()
   403  		itemName := schemaTypeName(itemSchema)
   404  
   405  		name += itemName
   406  	}
   407  
   408  	return name
   409  }
   410  
   411  func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   412  	union := schema.(*UnionSchema)
   413  
   414  	names, err := cfg.resolver.Name(typ)
   415  	if err != nil {
   416  		return &errorEncoder{err: err}
   417  	}
   418  
   419  	var pos int
   420  	for _, name := range names {
   421  		if idx := strings.Index(name, ":"); idx > 0 {
   422  			name = name[:idx]
   423  		}
   424  
   425  		schema, pos = union.Types().Get(name)
   426  		if schema != nil {
   427  			break
   428  		}
   429  	}
   430  	if schema == nil {
   431  		return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])}
   432  	}
   433  
   434  	encoder := encoderOfType(cfg, schema, typ)
   435  
   436  	return &unionResolverEncoder{
   437  		pos:     pos,
   438  		encoder: encoder,
   439  	}
   440  }
   441  
   442  type unionResolverEncoder struct {
   443  	pos     int
   444  	encoder ValEncoder
   445  }
   446  
   447  func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   448  	w.WriteInt(int32(e.pos))
   449  
   450  	e.encoder.Encode(ptr, w)
   451  }
   452  
   453  func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) {
   454  	types := schema.Types()
   455  
   456  	idx := int(r.ReadInt())
   457  	if idx < 0 || idx > len(types)-1 {
   458  		r.ReportError("decode union type", "unknown union type")
   459  		return 0, nil
   460  	}
   461  
   462  	return idx, types[idx]
   463  }