github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_map.go (about)

     1  package avro
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"unsafe"
    10  
    11  	"github.com/modern-go/reflect2"
    12  )
    13  
    14  func createDecoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    15  	if typ.Kind() == reflect.Map {
    16  		keyType := typ.(reflect2.MapType).Key()
    17  		switch {
    18  		case keyType.Kind() == reflect.String:
    19  			return decoderOfMap(cfg, schema, typ)
    20  		case keyType.Implements(textUnmarshalerType):
    21  			return decoderOfMapUnmarshaler(cfg, schema, typ)
    22  		}
    23  	}
    24  
    25  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    26  }
    27  
    28  func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    29  	if typ.Kind() == reflect.Map {
    30  		keyType := typ.(reflect2.MapType).Key()
    31  		switch {
    32  		case keyType.Kind() == reflect.String:
    33  			return encoderOfMap(cfg, schema, typ)
    34  		case keyType.Implements(textMarshalerType):
    35  			return encoderOfMapMarshaler(cfg, schema, typ)
    36  		}
    37  	}
    38  
    39  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
    40  }
    41  
    42  func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    43  	m := schema.(*MapSchema)
    44  	mapType := typ.(*reflect2.UnsafeMapType)
    45  	decoder := decoderOfType(cfg, m.Values(), mapType.Elem())
    46  
    47  	return &mapDecoder{
    48  		mapType:  mapType,
    49  		elemType: mapType.Elem(),
    50  		decoder:  decoder,
    51  	}
    52  }
    53  
    54  type mapDecoder struct {
    55  	mapType  *reflect2.UnsafeMapType
    56  	elemType reflect2.Type
    57  	decoder  ValDecoder
    58  }
    59  
    60  func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    61  	if d.mapType.UnsafeIsNil(ptr) {
    62  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
    63  	}
    64  
    65  	for {
    66  		l, _ := r.ReadBlockHeader()
    67  		if l == 0 {
    68  			break
    69  		}
    70  
    71  		for i := int64(0); i < l; i++ {
    72  			keyPtr := reflect2.PtrOf(r.ReadString())
    73  			elemPtr := d.elemType.UnsafeNew()
    74  			d.decoder.Decode(elemPtr, r)
    75  			if r.Error != nil {
    76  				r.Error = fmt.Errorf("reading map[string]%s: %w", d.elemType.String(), r.Error)
    77  				return
    78  			}
    79  
    80  			d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
    81  		}
    82  	}
    83  
    84  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
    85  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
    86  	}
    87  }
    88  
    89  func decoderOfMapUnmarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    90  	m := schema.(*MapSchema)
    91  	mapType := typ.(*reflect2.UnsafeMapType)
    92  	decoder := decoderOfType(cfg, m.Values(), mapType.Elem())
    93  
    94  	return &mapDecoderUnmarshaler{
    95  		mapType:  mapType,
    96  		keyType:  mapType.Key(),
    97  		elemType: mapType.Elem(),
    98  		decoder:  decoder,
    99  	}
   100  }
   101  
   102  type mapDecoderUnmarshaler struct {
   103  	mapType  *reflect2.UnsafeMapType
   104  	keyType  reflect2.Type
   105  	elemType reflect2.Type
   106  	decoder  ValDecoder
   107  }
   108  
   109  func (d *mapDecoderUnmarshaler) Decode(ptr unsafe.Pointer, r *Reader) {
   110  	if d.mapType.UnsafeIsNil(ptr) {
   111  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
   112  	}
   113  
   114  	for {
   115  		l, _ := r.ReadBlockHeader()
   116  		if l == 0 {
   117  			break
   118  		}
   119  
   120  		for i := int64(0); i < l; i++ {
   121  			keyPtr := d.keyType.UnsafeNew()
   122  			keyObj := d.keyType.UnsafeIndirect(keyPtr)
   123  			if reflect2.IsNil(keyObj) {
   124  				ptrType := d.keyType.(*reflect2.UnsafePtrType)
   125  				newPtr := ptrType.Elem().UnsafeNew()
   126  				*((*unsafe.Pointer)(keyPtr)) = newPtr
   127  				keyObj = d.keyType.UnsafeIndirect(keyPtr)
   128  			}
   129  			unmarshaler := keyObj.(encoding.TextUnmarshaler)
   130  			err := unmarshaler.UnmarshalText([]byte(r.ReadString()))
   131  			if err != nil {
   132  				r.ReportError("mapDecoderUnmarshaler", err.Error())
   133  				return
   134  			}
   135  
   136  			elemPtr := d.elemType.UnsafeNew()
   137  			d.decoder.Decode(elemPtr, r)
   138  
   139  			d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr)
   140  		}
   141  	}
   142  
   143  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   144  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
   145  	}
   146  }
   147  
   148  func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   149  	m := schema.(*MapSchema)
   150  	mapType := typ.(*reflect2.UnsafeMapType)
   151  	encoder := encoderOfType(cfg, m.Values(), mapType.Elem())
   152  
   153  	return &mapEncoder{
   154  		blockLength: cfg.getBlockLength(),
   155  		mapType:     mapType,
   156  		encoder:     encoder,
   157  	}
   158  }
   159  
   160  type mapEncoder struct {
   161  	blockLength int
   162  	mapType     *reflect2.UnsafeMapType
   163  	encoder     ValEncoder
   164  }
   165  
   166  func (e *mapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   167  	blockLength := e.blockLength
   168  
   169  	iter := e.mapType.UnsafeIterate(ptr)
   170  
   171  	for {
   172  		wrote := w.WriteBlockCB(func(w *Writer) int64 {
   173  			var i int
   174  			for i = 0; iter.HasNext() && i < blockLength; i++ {
   175  				keyPtr, elemPtr := iter.UnsafeNext()
   176  				w.WriteString(*((*string)(keyPtr)))
   177  				e.encoder.Encode(elemPtr, w)
   178  			}
   179  
   180  			return int64(i)
   181  		})
   182  
   183  		if wrote == 0 {
   184  			break
   185  		}
   186  	}
   187  
   188  	if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   189  		w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
   190  	}
   191  }
   192  
   193  func encoderOfMapMarshaler(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   194  	m := schema.(*MapSchema)
   195  	mapType := typ.(*reflect2.UnsafeMapType)
   196  	encoder := encoderOfType(cfg, m.Values(), mapType.Elem())
   197  
   198  	return &mapEncoderMarshaller{
   199  		blockLength: cfg.getBlockLength(),
   200  		mapType:     mapType,
   201  		keyType:     mapType.Key(),
   202  		encoder:     encoder,
   203  	}
   204  }
   205  
   206  type mapEncoderMarshaller struct {
   207  	blockLength int
   208  	mapType     *reflect2.UnsafeMapType
   209  	keyType     reflect2.Type
   210  	encoder     ValEncoder
   211  }
   212  
   213  func (e *mapEncoderMarshaller) Encode(ptr unsafe.Pointer, w *Writer) {
   214  	blockLength := e.blockLength
   215  
   216  	iter := e.mapType.UnsafeIterate(ptr)
   217  
   218  	for {
   219  		wrote := w.WriteBlockCB(func(w *Writer) int64 {
   220  			var i int
   221  			for i = 0; iter.HasNext() && i < blockLength; i++ {
   222  				keyPtr, elemPtr := iter.UnsafeNext()
   223  
   224  				obj := e.keyType.UnsafeIndirect(keyPtr)
   225  				if e.keyType.IsNullable() && reflect2.IsNil(obj) {
   226  					w.Error = errors.New("avro: mapEncoderMarshaller: encoding nil TextMarshaller")
   227  					return int64(0)
   228  				}
   229  				marshaler := (obj).(encoding.TextMarshaler)
   230  				b, err := marshaler.MarshalText()
   231  				if err != nil {
   232  					w.Error = err
   233  					return int64(0)
   234  				}
   235  				w.WriteString(string(b))
   236  
   237  				e.encoder.Encode(elemPtr, w)
   238  			}
   239  			return int64(i)
   240  		})
   241  
   242  		if wrote == 0 {
   243  			break
   244  		}
   245  	}
   246  
   247  	if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   248  		w.Error = fmt.Errorf("%v: %w", e.mapType, w.Error)
   249  	}
   250  }