github.com/hamba/avro@v1.8.0/codec_union.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Map:
    16  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    17  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    18  			break
    19  		}
    20  		return decoderOfMapUnion(cfg, schema, typ)
    21  
    22  	case reflect.Ptr:
    23  		if !schema.(*UnionSchema).Nullable() {
    24  			break
    25  		}
    26  		return decoderOfPtrUnion(cfg, schema, typ)
    27  
    28  	case reflect.Interface:
    29  		if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
    30  			return decoderOfResolvedUnion(cfg, schema)
    31  		}
    32  	}
    33  
    34  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    35  }
    36  
    37  func createEncoderOfUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    38  	switch typ.Kind() {
    39  	case reflect.Map:
    40  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    41  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    42  			break
    43  		}
    44  		return encoderOfMapUnion(cfg, schema, typ)
    45  
    46  	case reflect.Ptr:
    47  		if !schema.(*UnionSchema).Nullable() {
    48  			break
    49  		}
    50  		return encoderOfPtrUnion(cfg, schema, typ)
    51  	}
    52  
    53  	return encoderOfResolverUnion(cfg, schema, typ)
    54  }
    55  
    56  func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    57  	union := schema.(*UnionSchema)
    58  	mapType := typ.(*reflect2.UnsafeMapType)
    59  
    60  	return &mapUnionDecoder{
    61  		cfg:      cfg,
    62  		schema:   union,
    63  		mapType:  mapType,
    64  		elemType: mapType.Elem(),
    65  	}
    66  }
    67  
    68  type mapUnionDecoder struct {
    69  	cfg      *frozenConfig
    70  	schema   *UnionSchema
    71  	mapType  *reflect2.UnsafeMapType
    72  	elemType reflect2.Type
    73  }
    74  
    75  func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    76  	_, resSchema := getUnionSchema(d.schema, r)
    77  	if resSchema == nil {
    78  		return
    79  	}
    80  
    81  	// In a null case, just return
    82  	if resSchema.Type() == Null {
    83  		return
    84  	}
    85  
    86  	if d.mapType.UnsafeIsNil(ptr) {
    87  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
    88  	}
    89  
    90  	key := schemaTypeName(resSchema)
    91  	keyPtr := reflect2.PtrOf(key)
    92  
    93  	elemPtr := d.elemType.UnsafeNew()
    94  	decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r)
    95  
    96  	d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
    97  }
    98  
    99  func encoderOfMapUnion(cfg *frozenConfig, schema Schema, _ reflect2.Type) ValEncoder {
   100  	union := schema.(*UnionSchema)
   101  
   102  	return &mapUnionEncoder{
   103  		cfg:    cfg,
   104  		schema: union,
   105  	}
   106  }
   107  
   108  type mapUnionEncoder struct {
   109  	cfg    *frozenConfig
   110  	schema *UnionSchema
   111  }
   112  
   113  func (e *mapUnionEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   114  	m := *((*map[string]interface{})(ptr))
   115  
   116  	if len(m) > 1 {
   117  		w.Error = errors.New("avro: cannot encode union map with multiple entries")
   118  		return
   119  	}
   120  
   121  	name := "null"
   122  	val := interface{}(nil)
   123  	for k, v := range m {
   124  		name = k
   125  		val = v
   126  		break
   127  	}
   128  
   129  	schema, pos := e.schema.Types().Get(name)
   130  	if schema == nil {
   131  		w.Error = fmt.Errorf("avro: unknown union type %s", name)
   132  		return
   133  	}
   134  
   135  	w.WriteLong(int64(pos))
   136  
   137  	if schema.Type() == Null && val == nil {
   138  		return
   139  	}
   140  
   141  	elemType := reflect2.TypeOf(val)
   142  	elemPtr := reflect2.PtrOf(val)
   143  
   144  	encoder := encoderOfType(e.cfg, schema, elemType)
   145  	if elemType.LikePtr() {
   146  		encoder = &onePtrEncoder{encoder}
   147  	}
   148  	encoder.Encode(elemPtr, w)
   149  }
   150  
   151  func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   152  	union := schema.(*UnionSchema)
   153  	_, typeIdx := union.Indices()
   154  	ptrType := typ.(*reflect2.UnsafePtrType)
   155  	elemType := ptrType.Elem()
   156  	decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType)
   157  
   158  	return &unionPtrDecoder{
   159  		schema:  union,
   160  		typ:     elemType,
   161  		decoder: decoder,
   162  	}
   163  }
   164  
   165  type unionPtrDecoder struct {
   166  	schema  *UnionSchema
   167  	typ     reflect2.Type
   168  	decoder ValDecoder
   169  }
   170  
   171  func (d *unionPtrDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   172  	_, schema := getUnionSchema(d.schema, r)
   173  	if schema == nil {
   174  		return
   175  	}
   176  
   177  	if schema.Type() == Null {
   178  		*((*unsafe.Pointer)(ptr)) = nil
   179  		return
   180  	}
   181  
   182  	if *((*unsafe.Pointer)(ptr)) == nil {
   183  		// Create new instance
   184  		newPtr := d.typ.UnsafeNew()
   185  		d.decoder.Decode(newPtr, r)
   186  		*((*unsafe.Pointer)(ptr)) = newPtr
   187  		return
   188  	}
   189  
   190  	// Reuse existing instance
   191  	d.decoder.Decode(*((*unsafe.Pointer)(ptr)), r)
   192  }
   193  
   194  func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   195  	union := schema.(*UnionSchema)
   196  	nullIdx, typeIdx := union.Indices()
   197  	ptrType := typ.(*reflect2.UnsafePtrType)
   198  	encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem())
   199  
   200  	return &unionPtrEncoder{
   201  		schema:  union,
   202  		encoder: encoder,
   203  		nullIdx: int64(nullIdx),
   204  		typeIdx: int64(typeIdx),
   205  	}
   206  }
   207  
   208  type unionPtrEncoder struct {
   209  	schema  *UnionSchema
   210  	encoder ValEncoder
   211  	nullIdx int64
   212  	typeIdx int64
   213  }
   214  
   215  func (e *unionPtrEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   216  	if *((*unsafe.Pointer)(ptr)) == nil {
   217  		w.WriteLong(e.nullIdx)
   218  		return
   219  	}
   220  
   221  	w.WriteLong(e.typeIdx)
   222  	e.encoder.Encode(*((*unsafe.Pointer)(ptr)), w)
   223  }
   224  
   225  func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) ValDecoder {
   226  	union := schema.(*UnionSchema)
   227  
   228  	types := make([]reflect2.Type, len(union.Types()))
   229  	decoders := make([]ValDecoder, len(union.Types()))
   230  	for i, schema := range union.Types() {
   231  		name := unionResolutionName(schema)
   232  		if typ, err := cfg.resolver.Type(name); err == nil {
   233  			decoder := decoderOfType(cfg, schema, typ)
   234  			decoders[i] = decoder
   235  			types[i] = typ
   236  			continue
   237  		}
   238  
   239  		decoders = []ValDecoder{}
   240  		types = []reflect2.Type{}
   241  		break
   242  	}
   243  
   244  	return &unionResolvedDecoder{
   245  		cfg:      cfg,
   246  		schema:   union,
   247  		types:    types,
   248  		decoders: decoders,
   249  	}
   250  }
   251  
   252  type unionResolvedDecoder struct {
   253  	cfg      *frozenConfig
   254  	schema   *UnionSchema
   255  	types    []reflect2.Type
   256  	decoders []ValDecoder
   257  }
   258  
   259  func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   260  	i, schema := getUnionSchema(d.schema, r)
   261  	if schema == nil {
   262  		return
   263  	}
   264  
   265  	pObj := (*interface{})(ptr)
   266  
   267  	if schema.Type() == Null {
   268  		*pObj = nil
   269  		return
   270  	}
   271  
   272  	if i >= len(d.decoders) {
   273  		if d.cfg.config.UnionResolutionError {
   274  			r.ReportError("decode union type", "unknown union type")
   275  			return
   276  		}
   277  
   278  		// We cannot resolve this, set it to the map type
   279  		name := schemaTypeName(schema)
   280  		obj := map[string]interface{}{}
   281  		obj[name] = r.ReadNext(schema)
   282  
   283  		*pObj = obj
   284  		return
   285  	}
   286  
   287  	typ := d.types[i]
   288  	var newPtr unsafe.Pointer
   289  	switch typ.Kind() {
   290  	case reflect.Map:
   291  		mapType := typ.(*reflect2.UnsafeMapType)
   292  		newPtr = mapType.UnsafeMakeMap(1)
   293  
   294  	case reflect.Slice:
   295  		mapType := typ.(*reflect2.UnsafeSliceType)
   296  		newPtr = mapType.UnsafeMakeSlice(1, 1)
   297  
   298  	case reflect.Ptr:
   299  		elemType := typ.(*reflect2.UnsafePtrType).Elem()
   300  		newPtr = elemType.UnsafeNew()
   301  
   302  	default:
   303  		newPtr = typ.UnsafeNew()
   304  	}
   305  
   306  	d.decoders[i].Decode(newPtr, r)
   307  	*pObj = typ.UnsafeIndirect(newPtr)
   308  }
   309  
   310  func unionResolutionName(schema Schema) string {
   311  	name := schemaTypeName(schema)
   312  	switch schema.Type() {
   313  	case Map:
   314  		name += ":"
   315  		valSchema := schema.(*MapSchema).Values()
   316  		valName := schemaTypeName(valSchema)
   317  
   318  		name += valName
   319  
   320  	case Array:
   321  		name += ":"
   322  		itemSchema := schema.(*ArraySchema).Items()
   323  		itemName := schemaTypeName(itemSchema)
   324  
   325  		name += itemName
   326  	}
   327  
   328  	return name
   329  }
   330  
   331  func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   332  	union := schema.(*UnionSchema)
   333  
   334  	names, err := cfg.resolver.Name(typ)
   335  	if err != nil {
   336  		return &errorEncoder{err: err}
   337  	}
   338  
   339  	var pos int
   340  	for _, name := range names {
   341  		if idx := strings.Index(name, ":"); idx > 0 {
   342  			name = name[:idx]
   343  		}
   344  
   345  		schema, pos = union.Types().Get(name)
   346  		if schema != nil {
   347  			break
   348  		}
   349  	}
   350  	if schema == nil {
   351  		return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])}
   352  	}
   353  
   354  	encoder := encoderOfType(cfg, schema, typ)
   355  
   356  	return &unionResolverEncoder{
   357  		pos:     pos,
   358  		encoder: encoder,
   359  	}
   360  }
   361  
   362  type unionResolverEncoder struct {
   363  	pos     int
   364  	encoder ValEncoder
   365  }
   366  
   367  func (e *unionResolverEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   368  	w.WriteLong(int64(e.pos))
   369  
   370  	e.encoder.Encode(ptr, w)
   371  }
   372  
   373  func getUnionSchema(schema *UnionSchema, r *Reader) (int, Schema) {
   374  	types := schema.Types()
   375  
   376  	idx := int(r.ReadLong())
   377  	if idx < 0 || idx > len(types)-1 {
   378  		r.ReportError("decode union type", "unknown union type")
   379  		return 0, nil
   380  	}
   381  
   382  	return idx, types[idx]
   383  }