github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_enum.go (about)

     1  package avro
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch {
    15  	case typ.Kind() == reflect.String:
    16  		return &enumCodec{enum: schema.(*EnumSchema)}
    17  	case typ.Implements(textUnmarshalerType):
    18  		return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
    19  	case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
    20  		return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
    21  	}
    22  
    23  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    24  }
    25  
    26  func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
    27  	switch {
    28  	case typ.Kind() == reflect.String:
    29  		return &enumCodec{enum: schema.(*EnumSchema)}
    30  	case typ.Implements(textMarshalerType):
    31  		return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)}
    32  	case reflect2.PtrTo(typ).Implements(textMarshalerType):
    33  		return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true}
    34  	}
    35  
    36  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    37  }
    38  
    39  type enumCodec struct {
    40  	enum *EnumSchema
    41  }
    42  
    43  func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) {
    44  	i := int(r.ReadInt())
    45  
    46  	symbol, ok := c.enum.Symbol(i)
    47  	if !ok {
    48  		r.ReportError("decode enum symbol", "unknown enum symbol")
    49  		return
    50  	}
    51  
    52  	*((*string)(ptr)) = symbol
    53  }
    54  
    55  func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) {
    56  	str := *((*string)(ptr))
    57  	for i, sym := range c.enum.symbols {
    58  		if str != sym {
    59  			continue
    60  		}
    61  
    62  		w.WriteInt(int32(i))
    63  		return
    64  	}
    65  
    66  	w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str)
    67  }
    68  
    69  type enumTextMarshalerCodec struct {
    70  	typ  reflect2.Type
    71  	enum *EnumSchema
    72  	ptr  bool
    73  }
    74  
    75  func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
    76  	i := int(r.ReadInt())
    77  
    78  	symbol, ok := c.enum.Symbol(i)
    79  	if !ok {
    80  		r.ReportError("decode enum symbol", "unknown enum symbol")
    81  		return
    82  	}
    83  
    84  	var obj any
    85  	if c.ptr {
    86  		obj = c.typ.PackEFace(ptr)
    87  	} else {
    88  		obj = c.typ.UnsafeIndirect(ptr)
    89  	}
    90  	if reflect2.IsNil(obj) {
    91  		ptrType := c.typ.(*reflect2.UnsafePtrType)
    92  		newPtr := ptrType.Elem().UnsafeNew()
    93  		*((*unsafe.Pointer)(ptr)) = newPtr
    94  		obj = c.typ.UnsafeIndirect(ptr)
    95  	}
    96  	unmarshaler := (obj).(encoding.TextUnmarshaler)
    97  	if err := unmarshaler.UnmarshalText([]byte(symbol)); err != nil {
    98  		r.ReportError("decode enum text unmarshaler", err.Error())
    99  	}
   100  }
   101  
   102  func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   103  	var obj any
   104  	if c.ptr {
   105  		obj = c.typ.PackEFace(ptr)
   106  	} else {
   107  		obj = c.typ.UnsafeIndirect(ptr)
   108  	}
   109  	if c.typ.IsNullable() && reflect2.IsNil(obj) {
   110  		w.Error = errors.New("encoding nil enum text marshaler")
   111  		return
   112  	}
   113  	marshaler := (obj).(encoding.TextMarshaler)
   114  	b, err := marshaler.MarshalText()
   115  	if err != nil {
   116  		w.Error = err
   117  		return
   118  	}
   119  
   120  	str := string(b)
   121  	for i, sym := range c.enum.symbols {
   122  		if str != sym {
   123  			continue
   124  		}
   125  
   126  		w.WriteInt(int32(i))
   127  		return
   128  	}
   129  
   130  	w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str)
   131  }