github.com/aacfactory/avro@v1.2.12/internal/base/codec_enum.go (about)

     1  package base
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch {
    15  	case typ.Kind() == reflect.String:
    16  		return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
    17  	case typ.Implements(textUnmarshalerType):
    18  		return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
    19  	case reflect2.PtrTo(typ).Implements(textUnmarshalerType):
    20  		return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
    21  	}
    22  
    23  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    24  }
    25  
    26  func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder {
    27  	switch {
    28  	case typ.Kind() == reflect.String:
    29  		return &enumCodec{symbols: schema.(*EnumSchema).Symbols()}
    30  	case typ.Implements(textMarshalerType):
    31  		return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()}
    32  	case reflect2.PtrTo(typ).Implements(textMarshalerType):
    33  		return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true}
    34  	}
    35  
    36  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    37  }
    38  
    39  type enumCodec struct {
    40  	symbols []string
    41  }
    42  
    43  func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) {
    44  	i := int(r.ReadInt())
    45  
    46  	if i < 0 || i >= len(c.symbols) {
    47  		r.ReportError("decode enum symbol", "unknown enum symbol")
    48  		return
    49  	}
    50  
    51  	*((*string)(ptr)) = c.symbols[i]
    52  }
    53  
    54  func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) {
    55  	str := *((*string)(ptr))
    56  	for i, sym := range c.symbols {
    57  		if str != sym {
    58  			continue
    59  		}
    60  
    61  		w.WriteInt(int32(i))
    62  		return
    63  	}
    64  
    65  	w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str)
    66  }
    67  
    68  type enumTextMarshalerCodec struct {
    69  	typ     reflect2.Type
    70  	symbols []string
    71  	ptr     bool
    72  }
    73  
    74  func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) {
    75  	i := int(r.ReadInt())
    76  
    77  	if i < 0 || i >= len(c.symbols) {
    78  		r.ReportError("decode enum symbol", "unknown enum symbol")
    79  		return
    80  	}
    81  
    82  	var obj any
    83  	if c.ptr {
    84  		obj = c.typ.PackEFace(ptr)
    85  	} else {
    86  		obj = c.typ.UnsafeIndirect(ptr)
    87  	}
    88  	if reflect2.IsNil(obj) {
    89  		ptrType := c.typ.(*reflect2.UnsafePtrType)
    90  		newPtr := ptrType.Elem().UnsafeNew()
    91  		*((*unsafe.Pointer)(ptr)) = newPtr
    92  		obj = c.typ.UnsafeIndirect(ptr)
    93  	}
    94  	unmarshaler := (obj).(encoding.TextUnmarshaler)
    95  	if err := unmarshaler.UnmarshalText([]byte(c.symbols[i])); err != nil {
    96  		r.ReportError("decode enum text unmarshaler", err.Error())
    97  	}
    98  }
    99  
   100  func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   101  	var obj any
   102  	if c.ptr {
   103  		obj = c.typ.PackEFace(ptr)
   104  	} else {
   105  		obj = c.typ.UnsafeIndirect(ptr)
   106  	}
   107  	if c.typ.IsNullable() && reflect2.IsNil(obj) {
   108  		w.Error = errors.New("encoding nil enum text marshaler")
   109  		return
   110  	}
   111  	marshaler := (obj).(encoding.TextMarshaler)
   112  	b, err := marshaler.MarshalText()
   113  	if err != nil {
   114  		w.Error = err
   115  		return
   116  	}
   117  
   118  	str := string(b)
   119  	for i, sym := range c.symbols {
   120  		if str != sym {
   121  			continue
   122  		}
   123  
   124  		w.WriteInt(int32(i))
   125  		return
   126  	}
   127  
   128  	w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str)
   129  }