github.com/aacfactory/avro@v1.2.12/internal/base/codec_map.go (about)

     1  package base
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"unsafe"
    10  
    11  	"github.com/modern-go/reflect2"
    12  )
    13  
    14  func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    15  	if typ.Kind() == reflect.Map {
    16  		keyType := typ.(reflect2.MapType).Key()
    17  		switch {
    18  		case keyType.Kind() == reflect.String:
    19  			return decoderOfMap(cfg, schema, typ)
    20  		case keyType.Implements(textUnmarshalerType):
    21  			return decoderOfMapUnmarshaler(cfg, schema, typ)
    22  		}
    23  	}
    24  
    25  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    26  }
    27  
    28  func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    29  	if typ.Kind() == reflect.Map {
    30  		keyType := typ.(reflect2.MapType).Key()
    31  		switch {
    32  		case keyType.Kind() == reflect.String:
    33  			return encoderOfMap(cfg, schema, typ)
    34  		case keyType.Implements(textMarshalerType):
    35  			return encoderOfMapMarshaler(cfg, schema, typ)
    36  		}
    37  	}
    38  
    39  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    40  }
    41  
    42  func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    43  	m := schema.(*MapSchema)
    44  	mapType := typ.(*reflect2.UnsafeMapType)
    45  	decoder := decoderOfType(cfg, m.Values(), mapType.Elem())
    46  
    47  	return &mapDecoder{
    48  		mapType:  mapType,
    49  		elemType: mapType.Elem(),
    50  		decoder:  decoder,
    51  	}
    52  }
    53  
    54  type mapDecoder struct {
    55  	mapType  *reflect2.UnsafeMapType
    56  	elemType reflect2.Type
    57  	decoder  ValDecoder
    58  }
    59  
    60  func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    61  	if d.mapType.UnsafeIsNil(ptr) {
    62  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
    63  	}
    64  
    65  	for {
    66  		l, _ := r.ReadBlockHeader()
    67  		if l == 0 {
    68  			break
    69  		}
    70  
    71  		for i := int64(0); i < l; i++ {
    72  			keyPtr := reflect2.PtrOf(r.ReadString())
    73  			elemPtr := d.elemType.UnsafeNew()
    74  			d.decoder.Decode(elemPtr, r)
    75  
    76  			d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
    77  		}
    78  	}
    79  
    80  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
    81  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
    82  	}
    83  }
    84  
    85  func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    86  	m := schema.(*MapSchema)
    87  	mapType := typ.(*reflect2.UnsafeMapType)
    88  	decoder := decoderOfType(cfg, m.Values(), mapType.Elem())
    89  
    90  	return &mapDecoderUnmarshaler{
    91  		mapType:  mapType,
    92  		keyType:  mapType.Key(),
    93  		elemType: mapType.Elem(),
    94  		decoder:  decoder,
    95  	}
    96  }
    97  
    98  type mapDecoderUnmarshaler struct {
    99  	mapType  *reflect2.UnsafeMapType
   100  	keyType  reflect2.Type
   101  	elemType reflect2.Type
   102  	decoder  ValDecoder
   103  }
   104  
   105  func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) {
   106  	if d.mapType.UnsafeIsNil(ptr) {
   107  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
   108  	}
   109  
   110  	for {
   111  		l, _ := r.ReadBlockHeader()
   112  		if l == 0 {
   113  			break
   114  		}
   115  
   116  		for i := int64(0); i < l; i++ {
   117  			keyPtr := d.keyType.UnsafeNew()
   118  			keyObj := d.keyType.UnsafeIndirect(keyPtr)
   119  			if reflect2.IsNil(keyObj) {
   120  				ptrType := d.keyType.(*reflect2.UnsafePtrType)
   121  				newPtr := ptrType.Elem().UnsafeNew()
   122  				*((*unsafe.Pointer)(keyPtr)) = newPtr
   123  				keyObj = d.keyType.UnsafeIndirect(keyPtr)
   124  			}
   125  			unmarshaler := keyObj.(encoding.TextUnmarshaler)
   126  			err := unmarshaler.UnmarshalText([]byte(r.ReadString()))
   127  			if err != nil {
   128  				r.ReportError("mapDecoderUnmarshaler", err.Error())
   129  				return
   130  			}
   131  
   132  			elemPtr := d.elemType.UnsafeNew()
   133  			d.decoder.Decode(elemPtr, r)
   134  
   135  			d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
   136  		}
   137  	}
   138  
   139  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   140  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
   141  	}
   142  }
   143  
   144  func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   145  	m := schema.(*MapSchema)
   146  	mapType := typ.(*reflect2.UnsafeMapType)
   147  	encoder := encoderOfType(cfg, m.Values(), mapType.Elem())
   148  
   149  	return &mapEncoder{
   150  		blockLength: cfg.getBlockLength(),
   151  		mapType:     mapType,
   152  		encoder:     encoder,
   153  	}
   154  }
   155  
   156  type mapEncoder struct {
   157  	blockLength int
   158  	mapType     *reflect2.UnsafeMapType
   159  	encoder     ValEncoder
   160  }
   161  
   162  func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   163  	blockLength := e.blockLength
   164  
   165  	iter := e.mapType.UnsafeIterate(ptr)
   166  
   167  	for {
   168  		wrote := w.WriteBlockCB(func(w *Writer) int64 {
   169  			var i int
   170  			for i = 0; iter.HasNext() && i < blockLength; i++ {
   171  				keyPtr, elemPtr := iter.UnsafeNext()
   172  				w.WriteString(*((*string)(keyPtr)))
   173  				e.encoder.Encode(elemPtr, w)
   174  			}
   175  
   176  			return int64(i)
   177  		})
   178  
   179  		if wrote == 0 {
   180  			break
   181  		}
   182  	}
   183  
   184  	if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   185  		w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
   186  	}
   187  }
   188  
   189  func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   190  	m := schema.(*MapSchema)
   191  	mapType := typ.(*reflect2.UnsafeMapType)
   192  	encoder := encoderOfType(cfg, m.Values(), mapType.Elem())
   193  
   194  	return &mapEncoderMarshaller{
   195  		blockLength: cfg.getBlockLength(),
   196  		mapType:     mapType,
   197  		keyType:     mapType.Key(),
   198  		encoder:     encoder,
   199  	}
   200  }
   201  
   202  type mapEncoderMarshaller struct {
   203  	blockLength int
   204  	mapType     *reflect2.UnsafeMapType
   205  	keyType     reflect2.Type
   206  	encoder     ValEncoder
   207  }
   208  
   209  func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) {
   210  	blockLength := e.blockLength
   211  
   212  	iter := e.mapType.UnsafeIterate(ptr)
   213  
   214  	for {
   215  		wrote := w.WriteBlockCB(func(w *Writer) int64 {
   216  			var i int
   217  			for i = 0; iter.HasNext() && i < blockLength; i++ {
   218  				keyPtr, elemPtr := iter.UnsafeNext()
   219  
   220  				obj := e.keyType.UnsafeIndirect(keyPtr)
   221  				if e.keyType.IsNullable() && reflect2.IsNil(obj) {
   222  					w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller")
   223  					return int64(0)
   224  				}
   225  				marshaler := (obj).(encoding.TextMarshaler)
   226  				b, err := marshaler.MarshalText()
   227  				if err != nil {
   228  					w.Error = err
   229  					return int64(0)
   230  				}
   231  				w.WriteString(string(b))
   232  
   233  				e.encoder.Encode(elemPtr, w)
   234  			}
   235  			return int64(i)
   236  		})
   237  
   238  		if wrote == 0 {
   239  			break
   240  		}
   241  	}
   242  
   243  	if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   244  		w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
   245  	}
   246  }