github.com/aacfactory/avro@v1.2.12/internal/base/codec_enum.go (about) 1 package base 2 3 import ( 4 "encoding" 5 "errors" 6 "fmt" 7 "reflect" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder { 14 switch { 15 case typ.Kind() == reflect.String: 16 return &enumCodec{symbols: schema.(*EnumSchema).Symbols()} 17 case typ.Implements(textUnmarshalerType): 18 return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()} 19 case reflect2.PtrTo(typ).Implements(textUnmarshalerType): 20 return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true} 21 } 22 23 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 24 } 25 26 func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder { 27 switch { 28 case typ.Kind() == reflect.String: 29 return &enumCodec{symbols: schema.(*EnumSchema).Symbols()} 30 case typ.Implements(textMarshalerType): 31 return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols()} 32 case reflect2.PtrTo(typ).Implements(textMarshalerType): 33 return &enumTextMarshalerCodec{typ: typ, symbols: schema.(*EnumSchema).Symbols(), ptr: true} 34 } 35 36 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 37 } 38 39 type enumCodec struct { 40 symbols []string 41 } 42 43 func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) { 44 i := int(r.ReadInt()) 45 46 if i < 0 || i >= len(c.symbols) { 47 r.ReportError("decode enum symbol", "unknown enum symbol") 48 return 49 } 50 51 *((*string)(ptr)) = c.symbols[i] 52 } 53 54 func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) { 55 str := *((*string)(ptr)) 56 for i, sym := range c.symbols { 57 if str != sym { 58 continue 59 } 60 61 w.WriteInt(int32(i)) 62 return 63 } 64 65 w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str) 66 } 67 68 type enumTextMarshalerCodec struct { 69 typ reflect2.Type 70 symbols []string 71 ptr bool 72 } 73 74 func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { 75 i := int(r.ReadInt()) 76 77 if i < 0 || i >= len(c.symbols) { 78 r.ReportError("decode enum symbol", "unknown enum symbol") 79 return 80 } 81 82 var obj any 83 if c.ptr { 84 obj = c.typ.PackEFace(ptr) 85 } else { 86 obj = c.typ.UnsafeIndirect(ptr) 87 } 88 if reflect2.IsNil(obj) { 89 ptrType := c.typ.(*reflect2.UnsafePtrType) 90 newPtr := ptrType.Elem().UnsafeNew() 91 *((*unsafe.Pointer)(ptr)) = newPtr 92 obj = c.typ.UnsafeIndirect(ptr) 93 } 94 unmarshaler := (obj).(encoding.TextUnmarshaler) 95 if err := unmarshaler.UnmarshalText([]byte(c.symbols[i])); err != nil { 96 r.ReportError("decode enum text unmarshaler", err.Error()) 97 } 98 } 99 100 func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) { 101 var obj any 102 if c.ptr { 103 obj = c.typ.PackEFace(ptr) 104 } else { 105 obj = c.typ.UnsafeIndirect(ptr) 106 } 107 if c.typ.IsNullable() && reflect2.IsNil(obj) { 108 w.Error = errors.New("encoding nil enum text marshaler") 109 return 110 } 111 marshaler := (obj).(encoding.TextMarshaler) 112 b, err := marshaler.MarshalText() 113 if err != nil { 114 w.Error = err 115 return 116 } 117 118 str := string(b) 119 for i, sym := range c.symbols { 120 if str != sym { 121 continue 122 } 123 124 w.WriteInt(int32(i)) 125 return 126 } 127 128 w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str) 129 }