github.com/aacfactory/avro@v1.2.12/internal/base/codec_record.go (about)

     1  package base
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Struct:
    16  		return decoderOfStruct(cfg, schema, typ)
    17  
    18  	case reflect.Map:
    19  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    20  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    21  			break
    22  		}
    23  		return decoderOfRecord(cfg, schema, typ)
    24  
    25  	case reflect.Ptr:
    26  		return decoderOfPtr(cfg, schema, typ)
    27  
    28  	case reflect.Interface:
    29  		if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok {
    30  			return &recordIfaceDecoder{schema: schema, valType: ifaceType}
    31  		}
    32  	default:
    33  		break
    34  	}
    35  
    36  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    37  }
    38  
    39  func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    40  	switch typ.Kind() {
    41  	case reflect.Struct:
    42  		return encoderOfStruct(cfg, schema, typ)
    43  
    44  	case reflect.Map:
    45  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    46  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    47  			break
    48  		}
    49  		return encoderOfRecord(cfg, schema, typ)
    50  
    51  	case reflect.Ptr:
    52  		return encoderOfPtr(cfg, schema, typ)
    53  	default:
    54  		break
    55  	}
    56  
    57  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    58  }
    59  
    60  func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    61  	processing := cfg.getProcessingDecoderFromCache(schema.Fingerprint(), typ.RType())
    62  	if processing != nil {
    63  		return processing
    64  	}
    65  	dec := &structDecoder{typ: typ, fields: nil}
    66  	cfg.addProcessingDecoderToCache(schema.Fingerprint(), typ.RType(), dec)
    67  
    68  	rec := schema.(*RecordSchema)
    69  	structDesc := describeStruct(cfg.getTagKey(), typ)
    70  
    71  	fields := make([]*structFieldDecoder, 0, len(rec.Fields()))
    72  	for _, field := range rec.Fields() {
    73  		sf := structDesc.Fields.Get(field.Name())
    74  		if sf == nil {
    75  			for _, alias := range field.Aliases() {
    76  				sf = structDesc.Fields.Get(alias)
    77  				if sf != nil {
    78  					break
    79  				}
    80  			}
    81  		}
    82  
    83  		// Skip field if it doesnt exist
    84  		if sf == nil {
    85  			fields = append(fields, &structFieldDecoder{
    86  				decoder: createSkipDecoder(field.Type()),
    87  			})
    88  			continue
    89  		}
    90  
    91  		fields = append(fields, &structFieldDecoder{
    92  			field:   sf.Field,
    93  			decoder: decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()),
    94  		})
    95  	}
    96  
    97  	dec.fields = fields
    98  	return dec
    99  }
   100  
   101  type structFieldDecoder struct {
   102  	field   []*reflect2.UnsafeStructField
   103  	decoder ValDecoder
   104  }
   105  
   106  type structDecoder struct {
   107  	typ    reflect2.Type
   108  	fields []*structFieldDecoder
   109  }
   110  
   111  func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   112  	for _, field := range d.fields {
   113  		// Skip case
   114  		if field.field == nil {
   115  			field.decoder.Decode(nil, r)
   116  			continue
   117  		}
   118  
   119  		fieldPtr := ptr
   120  		for i, f := range field.field {
   121  			fieldPtr = f.UnsafeGet(fieldPtr)
   122  
   123  			if i == len(field.field)-1 {
   124  				break
   125  			}
   126  
   127  			if f.Type().Kind() == reflect.Ptr {
   128  				if *((*unsafe.Pointer)(fieldPtr)) == nil {
   129  					newPtr := f.Type().(*reflect2.UnsafePtrType).Elem().UnsafeNew()
   130  					*((*unsafe.Pointer)(fieldPtr)) = newPtr
   131  				}
   132  
   133  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   134  			}
   135  		}
   136  		field.decoder.Decode(fieldPtr, r)
   137  
   138  		if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   139  			for _, f := range field.field {
   140  				r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error)
   141  				return
   142  			}
   143  		}
   144  	}
   145  }
   146  
   147  func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   148  	processing := cfg.getProcessingEncoderFromCache(schema.Fingerprint(), typ.RType())
   149  	if processing != nil {
   150  		return processing
   151  	}
   152  	enc := &structEncoder{typ: typ, fields: nil}
   153  	cfg.addProcessingEncoderToCache(schema.Fingerprint(), typ.RType(), enc)
   154  
   155  	rec := schema.(*RecordSchema)
   156  	structDesc := describeStruct(cfg.getTagKey(), typ)
   157  
   158  	fields := make([]*structFieldEncoder, 0, len(rec.Fields()))
   159  	for _, field := range rec.Fields() {
   160  		sf := structDesc.Fields.Get(field.Name())
   161  		if sf != nil {
   162  			fields = append(fields, &structFieldEncoder{
   163  				field:   sf.Field,
   164  				encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()),
   165  			})
   166  			continue
   167  		}
   168  
   169  		if !field.HasDefault() {
   170  			// In all other cases, this is a required field
   171  			err := fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name())
   172  			return &errorEncoder{err: err}
   173  		}
   174  
   175  		def := field.Default()
   176  		if field.Default() == nil {
   177  			if field.Type().Type() == Null {
   178  				// We write nothing in a Null case, just skip it
   179  				continue
   180  			}
   181  
   182  			if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() {
   183  				defaultType := reflect2.TypeOf(&def)
   184  				fields = append(fields, &structFieldEncoder{
   185  					defaultPtr: reflect2.PtrOf(&def),
   186  					encoder:    encoderOfPtrUnion(cfg, field.Type(), defaultType),
   187  				})
   188  				continue
   189  			}
   190  		}
   191  
   192  		defaultType := reflect2.TypeOf(def)
   193  		defaultEncoder := encoderOfType(cfg, field.Type(), defaultType)
   194  		if defaultType.LikePtr() {
   195  			defaultEncoder = &onePtrEncoder{defaultEncoder}
   196  		}
   197  		fields = append(fields, &structFieldEncoder{
   198  			defaultPtr: reflect2.PtrOf(def),
   199  			encoder:    defaultEncoder,
   200  		})
   201  	}
   202  
   203  	enc.fields = fields
   204  	return enc
   205  }
   206  
   207  type structFieldEncoder struct {
   208  	field      []*reflect2.UnsafeStructField
   209  	defaultPtr unsafe.Pointer
   210  	encoder    ValEncoder
   211  }
   212  
   213  type structEncoder struct {
   214  	typ    reflect2.Type
   215  	fields []*structFieldEncoder
   216  }
   217  
   218  func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   219  	for _, field := range e.fields {
   220  		// Default case
   221  		if field.field == nil {
   222  			field.encoder.Encode(field.defaultPtr, w)
   223  			continue
   224  		}
   225  
   226  		fieldPtr := ptr
   227  		for i, f := range field.field {
   228  			fieldPtr = f.UnsafeGet(fieldPtr)
   229  
   230  			if i == len(field.field)-1 {
   231  				break
   232  			}
   233  
   234  			if f.Type().Kind() == reflect.Ptr {
   235  				if *((*unsafe.Pointer)(fieldPtr)) == nil {
   236  					w.Error = fmt.Errorf("embedded field %q is nil", f.Name())
   237  					return
   238  				}
   239  
   240  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   241  			}
   242  		}
   243  		field.encoder.Encode(fieldPtr, w)
   244  
   245  		if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   246  			for _, f := range field.field {
   247  				w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error)
   248  				return
   249  			}
   250  		}
   251  	}
   252  }
   253  
   254  func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   255  	rec := schema.(*RecordSchema)
   256  	mapType := typ.(*reflect2.UnsafeMapType)
   257  
   258  	fields := make([]recordMapDecoderField, len(rec.Fields()))
   259  	for i, field := range rec.Fields() {
   260  		fields[i] = recordMapDecoderField{
   261  			name:    field.Name(),
   262  			decoder: decoderOfType(cfg, field.Type(), mapType.Elem()),
   263  		}
   264  	}
   265  
   266  	return &recordMapDecoder{
   267  		mapType:  mapType,
   268  		elemType: mapType.Elem(),
   269  		fields:   fields,
   270  	}
   271  }
   272  
   273  type recordMapDecoderField struct {
   274  	name    string
   275  	decoder ValDecoder
   276  }
   277  
   278  type recordMapDecoder struct {
   279  	mapType  *reflect2.UnsafeMapType
   280  	elemType reflect2.Type
   281  	fields   []recordMapDecoderField
   282  }
   283  
   284  func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   285  	if d.mapType.UnsafeIsNil(ptr) {
   286  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
   287  	}
   288  
   289  	for _, field := range d.fields {
   290  		elem := d.elemType.UnsafeNew()
   291  		field.decoder.Decode(elem, r)
   292  
   293  		d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem)
   294  	}
   295  
   296  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   297  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
   298  	}
   299  }
   300  
   301  func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   302  	rec := schema.(*RecordSchema)
   303  	mapType := typ.(*reflect2.UnsafeMapType)
   304  
   305  	fields := make([]mapEncoderField, len(rec.Fields()))
   306  	for i, field := range rec.Fields() {
   307  		fields[i] = mapEncoderField{
   308  			name:    field.Name(),
   309  			hasDef:  field.HasDefault(),
   310  			def:     field.Default(),
   311  			encoder: encoderOfType(cfg, field.Type(), mapType.Elem()),
   312  		}
   313  
   314  		if field.HasDefault() {
   315  			switch {
   316  			case field.Type().Type() == Union:
   317  				union := field.Type().(*UnionSchema)
   318  				fields[i].def = map[string]any{
   319  					string(union.Types()[0].Type()): field.Default(),
   320  				}
   321  			case field.Default() == nil:
   322  				continue
   323  			}
   324  
   325  			defaultType := reflect2.TypeOf(fields[i].def)
   326  			fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType)
   327  			if defaultType.LikePtr() {
   328  				fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder}
   329  			}
   330  		}
   331  	}
   332  
   333  	return &recordMapEncoder{
   334  		mapType: mapType,
   335  		fields:  fields,
   336  	}
   337  }
   338  
   339  type mapEncoderField struct {
   340  	name       string
   341  	hasDef     bool
   342  	def        any
   343  	defEncoder ValEncoder
   344  	encoder    ValEncoder
   345  }
   346  
   347  type recordMapEncoder struct {
   348  	mapType *reflect2.UnsafeMapType
   349  	fields  []mapEncoderField
   350  }
   351  
   352  func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   353  	for _, field := range e.fields {
   354  		// The first property of mapEncoderField is the name, so a pointer
   355  		// to field is a pointer to the name.
   356  		valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field))
   357  		if valPtr == nil {
   358  			// Missing required field
   359  			if !field.hasDef {
   360  				w.Error = fmt.Errorf("avro: missing required field %s", field.name)
   361  				return
   362  			}
   363  
   364  			// Null default
   365  			if field.def == nil {
   366  				continue
   367  			}
   368  
   369  			defPtr := reflect2.PtrOf(field.def)
   370  			field.defEncoder.Encode(defPtr, w)
   371  			continue
   372  		}
   373  
   374  		field.encoder.Encode(valPtr, w)
   375  
   376  		if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   377  			w.Error = fmt.Errorf("%s: %w", field.name, w.Error)
   378  			return
   379  		}
   380  	}
   381  }
   382  
   383  type recordIfaceDecoder struct {
   384  	schema  Schema
   385  	valType *reflect2.UnsafeIFaceType
   386  }
   387  
   388  func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   389  	obj := d.valType.UnsafeIndirect(ptr)
   390  	if reflect2.IsNil(obj) {
   391  		r.ReportError("decode non empty interface", "can not unmarshal into nil")
   392  		return
   393  	}
   394  
   395  	r.ReadVal(d.schema, obj)
   396  }
   397  
   398  type structDescriptor struct {
   399  	Type   reflect2.Type
   400  	Fields structFields
   401  }
   402  
   403  type structFields []*structField
   404  
   405  func (sf structFields) Get(name string) *structField {
   406  	for _, f := range sf {
   407  		if f.Name == name {
   408  			return f
   409  		}
   410  	}
   411  
   412  	return nil
   413  }
   414  
   415  type structField struct {
   416  	Name  string
   417  	Field []*reflect2.UnsafeStructField
   418  
   419  	anon *reflect2.UnsafeStructType
   420  }
   421  
   422  func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor {
   423  	structType := typ.(*reflect2.UnsafeStructType)
   424  	fields := structFields{}
   425  
   426  	var curr []structField
   427  	next := []structField{{anon: structType}}
   428  
   429  	visited := map[uintptr]bool{}
   430  
   431  	for len(next) > 0 {
   432  		curr, next = next, curr[:0]
   433  
   434  		for _, f := range curr {
   435  			rtype := f.anon.RType()
   436  			if visited[f.anon.RType()] {
   437  				continue
   438  			}
   439  			visited[rtype] = true
   440  
   441  			for i := 0; i < f.anon.NumField(); i++ {
   442  				field := f.anon.Field(i).(*reflect2.UnsafeStructField)
   443  				isUnexported := field.PkgPath() != ""
   444  
   445  				chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1)
   446  				copy(chain, f.Field)
   447  				chain[len(f.Field)] = field
   448  
   449  				if field.Anonymous() {
   450  					t := field.Type()
   451  					if t.Kind() == reflect.Ptr {
   452  						t = t.(*reflect2.UnsafePtrType).Elem()
   453  					}
   454  					if t.Kind() != reflect.Struct {
   455  						continue
   456  					}
   457  
   458  					next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)})
   459  					continue
   460  				}
   461  
   462  				// Ignore unexported fields.
   463  				if isUnexported {
   464  					continue
   465  				}
   466  
   467  				fieldName := field.Name()
   468  				if tag, ok := field.Tag().Lookup(tagKey); ok {
   469  					fieldName = tag
   470  				}
   471  
   472  				fields = append(fields, &structField{
   473  					Name:  fieldName,
   474  					Field: chain,
   475  				})
   476  			}
   477  		}
   478  	}
   479  
   480  	return &structDescriptor{
   481  		Type:   structType,
   482  		Fields: fields,
   483  	}
   484  }