github.com/aacfactory/avro@v1.2.12/internal/base/codec_map.go (about) 1 package base 2 3 import ( 4 "encoding" 5 "errors" 6 "fmt" 7 "io" 8 "reflect" 9 "unsafe" 10 11 "github.com/modern-go/reflect2" 12 ) 13 14 func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 15 if typ.Kind() == reflect.Map { 16 keyType := typ.(reflect2.MapType).Key() 17 switch { 18 case keyType.Kind() == reflect.String: 19 return decoderOfMap(cfg, schema, typ) 20 case keyType.Implements(textUnmarshalerType): 21 return decoderOfMapUnmarshaler(cfg, schema, typ) 22 } 23 } 24 25 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 26 } 27 28 func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 29 if typ.Kind() == reflect.Map { 30 keyType := typ.(reflect2.MapType).Key() 31 switch { 32 case keyType.Kind() == reflect.String: 33 return encoderOfMap(cfg, schema, typ) 34 case keyType.Implements(textMarshalerType): 35 return encoderOfMapMarshaler(cfg, schema, typ) 36 } 37 } 38 39 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 40 } 41 42 func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 43 m := schema.(*MapSchema) 44 mapType := typ.(*reflect2.UnsafeMapType) 45 decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) 46 47 return &mapDecoder{ 48 mapType: mapType, 49 elemType: mapType.Elem(), 50 decoder: decoder, 51 } 52 } 53 54 type mapDecoder struct { 55 mapType *reflect2.UnsafeMapType 56 elemType reflect2.Type 57 decoder ValDecoder 58 } 59 60 func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { 61 if d.mapType.UnsafeIsNil(ptr) { 62 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 63 } 64 65 for { 66 l, _ := r.ReadBlockHeader() 67 if l == 0 { 68 break 69 } 70 71 for i := int64(0); i < l; i++ { 72 keyPtr := reflect2.PtrOf(r.ReadString()) 73 elemPtr := d.elemType.UnsafeNew() 74 d.decoder.Decode(elemPtr, r) 75 76 d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) 77 } 78 } 79 80 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 81 r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) 82 } 83 } 84 85 func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { 86 m := schema.(*MapSchema) 87 mapType := typ.(*reflect2.UnsafeMapType) 88 decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) 89 90 return &mapDecoderUnmarshaler{ 91 mapType: mapType, 92 keyType: mapType.Key(), 93 elemType: mapType.Elem(), 94 decoder: decoder, 95 } 96 } 97 98 type mapDecoderUnmarshaler struct { 99 mapType *reflect2.UnsafeMapType 100 keyType reflect2.Type 101 elemType reflect2.Type 102 decoder ValDecoder 103 } 104 105 func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) { 106 if d.mapType.UnsafeIsNil(ptr) { 107 d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) 108 } 109 110 for { 111 l, _ := r.ReadBlockHeader() 112 if l == 0 { 113 break 114 } 115 116 for i := int64(0); i < l; i++ { 117 keyPtr := d.keyType.UnsafeNew() 118 keyObj := d.keyType.UnsafeIndirect(keyPtr) 119 if reflect2.IsNil(keyObj) { 120 ptrType := d.keyType.(*reflect2.UnsafePtrType) 121 newPtr := ptrType.Elem().UnsafeNew() 122 *((*unsafe.Pointer)(keyPtr)) = newPtr 123 keyObj = d.keyType.UnsafeIndirect(keyPtr) 124 } 125 unmarshaler := keyObj.(encoding.TextUnmarshaler) 126 err := unmarshaler.UnmarshalText([]byte(r.ReadString())) 127 if err != nil { 128 r.ReportError("mapDecoderUnmarshaler", err.Error()) 129 return 130 } 131 132 elemPtr := d.elemType.UnsafeNew() 133 d.decoder.Decode(elemPtr, r) 134 135 d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) 136 } 137 } 138 139 if r.Error != nil && !errors.Is(r.Error, io.EOF) { 140 r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error) 141 } 142 } 143 144 func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 145 m := schema.(*MapSchema) 146 mapType := typ.(*reflect2.UnsafeMapType) 147 encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) 148 149 return &mapEncoder{ 150 blockLength: cfg.getBlockLength(), 151 mapType: mapType, 152 encoder: encoder, 153 } 154 } 155 156 type mapEncoder struct { 157 blockLength int 158 mapType *reflect2.UnsafeMapType 159 encoder ValEncoder 160 } 161 162 func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) { 163 blockLength := e.blockLength 164 165 iter := e.mapType.UnsafeIterate(ptr) 166 167 for { 168 wrote := w.WriteBlockCB(func(w *Writer) int64 { 169 var i int 170 for i = 0; iter.HasNext() && i < blockLength; i++ { 171 keyPtr, elemPtr := iter.UnsafeNext() 172 w.WriteString(*((*string)(keyPtr))) 173 e.encoder.Encode(elemPtr, w) 174 } 175 176 return int64(i) 177 }) 178 179 if wrote == 0 { 180 break 181 } 182 } 183 184 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 185 w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error) 186 } 187 } 188 189 func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { 190 m := schema.(*MapSchema) 191 mapType := typ.(*reflect2.UnsafeMapType) 192 encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) 193 194 return &mapEncoderMarshaller{ 195 blockLength: cfg.getBlockLength(), 196 mapType: mapType, 197 keyType: mapType.Key(), 198 encoder: encoder, 199 } 200 } 201 202 type mapEncoderMarshaller struct { 203 blockLength int 204 mapType *reflect2.UnsafeMapType 205 keyType reflect2.Type 206 encoder ValEncoder 207 } 208 209 func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) { 210 blockLength := e.blockLength 211 212 iter := e.mapType.UnsafeIterate(ptr) 213 214 for { 215 wrote := w.WriteBlockCB(func(w *Writer) int64 { 216 var i int 217 for i = 0; iter.HasNext() && i < blockLength; i++ { 218 keyPtr, elemPtr := iter.UnsafeNext() 219 220 obj := e.keyType.UnsafeIndirect(keyPtr) 221 if e.keyType.IsNullable() && reflect2.IsNil(obj) { 222 w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller") 223 return int64(0) 224 } 225 marshaler := (obj).(encoding.TextMarshaler) 226 b, err := marshaler.MarshalText() 227 if err != nil { 228 w.Error = err 229 return int64(0) 230 } 231 w.WriteString(string(b)) 232 233 e.encoder.Encode(elemPtr, w) 234 } 235 return int64(i) 236 }) 237 238 if wrote == 0 { 239 break 240 } 241 } 242 243 if w.Error != nil && !errors.Is(w.Error, io.EOF) { 244 w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error) 245 } 246 }