github.com/aacfactory/avro@v1.2.12/internal/base/codec_union.go (about)

     1  package base
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Map:
    16  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    17  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    18  			break
    19  		}
    20  		return decoderOfMapUnion(cfg, schema, typ)
    21  
    22  	case reflect.Ptr:
    23  		if !schema.(*UnionSchema).Nullable() {
    24  			break
    25  		}
    26  		return decoderOfPtrUnion(cfg, schema, typ)
    27  
    28  	case reflect.Interface:
    29  		if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
    30  			dec, err := decoderOfResolvedUnion(cfg, schema)
    31  			if err != nil {
    32  				return &errorDecoder{err: fmt.Errorf("avro: problem resolving decoder for Avro %s: %w", schema.Type(), err)}
    33  			}
    34  
    35  			return dec
    36  		}
    37  	default:
    38  		break
    39  	}
    40  
    41  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    42  }
    43  
    44  func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    45  	switch typ.Kind() {
    46  	case reflect.Map:
    47  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    48  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    49  			break
    50  		}
    51  		return encoderOfMapUnion(cfg, schema, typ)
    52  
    53  	case reflect.Ptr:
    54  		if !schema.(*UnionSchema).Nullable() {
    55  			break
    56  		}
    57  		return encoderOfPtrUnion(cfg, schema, typ)
    58  	default:
    59  		break
    60  	}
    61  
    62  	return encoderOfResolverUnion(cfg, schema, typ)
    63  }
    64  
    65  func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    66  	union := schema.(*UnionSchema)
    67  	mapType := typ.(*reflect2.UnsafeMapType)
    68  
    69  	return &mapUnionDecoder{
    70  		cfg:      cfg,
    71  		schema:   union,
    72  		mapType:  mapType,
    73  		elemType: mapType.Elem(),
    74  	}
    75  }
    76  
    77  type mapUnionDecoder struct {
    78  	cfg      *frozenConfig
    79  	schema   *UnionSchema
    80  	mapType  *reflect2.UnsafeMapType
    81  	elemType reflect2.Type
    82  }
    83  
    84  func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    85  	_, resSchema := getUnionSchema(d.schema, r)
    86  	if resSchema == nil {
    87  		return
    88  	}
    89  
    90  	// In a null case, just return
    91  	if resSchema.Type() == Null {
    92  		return
    93  	}
    94  
    95  	if d.mapType.UnsafeIsNil(ptr) {
    96  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
    97  	}
    98  
    99  	key := schemaTypeName(resSchema)
   100  	keyPtr := reflect2.PtrOf(key)
   101  
   102  	elemPtr := d.elemType.UnsafeNew()
   103  	decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r)
   104  
   105  	d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
   106  }
   107  
   108  func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder {
   109  	union := schema.(*UnionSchema)
   110  
   111  	return &mapUnionEncoder{
   112  		cfg:    cfg,
   113  		schema: union,
   114  	}
   115  }
   116  
   117  type mapUnionEncoder struct {
   118  	cfg    *frozenConfig
   119  	schema *UnionSchema
   120  }
   121  
   122  func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   123  	m := *((*map[string]any)(ptr))
   124  
   125  	if len(m) > 1 {
   126  		w.Error = errors.New("avro: cannot encode union map with multiple entries")
   127  		return
   128  	}
   129  
   130  	name := "null"
   131  	val := any(nil)
   132  	for k, v := range m {
   133  		name = k
   134  		val = v
   135  		break
   136  	}
   137  
   138  	schema, pos := e.schema.Types().Get(name)
   139  	if schema == nil {
   140  		w.Error = fmt.Errorf("avro: unknown union type %s", name)
   141  		return
   142  	}
   143  
   144  	w.WriteLong(int64(pos))
   145  
   146  	if schema.Type() == Null && val == nil {
   147  		return
   148  	}
   149  
   150  	elemType := reflect2.TypeOf(val)
   151  	elemPtr := reflect2.PtrOf(val)
   152  
   153  	encoder := encoderOfType(e.cfg, schema, elemType)
   154  	if elemType.LikePtr() {
   155  		encoder = &onePtrEncoder{encoder}
   156  	}
   157  	encoder.Encode(elemPtr, w)
   158  }
   159  
   160  func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   161  	union := schema.(*UnionSchema)
   162  	_, typeIdx := union.Indices()
   163  	ptrType := typ.(*reflect2.UnsafePtrType)
   164  	elemType := ptrType.Elem()
   165  	decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType)
   166  
   167  	return &unionPtrDecoder{
   168  		schema:  union,
   169  		typ:     elemType,
   170  		decoder: decoder,
   171  	}
   172  }
   173  
   174  type unionPtrDecoder struct {
   175  	schema  *UnionSchema
   176  	typ     reflect2.Type
   177  	decoder ValDecoder
   178  }
   179  
   180  func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   181  	_, schema := getUnionSchema(d.schema, r)
   182  	if schema == nil {
   183  		return
   184  	}
   185  
   186  	if schema.Type() == Null {
   187  		*((*unsafe.Pointer)(ptr)) = nil
   188  		return
   189  	}
   190  
   191  	if *((*unsafe.Pointer)(ptr)) == nil {
   192  		// Create new instance
   193  		newPtr := d.typ.UnsafeNew()
   194  		d.decoder.Decode(newPtr, r)
   195  		*((*unsafe.Pointer)(ptr)) = newPtr
   196  		return
   197  	}
   198  
   199  	// Reuse existing instance
   200  	d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r)
   201  }
   202  
   203  func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   204  	union := schema.(*UnionSchema)
   205  	nullIdx, typeIdx := union.Indices()
   206  	ptrType := typ.(*reflect2.UnsafePtrType)
   207  	encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem())
   208  
   209  	return &unionPtrEncoder{
   210  		schema:  union,
   211  		encoder: encoder,
   212  		nullIdx: int64(nullIdx),
   213  		typeIdx: int64(typeIdx),
   214  	}
   215  }
   216  
   217  type unionPtrEncoder struct {
   218  	schema  *UnionSchema
   219  	encoder ValEncoder
   220  	nullIdx int64
   221  	typeIdx int64
   222  }
   223  
   224  func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   225  	if *((*unsafe.Pointer)(ptr)) == nil {
   226  		w.WriteLong(e.nullIdx)
   227  		return
   228  	}
   229  
   230  	w.WriteLong(e.typeIdx)
   231  	e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w)
   232  }
   233  
   234  func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) (ValDecoder, error) {
   235  	union := schema.(*UnionSchema)
   236  
   237  	types := make([]reflect2.Type, len(union.Types()))
   238  	decoders := make([]ValDecoder, len(union.Types()))
   239  	for i, schema := range union.Types() {
   240  		name := unionResolutionName(schema)
   241  
   242  		typ, err := cfg.resolver.Type(name)
   243  		if err != nil {
   244  			if cfg.config.UnionResolutionError {
   245  				return nil, err
   246  			}
   247  
   248  			if cfg.config.PartialUnionTypeResolution {
   249  				decoders[i] = nil
   250  				types[i] = nil
   251  				continue
   252  			}
   253  
   254  			decoders = []ValDecoder{}
   255  			types = []reflect2.Type{}
   256  			break
   257  		}
   258  
   259  		decoder := decoderOfType(cfg, schema, typ)
   260  		decoders[i] = decoder
   261  		types[i] = typ
   262  	}
   263  
   264  	return &unionResolvedDecoder{
   265  		cfg:      cfg,
   266  		schema:   union,
   267  		types:    types,
   268  		decoders: decoders,
   269  	}, nil
   270  }
   271  
   272  type unionResolvedDecoder struct {
   273  	cfg      *frozenConfig
   274  	schema   *UnionSchema
   275  	types    []reflect2.Type
   276  	decoders []ValDecoder
   277  }
   278  
   279  func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   280  	i, schema := getUnionSchema(d.schema, r)
   281  	if schema == nil {
   282  		return
   283  	}
   284  
   285  	pObj := (*any)(ptr)
   286  
   287  	if schema.Type() == Null {
   288  		*pObj = nil
   289  		return
   290  	}
   291  
   292  	if i >= len(d.decoders) || d.decoders[i] == nil {
   293  		if d.cfg.config.UnionResolutionError {
   294  			r.ReportError("decode union type", "unknown union type")
   295  			return
   296  		}
   297  
   298  		// We cannot resolve this, set it to the map type
   299  		name := schemaTypeName(schema)
   300  		obj := map[string]any{}
   301  		obj[name] = r.ReadNext(schema)
   302  
   303  		*pObj = obj
   304  		return
   305  	}
   306  
   307  	typ := d.types[i]
   308  	var newPtr unsafe.Pointer
   309  	switch typ.Kind() {
   310  	case reflect.Map:
   311  		mapType := typ.(*reflect2.UnsafeMapType)
   312  		newPtr = mapType.UnsafeMakeMap(1)
   313  
   314  	case reflect.Slice:
   315  		mapType := typ.(*reflect2.UnsafeSliceType)
   316  		newPtr = mapType.UnsafeMakeSlice(1, 1)
   317  
   318  	case reflect.Ptr:
   319  		elemType := typ.(*reflect2.UnsafePtrType).Elem()
   320  		newPtr = elemType.UnsafeNew()
   321  
   322  	default:
   323  		newPtr = typ.UnsafeNew()
   324  	}
   325  
   326  	d.decoders[i].Decode(newPtr, r)
   327  	*pObj = typ.UnsafeIndirect(newPtr)
   328  }
   329  
   330  func unionResolutionName(schema Schema) string {
   331  	name := schemaTypeName(schema)
   332  	switch schema.Type() {
   333  	case Map:
   334  		name += ":"
   335  		valSchema := schema.(*MapSchema).Values()
   336  		valName := schemaTypeName(valSchema)
   337  
   338  		name += valName
   339  
   340  	case Array:
   341  		name += ":"
   342  		itemSchema := schema.(*ArraySchema).Items()
   343  		itemName := schemaTypeName(itemSchema)
   344  
   345  		name += itemName
   346  	}
   347  
   348  	return name
   349  }
   350  
   351  func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   352  	union := schema.(*UnionSchema)
   353  
   354  	names, err := cfg.resolver.Name(typ)
   355  	if err != nil {
   356  		return &errorEncoder{err: err}
   357  	}
   358  
   359  	var pos int
   360  	for _, name := range names {
   361  		if idx := strings.Index(name, ":"); idx > 0 {
   362  			name = name[:idx]
   363  		}
   364  
   365  		schema, pos = union.Types().Get(name)
   366  		if schema != nil {
   367  			break
   368  		}
   369  	}
   370  	if schema == nil {
   371  		return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])}
   372  	}
   373  
   374  	encoder := encoderOfType(cfg, schema, typ)
   375  
   376  	return &unionResolverEncoder{
   377  		pos:     pos,
   378  		encoder: encoder,
   379  	}
   380  }
   381  
   382  type unionResolverEncoder struct {
   383  	pos     int
   384  	encoder ValEncoder
   385  }
   386  
   387  func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   388  	w.WriteLong(int64(e.pos))
   389  
   390  	e.encoder.Encode(ptr, w)
   391  }
   392  
   393  func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) {
   394  	types := schema.Types()
   395  
   396  	idx := int(r.ReadLong())
   397  	if idx < 0 || idx > len(types)-1 {
   398  		r.ReportError("decode union type", "unknown union type")
   399  		return 0, nil
   400  	}
   401  
   402  	return idx, types[idx]
   403  }