github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_record.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Struct:
    16  		return decoderOfStruct(cfg, schema, typ)
    17  
    18  	case reflect.Map:
    19  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    20  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    21  			break
    22  		}
    23  		return decoderOfRecord(cfg, schema, typ)
    24  
    25  	case reflect.Ptr:
    26  		return decoderOfPtr(cfg, schema, typ)
    27  
    28  	case reflect.Interface:
    29  		if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok {
    30  			return &recordIfaceDecoder{schema: schema, valType: ifaceType}
    31  		}
    32  	}
    33  
    34  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    35  }
    36  
    37  func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    38  	switch typ.Kind() {
    39  	case reflect.Struct:
    40  		return encoderOfStruct(cfg, schema, typ)
    41  
    42  	case reflect.Map:
    43  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    44  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    45  			break
    46  		}
    47  		return encoderOfRecord(cfg, schema, typ)
    48  
    49  	case reflect.Ptr:
    50  		return encoderOfPtr(cfg, schema, typ)
    51  	}
    52  
    53  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    54  }
    55  
    56  func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    57  	rec := schema.(*RecordSchema)
    58  	structDesc := describeStruct(cfg.getTagKey(), typ)
    59  
    60  	fields := make([]*structFieldDecoder, 0, len(rec.Fields()))
    61  
    62  	for _, field := range rec.Fields() {
    63  		if field.action == FieldIgnore {
    64  			fields = append(fields, &structFieldDecoder{
    65  				decoder: createSkipDecoder(field.Type()),
    66  			})
    67  			continue
    68  		}
    69  
    70  		sf := structDesc.Fields.Get(field.Name())
    71  		if sf == nil {
    72  			for _, alias := range field.Aliases() {
    73  				sf = structDesc.Fields.Get(alias)
    74  				if sf != nil {
    75  					break
    76  				}
    77  			}
    78  		}
    79  
    80  		// Skip field if it doesnt exist
    81  		if sf == nil {
    82  			fields = append(fields, &structFieldDecoder{
    83  				decoder: createSkipDecoder(field.Type()),
    84  			})
    85  			continue
    86  		}
    87  
    88  		if field.action == FieldSetDefault {
    89  			if field.hasDef {
    90  				fields = append(fields, &structFieldDecoder{
    91  					field:   sf.Field,
    92  					decoder: createDefaultDecoder(cfg, field, sf.Field[len(sf.Field)-1].Type()),
    93  				})
    94  
    95  				continue
    96  			}
    97  		}
    98  
    99  		dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type())
   100  		fields = append(fields, &structFieldDecoder{
   101  			field:   sf.Field,
   102  			decoder: dec,
   103  		})
   104  	}
   105  
   106  	return &structDecoder{typ: typ, fields: fields}
   107  }
   108  
   109  type structFieldDecoder struct {
   110  	field   []*reflect2.UnsafeStructField
   111  	decoder ValDecoder
   112  }
   113  
   114  type structDecoder struct {
   115  	typ    reflect2.Type
   116  	fields []*structFieldDecoder
   117  }
   118  
   119  func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   120  	for _, field := range d.fields {
   121  		// Skip case
   122  		if field.field == nil {
   123  			field.decoder.Decode(nil, r)
   124  			continue
   125  		}
   126  
   127  		fieldPtr := ptr
   128  		for i, f := range field.field {
   129  			fieldPtr = f.UnsafeGet(fieldPtr)
   130  
   131  			if i == len(field.field)-1 {
   132  				break
   133  			}
   134  
   135  			if f.Type().Kind() == reflect.Ptr {
   136  				if *((*unsafe.Pointer)(fieldPtr)) == nil {
   137  					newPtr := f.Type().(*reflect2.UnsafePtrType).Elem().UnsafeNew()
   138  					*((*unsafe.Pointer)(fieldPtr)) = newPtr
   139  				}
   140  
   141  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   142  			}
   143  		}
   144  		field.decoder.Decode(fieldPtr, r)
   145  
   146  		if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   147  			for _, f := range field.field {
   148  				r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error)
   149  				return
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   156  	rec := schema.(*RecordSchema)
   157  	structDesc := describeStruct(cfg.getTagKey(), typ)
   158  
   159  	fields := make([]*structFieldEncoder, 0, len(rec.Fields()))
   160  	for _, field := range rec.Fields() {
   161  		sf := structDesc.Fields.Get(field.Name())
   162  		if sf != nil {
   163  			fields = append(fields, &structFieldEncoder{
   164  				field:   sf.Field,
   165  				encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()),
   166  			})
   167  			continue
   168  		}
   169  
   170  		if !field.HasDefault() {
   171  			// In all other cases, this is a required field
   172  			err := fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name())
   173  			return &errorEncoder{err: err}
   174  		}
   175  
   176  		def := field.Default()
   177  		if field.Default() == nil {
   178  			if field.Type().Type() == Null {
   179  				// We write nothing in a Null case, just skip it
   180  				continue
   181  			}
   182  
   183  			if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() {
   184  				defaultType := reflect2.TypeOf(&def)
   185  				fields = append(fields, &structFieldEncoder{
   186  					defaultPtr: reflect2.PtrOf(&def),
   187  					encoder:    encoderOfNullableUnion(cfg, field.Type(), defaultType),
   188  				})
   189  				continue
   190  			}
   191  		}
   192  
   193  		defaultType := reflect2.TypeOf(def)
   194  		defaultEncoder := encoderOfType(cfg, field.Type(), defaultType)
   195  		if defaultType.LikePtr() {
   196  			defaultEncoder = &onePtrEncoder{defaultEncoder}
   197  		}
   198  		fields = append(fields, &structFieldEncoder{
   199  			defaultPtr: reflect2.PtrOf(def),
   200  			encoder:    defaultEncoder,
   201  		})
   202  	}
   203  	return &structEncoder{typ: typ, fields: fields}
   204  }
   205  
   206  type structFieldEncoder struct {
   207  	field      []*reflect2.UnsafeStructField
   208  	defaultPtr unsafe.Pointer
   209  	encoder    ValEncoder
   210  }
   211  
   212  type structEncoder struct {
   213  	typ    reflect2.Type
   214  	fields []*structFieldEncoder
   215  }
   216  
   217  func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   218  	for _, field := range e.fields {
   219  		// Default case
   220  		if field.field == nil {
   221  			field.encoder.Encode(field.defaultPtr, w)
   222  			continue
   223  		}
   224  
   225  		fieldPtr := ptr
   226  		for i, f := range field.field {
   227  			fieldPtr = f.UnsafeGet(fieldPtr)
   228  
   229  			if i == len(field.field)-1 {
   230  				break
   231  			}
   232  
   233  			if f.Type().Kind() == reflect.Ptr {
   234  				if *((*unsafe.Pointer)(fieldPtr)) == nil {
   235  					w.Error = fmt.Errorf("embedded field %q is nil", f.Name())
   236  					return
   237  				}
   238  
   239  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   240  			}
   241  		}
   242  		field.encoder.Encode(fieldPtr, w)
   243  
   244  		if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   245  			for _, f := range field.field {
   246  				w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error)
   247  				return
   248  			}
   249  		}
   250  	}
   251  }
   252  
   253  func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   254  	rec := schema.(*RecordSchema)
   255  	mapType := typ.(*reflect2.UnsafeMapType)
   256  
   257  	fields := make([]recordMapDecoderField, len(rec.Fields()))
   258  	for i, field := range rec.Fields() {
   259  		switch field.action {
   260  		case FieldIgnore:
   261  			fields[i] = recordMapDecoderField{
   262  				name:    field.Name(),
   263  				decoder: createSkipDecoder(field.Type()),
   264  				skip:    true,
   265  			}
   266  			continue
   267  		case FieldSetDefault:
   268  			if field.hasDef {
   269  				fields[i] = recordMapDecoderField{
   270  					name:    field.Name(),
   271  					decoder: createDefaultDecoder(cfg, field, mapType.Elem()),
   272  				}
   273  				continue
   274  			}
   275  		}
   276  
   277  		fields[i] = recordMapDecoderField{
   278  			name:    field.Name(),
   279  			decoder: newEfaceDecoder(cfg, field.Type()),
   280  		}
   281  	}
   282  
   283  	return &recordMapDecoder{
   284  		mapType:  mapType,
   285  		elemType: mapType.Elem(),
   286  		fields:   fields,
   287  	}
   288  }
   289  
   290  type recordMapDecoderField struct {
   291  	name    string
   292  	decoder ValDecoder
   293  	skip    bool
   294  }
   295  
   296  type recordMapDecoder struct {
   297  	mapType  *reflect2.UnsafeMapType
   298  	elemType reflect2.Type
   299  	fields   []recordMapDecoderField
   300  }
   301  
   302  func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   303  	if d.mapType.UnsafeIsNil(ptr) {
   304  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(len(d.fields)))
   305  	}
   306  
   307  	for _, field := range d.fields {
   308  		elemPtr := d.elemType.UnsafeNew()
   309  		field.decoder.Decode(elemPtr, r)
   310  		if field.skip {
   311  			continue
   312  		}
   313  
   314  		d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elemPtr)
   315  	}
   316  
   317  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   318  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
   319  	}
   320  }
   321  
   322  func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   323  	rec := schema.(*RecordSchema)
   324  	mapType := typ.(*reflect2.UnsafeMapType)
   325  
   326  	fields := make([]mapEncoderField, len(rec.Fields()))
   327  	for i, field := range rec.Fields() {
   328  		fields[i] = mapEncoderField{
   329  			name:    field.Name(),
   330  			hasDef:  field.HasDefault(),
   331  			def:     field.Default(),
   332  			encoder: encoderOfType(cfg, field.Type(), mapType.Elem()),
   333  		}
   334  
   335  		if field.HasDefault() {
   336  			switch {
   337  			case field.Type().Type() == Union:
   338  				union := field.Type().(*UnionSchema)
   339  				fields[i].def = map[string]any{
   340  					string(union.Types()[0].Type()): field.Default(),
   341  				}
   342  			case field.Default() == nil:
   343  				continue
   344  			}
   345  
   346  			defaultType := reflect2.TypeOf(fields[i].def)
   347  			fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType)
   348  			if defaultType.LikePtr() {
   349  				fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder}
   350  			}
   351  		}
   352  	}
   353  
   354  	return &recordMapEncoder{
   355  		mapType: mapType,
   356  		fields:  fields,
   357  	}
   358  }
   359  
   360  type mapEncoderField struct {
   361  	name       string
   362  	hasDef     bool
   363  	def        any
   364  	defEncoder ValEncoder
   365  	encoder    ValEncoder
   366  }
   367  
   368  type recordMapEncoder struct {
   369  	mapType *reflect2.UnsafeMapType
   370  	fields  []mapEncoderField
   371  }
   372  
   373  func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   374  	for _, field := range e.fields {
   375  		// The first property of mapEncoderField is the name, so a pointer
   376  		// to field is a pointer to the name.
   377  		valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field))
   378  		if valPtr == nil {
   379  			// Missing required field
   380  			if !field.hasDef {
   381  				w.Error = fmt.Errorf("avro: missing required field %s", field.name)
   382  				return
   383  			}
   384  
   385  			// Null default
   386  			if field.def == nil {
   387  				continue
   388  			}
   389  
   390  			defPtr := reflect2.PtrOf(field.def)
   391  			field.defEncoder.Encode(defPtr, w)
   392  			continue
   393  		}
   394  
   395  		field.encoder.Encode(valPtr, w)
   396  
   397  		if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   398  			w.Error = fmt.Errorf("%s: %w", field.name, w.Error)
   399  			return
   400  		}
   401  	}
   402  }
   403  
   404  type recordIfaceDecoder struct {
   405  	schema  Schema
   406  	valType *reflect2.UnsafeIFaceType
   407  }
   408  
   409  func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   410  	obj := d.valType.UnsafeIndirect(ptr)
   411  	if reflect2.IsNil(obj) {
   412  		r.ReportError("decode non empty interface", "can not unmarshal into nil")
   413  		return
   414  	}
   415  
   416  	r.ReadVal(d.schema, obj)
   417  }
   418  
   419  type structDescriptor struct {
   420  	Type   reflect2.Type
   421  	Fields structFields
   422  }
   423  
   424  type structFields []*structField
   425  
   426  func (sf structFields) Get(name string) *structField {
   427  	for _, f := range sf {
   428  		if f.Name == name {
   429  			return f
   430  		}
   431  	}
   432  
   433  	return nil
   434  }
   435  
   436  type structField struct {
   437  	Name  string
   438  	Field []*reflect2.UnsafeStructField
   439  
   440  	anon *reflect2.UnsafeStructType
   441  }
   442  
   443  func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor {
   444  	structType := typ.(*reflect2.UnsafeStructType)
   445  	fields := structFields{}
   446  
   447  	var curr []structField
   448  	next := []structField{{anon: structType}}
   449  
   450  	visited := map[uintptr]bool{}
   451  
   452  	for len(next) > 0 {
   453  		curr, next = next, curr[:0]
   454  
   455  		for _, f := range curr {
   456  			rtype := f.anon.RType()
   457  			if visited[f.anon.RType()] {
   458  				continue
   459  			}
   460  			visited[rtype] = true
   461  
   462  			for i := 0; i < f.anon.NumField(); i++ {
   463  				field := f.anon.Field(i).(*reflect2.UnsafeStructField)
   464  				isUnexported := field.PkgPath() != ""
   465  
   466  				chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1)
   467  				copy(chain, f.Field)
   468  				chain[len(f.Field)] = field
   469  
   470  				if field.Anonymous() {
   471  					t := field.Type()
   472  					if t.Kind() == reflect.Ptr {
   473  						t = t.(*reflect2.UnsafePtrType).Elem()
   474  					}
   475  					if t.Kind() != reflect.Struct {
   476  						continue
   477  					}
   478  
   479  					next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)})
   480  					continue
   481  				}
   482  
   483  				// Ignore unexported fields.
   484  				if isUnexported {
   485  					continue
   486  				}
   487  
   488  				fieldName := field.Name()
   489  				if tag, ok := field.Tag().Lookup(tagKey); ok {
   490  					fieldName = tag
   491  				}
   492  
   493  				fields = append(fields, &structField{
   494  					Name:  fieldName,
   495  					Field: chain,
   496  				})
   497  			}
   498  		}
   499  	}
   500  
   501  	return &structDescriptor{
   502  		Type:   structType,
   503  		Fields: fields,
   504  	}
   505  }