github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_enum.go (about) 1 package avro 2 3 import ( 4 "encoding" 5 "errors" 6 "fmt" 7 "reflect" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 func createDecoderOfEnum(schema Schema, typ reflect2.Type) ValDecoder { 14 switch { 15 case typ.Kind() == reflect.String: 16 return &enumCodec{enum: schema.(*EnumSchema)} 17 case typ.Implements(textUnmarshalerType): 18 return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} 19 case reflect2.PtrTo(typ).Implements(textUnmarshalerType): 20 return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} 21 } 22 23 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 24 } 25 26 func createEncoderOfEnum(schema Schema, typ reflect2.Type) ValEncoder { 27 switch { 28 case typ.Kind() == reflect.String: 29 return &enumCodec{enum: schema.(*EnumSchema)} 30 case typ.Implements(textMarshalerType): 31 return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema)} 32 case reflect2.PtrTo(typ).Implements(textMarshalerType): 33 return &enumTextMarshalerCodec{typ: typ, enum: schema.(*EnumSchema), ptr: true} 34 } 35 36 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 37 } 38 39 type enumCodec struct { 40 enum *EnumSchema 41 } 42 43 func (c *enumCodec) Decode(ptr unsafe.Pointer, r *Reader) { 44 i := int(r.ReadInt()) 45 46 symbol, ok := c.enum.Symbol(i) 47 if !ok { 48 r.ReportError("decode enum symbol", "unknown enum symbol") 49 return 50 } 51 52 *((*string)(ptr)) = symbol 53 } 54 55 func (c *enumCodec) Encode(ptr unsafe.Pointer, w *Writer) { 56 str := *((*string)(ptr)) 57 for i, sym := range c.enum.symbols { 58 if str != sym { 59 continue 60 } 61 62 w.WriteInt(int32(i)) 63 return 64 } 65 66 w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str) 67 } 68 69 type enumTextMarshalerCodec struct { 70 typ reflect2.Type 71 enum *EnumSchema 72 ptr bool 73 } 74 75 func (c *enumTextMarshalerCodec) Decode(ptr unsafe.Pointer, r *Reader) { 76 i := int(r.ReadInt()) 77 78 symbol, ok := c.enum.Symbol(i) 79 if !ok { 80 r.ReportError("decode enum symbol", "unknown enum symbol") 81 return 82 } 83 84 var obj any 85 if c.ptr { 86 obj = c.typ.PackEFace(ptr) 87 } else { 88 obj = c.typ.UnsafeIndirect(ptr) 89 } 90 if reflect2.IsNil(obj) { 91 ptrType := c.typ.(*reflect2.UnsafePtrType) 92 newPtr := ptrType.Elem().UnsafeNew() 93 *((*unsafe.Pointer)(ptr)) = newPtr 94 obj = c.typ.UnsafeIndirect(ptr) 95 } 96 unmarshaler := (obj).(encoding.TextUnmarshaler) 97 if err := unmarshaler.UnmarshalText([]byte(symbol)); err != nil { 98 r.ReportError("decode enum text unmarshaler", err.Error()) 99 } 100 } 101 102 func (c *enumTextMarshalerCodec) Encode(ptr unsafe.Pointer, w *Writer) { 103 var obj any 104 if c.ptr { 105 obj = c.typ.PackEFace(ptr) 106 } else { 107 obj = c.typ.UnsafeIndirect(ptr) 108 } 109 if c.typ.IsNullable() && reflect2.IsNil(obj) { 110 w.Error = errors.New("encoding nil enum text marshaler") 111 return 112 } 113 marshaler := (obj).(encoding.TextMarshaler) 114 b, err := marshaler.MarshalText() 115 if err != nil { 116 w.Error = err 117 return 118 } 119 120 str := string(b) 121 for i, sym := range c.enum.symbols { 122 if str != sym { 123 continue 124 } 125 126 w.WriteInt(int32(i)) 127 return 128 } 129 130 w.Error = fmt.Errorf("avro: unknown enum symbol: %s", str) 131 }