github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_native.go (about)

     1  package avro
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"reflect"
     7  	"time"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  //nolint:maintidx // Splitting this would not make it simpler.
    14  func createDecoderOfNative(schema *PrimitiveSchema, typ reflect2.Type) ValDecoder {
    15  	resolved := schema.encodedType != ""
    16  	switch typ.Kind() {
    17  	case reflect.Bool:
    18  		if schema.Type() != Boolean {
    19  			break
    20  		}
    21  		return &boolCodec{}
    22  
    23  	case reflect.Int:
    24  		if schema.Type() != Int {
    25  			break
    26  		}
    27  		return &intCodec[int]{}
    28  
    29  	case reflect.Int8:
    30  		if schema.Type() != Int {
    31  			break
    32  		}
    33  		return &intCodec[int8]{}
    34  
    35  	case reflect.Uint8:
    36  		if schema.Type() != Int {
    37  			break
    38  		}
    39  		return &intCodec[uint8]{}
    40  
    41  	case reflect.Int16:
    42  		if schema.Type() != Int {
    43  			break
    44  		}
    45  		return &intCodec[int16]{}
    46  
    47  	case reflect.Uint16:
    48  		if schema.Type() != Int {
    49  			break
    50  		}
    51  		return &intCodec[uint16]{}
    52  
    53  	case reflect.Int32:
    54  		if schema.Type() != Int {
    55  			break
    56  		}
    57  		return &intCodec[int32]{}
    58  
    59  	case reflect.Uint32:
    60  		if schema.Type() != Long {
    61  			break
    62  		}
    63  		if resolved {
    64  			return &longConvCodec[uint32]{convert: createLongConverter(schema.encodedType)}
    65  		}
    66  		return &longCodec[uint32]{}
    67  
    68  	case reflect.Int64:
    69  		st := schema.Type()
    70  		lt := getLogicalType(schema)
    71  		switch {
    72  		case st == Int && lt == TimeMillis: // time.Duration
    73  			return &timeMillisCodec{}
    74  
    75  		case st == Long && lt == TimeMicros: // time.Duration
    76  			return &timeMicrosCodec{
    77  				convert: createLongConverter(schema.encodedType),
    78  			}
    79  
    80  		case st == Long && lt == "":
    81  			if resolved {
    82  				return &longConvCodec[int64]{convert: createLongConverter(schema.encodedType)}
    83  			}
    84  			return &longCodec[int64]{}
    85  
    86  		case lt != "":
    87  			return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s and logicalType %s",
    88  				typ.String(), schema.Type(), lt)}
    89  
    90  		default:
    91  			break
    92  		}
    93  
    94  	case reflect.Float32:
    95  		if schema.Type() != Float {
    96  			break
    97  		}
    98  		if resolved {
    99  			return &float32ConvCodec{convert: createFloatConverter(schema.encodedType)}
   100  		}
   101  		return &float32Codec{}
   102  
   103  	case reflect.Float64:
   104  		if schema.Type() != Double {
   105  			break
   106  		}
   107  		if resolved {
   108  			return &float64ConvCodec{convert: createDoubleConverter(schema.encodedType)}
   109  		}
   110  		return &float64Codec{}
   111  
   112  	case reflect.String:
   113  		if schema.Type() != String {
   114  			break
   115  		}
   116  		return &stringCodec{}
   117  
   118  	case reflect.Slice:
   119  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
   120  			break
   121  		}
   122  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
   123  
   124  	case reflect.Struct:
   125  		st := schema.Type()
   126  		ls := getLogicalSchema(schema)
   127  		lt := getLogicalType(schema)
   128  		isTime := typ.Type1().ConvertibleTo(timeType)
   129  		switch {
   130  		case isTime && st == Int && lt == Date:
   131  			return &dateCodec{}
   132  		case isTime && st == Long && lt == TimestampMillis:
   133  			return &timestampMillisCodec{
   134  				convert: createLongConverter(schema.encodedType),
   135  			}
   136  		case isTime && st == Long && lt == TimestampMicros:
   137  			return &timestampMicrosCodec{
   138  				convert: createLongConverter(schema.encodedType),
   139  			}
   140  		case isTime && st == Long && lt == LocalTimestampMillis:
   141  			return &timestampMillisCodec{
   142  				local:   true,
   143  				convert: createLongConverter(schema.encodedType),
   144  			}
   145  		case isTime && st == Long && lt == LocalTimestampMicros:
   146  			return &timestampMicrosCodec{
   147  				local:   true,
   148  				convert: createLongConverter(schema.encodedType),
   149  			}
   150  		case typ.Type1().ConvertibleTo(ratType) && st == Bytes && lt == Decimal:
   151  			dec := ls.(*DecimalLogicalSchema)
   152  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   153  
   154  		default:
   155  			break
   156  		}
   157  	case reflect.Ptr:
   158  		ptrType := typ.(*reflect2.UnsafePtrType)
   159  		elemType := ptrType.Elem()
   160  		tpy1 := elemType.Type1()
   161  		ls := getLogicalSchema(schema)
   162  		if ls == nil {
   163  			break
   164  		}
   165  		if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal {
   166  			break
   167  		}
   168  		dec := ls.(*DecimalLogicalSchema)
   169  
   170  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   171  	}
   172  
   173  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   174  }
   175  
   176  //nolint:maintidx // Splitting this would not make it simpler.
   177  func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder {
   178  	switch typ.Kind() {
   179  	case reflect.Bool:
   180  		if schema.Type() != Boolean {
   181  			break
   182  		}
   183  		return &boolCodec{}
   184  
   185  	case reflect.Int:
   186  		if schema.Type() != Int {
   187  			break
   188  		}
   189  		return &intCodec[int]{}
   190  
   191  	case reflect.Int8:
   192  		if schema.Type() != Int {
   193  			break
   194  		}
   195  		return &intCodec[int8]{}
   196  
   197  	case reflect.Uint8:
   198  		if schema.Type() != Int {
   199  			break
   200  		}
   201  		return &intCodec[uint8]{}
   202  
   203  	case reflect.Int16:
   204  		if schema.Type() != Int {
   205  			break
   206  		}
   207  		return &intCodec[int16]{}
   208  
   209  	case reflect.Uint16:
   210  		if schema.Type() != Int {
   211  			break
   212  		}
   213  		return &intCodec[uint16]{}
   214  
   215  	case reflect.Int32:
   216  		switch schema.Type() {
   217  		case Long:
   218  			return &longCodec[int32]{}
   219  
   220  		case Int:
   221  			return &intCodec[int32]{}
   222  		}
   223  
   224  	case reflect.Uint32:
   225  		if schema.Type() != Long {
   226  			break
   227  		}
   228  		return &longCodec[uint32]{}
   229  
   230  	case reflect.Int64:
   231  		st := schema.Type()
   232  		lt := getLogicalType(schema)
   233  		switch {
   234  		case st == Int && lt == TimeMillis: // time.Duration
   235  			return &timeMillisCodec{}
   236  
   237  		case st == Long && lt == TimeMicros: // time.Duration
   238  			return &timeMicrosCodec{}
   239  
   240  		case st == Long && lt == "":
   241  			return &longCodec[int64]{}
   242  
   243  		case lt != "":
   244  			return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s and logicalType %s",
   245  				typ.String(), schema.Type(), lt)}
   246  
   247  		default:
   248  			break
   249  		}
   250  
   251  	case reflect.Float32:
   252  		switch schema.Type() {
   253  		case Double:
   254  			return &float32DoubleCodec{}
   255  		case Float:
   256  			return &float32Codec{}
   257  		}
   258  
   259  	case reflect.Float64:
   260  		if schema.Type() != Double {
   261  			break
   262  		}
   263  		return &float64Codec{}
   264  
   265  	case reflect.String:
   266  		if schema.Type() != String {
   267  			break
   268  		}
   269  		return &stringCodec{}
   270  
   271  	case reflect.Slice:
   272  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
   273  			break
   274  		}
   275  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
   276  
   277  	case reflect.Struct:
   278  		st := schema.Type()
   279  		lt := getLogicalType(schema)
   280  		isTime := typ.Type1().ConvertibleTo(timeType)
   281  		switch {
   282  		case isTime && st == Int && lt == Date:
   283  			return &dateCodec{}
   284  		case isTime && st == Long && lt == TimestampMillis:
   285  			return &timestampMillisCodec{}
   286  		case isTime && st == Long && lt == TimestampMicros:
   287  			return &timestampMicrosCodec{}
   288  		case isTime && st == Long && lt == LocalTimestampMillis:
   289  			return &timestampMillisCodec{local: true}
   290  		case isTime && st == Long && lt == LocalTimestampMicros:
   291  			return &timestampMicrosCodec{local: true}
   292  		case typ.Type1().ConvertibleTo(ratType) && st != Bytes || lt == Decimal:
   293  			ls := getLogicalSchema(schema)
   294  			dec := ls.(*DecimalLogicalSchema)
   295  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   296  		default:
   297  			break
   298  		}
   299  
   300  	case reflect.Ptr:
   301  		ptrType := typ.(*reflect2.UnsafePtrType)
   302  		elemType := ptrType.Elem()
   303  		tpy1 := elemType.Type1()
   304  		ls := getLogicalSchema(schema)
   305  		if ls == nil {
   306  			break
   307  		}
   308  		if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal {
   309  			break
   310  		}
   311  		dec := ls.(*DecimalLogicalSchema)
   312  
   313  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   314  	}
   315  
   316  	if schema.Type() == Null {
   317  		return &nullCodec{}
   318  	}
   319  
   320  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   321  }
   322  
   323  func getLogicalSchema(schema Schema) LogicalSchema {
   324  	lts, ok := schema.(LogicalTypeSchema)
   325  	if !ok {
   326  		return nil
   327  	}
   328  
   329  	return lts.Logical()
   330  }
   331  
   332  func getLogicalType(schema Schema) LogicalType {
   333  	ls := getLogicalSchema(schema)
   334  	if ls == nil {
   335  		return ""
   336  	}
   337  
   338  	return ls.Type()
   339  }
   340  
   341  type nullCodec struct{}
   342  
   343  func (*nullCodec) Encode(unsafe.Pointer, *Writer) {}
   344  
   345  type boolCodec struct{}
   346  
   347  func (*boolCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   348  	*((*bool)(ptr)) = r.ReadBool()
   349  }
   350  
   351  func (*boolCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   352  	w.WriteBool(*((*bool)(ptr)))
   353  }
   354  
   355  type smallInt interface {
   356  	~int | ~int8 | ~int16 | ~int32 | ~uint | ~uint8 | ~uint16
   357  }
   358  
   359  type intCodec[T smallInt] struct{}
   360  
   361  func (*intCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) {
   362  	*((*T)(ptr)) = T(r.ReadInt())
   363  }
   364  
   365  func (*intCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) {
   366  	w.WriteInt(int32(*((*T)(ptr))))
   367  }
   368  
   369  type largeInt interface {
   370  	~int32 | ~uint32 | int64
   371  }
   372  
   373  type longCodec[T largeInt] struct{}
   374  
   375  func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) {
   376  	*((*T)(ptr)) = T(r.ReadLong())
   377  }
   378  
   379  func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) {
   380  	w.WriteLong(int64(*((*T)(ptr))))
   381  }
   382  
   383  type longConvCodec[T largeInt] struct {
   384  	convert func(*Reader) int64
   385  }
   386  
   387  func (c *longConvCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) {
   388  	*((*T)(ptr)) = T(c.convert(r))
   389  }
   390  
   391  type float32Codec struct{}
   392  
   393  func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   394  	*((*float32)(ptr)) = r.ReadFloat()
   395  }
   396  
   397  func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   398  	w.WriteFloat(*((*float32)(ptr)))
   399  }
   400  
   401  type float32ConvCodec struct {
   402  	convert func(*Reader) float32
   403  }
   404  
   405  func (c *float32ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   406  	*((*float32)(ptr)) = c.convert(r)
   407  }
   408  
   409  type float32DoubleCodec struct{}
   410  
   411  func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   412  	w.WriteDouble(float64(*((*float32)(ptr))))
   413  }
   414  
   415  type float64Codec struct{}
   416  
   417  func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   418  	*((*float64)(ptr)) = r.ReadDouble()
   419  }
   420  
   421  func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   422  	w.WriteDouble(*((*float64)(ptr)))
   423  }
   424  
   425  type float64ConvCodec struct {
   426  	convert func(*Reader) float64
   427  }
   428  
   429  func (c *float64ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   430  	*((*float64)(ptr)) = c.convert(r)
   431  }
   432  
   433  type stringCodec struct{}
   434  
   435  func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   436  	*((*string)(ptr)) = r.ReadString()
   437  }
   438  
   439  func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   440  	w.WriteString(*((*string)(ptr)))
   441  }
   442  
   443  type bytesCodec struct {
   444  	sliceType *reflect2.UnsafeSliceType
   445  }
   446  
   447  func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   448  	b := r.ReadBytes()
   449  	c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b))
   450  }
   451  
   452  func (c *bytesCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   453  	w.WriteBytes(*((*[]byte)(ptr)))
   454  }
   455  
   456  type dateCodec struct{}
   457  
   458  func (c *dateCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   459  	i := r.ReadInt()
   460  	sec := int64(i) * int64(24*time.Hour/time.Second)
   461  	*((*time.Time)(ptr)) = time.Unix(sec, 0).UTC()
   462  }
   463  
   464  func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   465  	t := *((*time.Time)(ptr))
   466  	days := t.Unix() / int64(24*time.Hour/time.Second)
   467  	w.WriteInt(int32(days))
   468  }
   469  
   470  type timestampMillisCodec struct {
   471  	local   bool
   472  	convert func(*Reader) int64
   473  }
   474  
   475  func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   476  	var i int64
   477  	if c.convert != nil {
   478  		i = c.convert(r)
   479  	} else {
   480  		i = r.ReadLong()
   481  	}
   482  	sec := i / 1e3
   483  	nsec := (i - sec*1e3) * 1e6
   484  	t := time.Unix(sec, nsec)
   485  
   486  	if c.local {
   487  		// When doing unix time, Go will convert the time from UTC to Local,
   488  		// changing the time by the number of seconds in the zone offset.
   489  		// Remove those added seconds.
   490  		_, offset := t.Zone()
   491  		t = t.Add(time.Duration(-1*offset) * time.Second)
   492  		*((*time.Time)(ptr)) = t
   493  		return
   494  	}
   495  	*((*time.Time)(ptr)) = t.UTC()
   496  }
   497  
   498  func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   499  	t := *((*time.Time)(ptr))
   500  	if c.local {
   501  		t = t.Local()
   502  		t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
   503  	}
   504  	w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
   505  }
   506  
   507  type timestampMicrosCodec struct {
   508  	local   bool
   509  	convert func(*Reader) int64
   510  }
   511  
   512  func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   513  	var i int64
   514  	if c.convert != nil {
   515  		i = c.convert(r)
   516  	} else {
   517  		i = r.ReadLong()
   518  	}
   519  	sec := i / 1e6
   520  	nsec := (i - sec*1e6) * 1e3
   521  	t := time.Unix(sec, nsec)
   522  
   523  	if c.local {
   524  		// When doing unix time, Go will convert the time from UTC to Local,
   525  		// changing the time by the number of seconds in the zone offset.
   526  		// Remove those added seconds.
   527  		_, offset := t.Zone()
   528  		t = t.Add(time.Duration(-1*offset) * time.Second)
   529  		*((*time.Time)(ptr)) = t
   530  		return
   531  	}
   532  	*((*time.Time)(ptr)) = t.UTC()
   533  }
   534  
   535  func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   536  	t := *((*time.Time)(ptr))
   537  	if c.local {
   538  		t = t.Local()
   539  		t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC)
   540  	}
   541  	w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3))
   542  }
   543  
   544  type timeMillisCodec struct{}
   545  
   546  func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   547  	i := r.ReadInt()
   548  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Millisecond
   549  }
   550  
   551  func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   552  	d := *((*time.Duration)(ptr))
   553  	w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond)))
   554  }
   555  
   556  type timeMicrosCodec struct {
   557  	convert func(*Reader) int64
   558  }
   559  
   560  func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   561  	var i int64
   562  	if c.convert != nil {
   563  		i = c.convert(r)
   564  	} else {
   565  		i = r.ReadLong()
   566  	}
   567  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond
   568  }
   569  
   570  func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   571  	d := *((*time.Duration)(ptr))
   572  	w.WriteLong(d.Nanoseconds() / int64(time.Microsecond))
   573  }
   574  
   575  var one = big.NewInt(1)
   576  
   577  type bytesDecimalCodec struct {
   578  	prec  int
   579  	scale int
   580  }
   581  
   582  func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   583  	b := r.ReadBytes()
   584  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   585  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   586  	}
   587  	*((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale)
   588  }
   589  
   590  func ratFromBytes(b []byte, scale int) *big.Rat {
   591  	num := (&big.Int{}).SetBytes(b)
   592  	if len(b) > 0 && b[0]&0x80 > 0 {
   593  		num.Sub(num, new(big.Int).Lsh(one, uint(len(b))*8))
   594  	}
   595  	denom := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(scale)), nil)
   596  	return new(big.Rat).SetFrac(num, denom)
   597  }
   598  
   599  func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   600  	r := (*big.Rat)(ptr)
   601  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   602  	i := (&big.Int{}).Mul(r.Num(), scale)
   603  	i = i.Div(i, r.Denom())
   604  
   605  	var b []byte
   606  	switch i.Sign() {
   607  	case 0:
   608  		b = []byte{0}
   609  
   610  	case 1:
   611  		b = i.Bytes()
   612  		if b[0]&0x80 > 0 {
   613  			b = append([]byte{0}, b...)
   614  		}
   615  
   616  	case -1:
   617  		length := uint(i.BitLen()/8+1) * 8
   618  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   619  	}
   620  	w.WriteBytes(b)
   621  }
   622  
   623  type bytesDecimalPtrCodec struct {
   624  	prec  int
   625  	scale int
   626  }
   627  
   628  func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   629  	b := r.ReadBytes()
   630  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   631  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   632  	}
   633  	*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
   634  }
   635  
   636  func (c *bytesDecimalPtrCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   637  	r := *((**big.Rat)(ptr))
   638  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   639  	i := (&big.Int{}).Mul(r.Num(), scale)
   640  	i = i.Div(i, r.Denom())
   641  
   642  	var b []byte
   643  	switch i.Sign() {
   644  	case 0:
   645  		b = []byte{0}
   646  
   647  	case 1:
   648  		b = i.Bytes()
   649  		if b[0]&0x80 > 0 {
   650  			b = append([]byte{0}, b...)
   651  		}
   652  
   653  	case -1:
   654  		length := uint(i.BitLen()/8+1) * 8
   655  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   656  	}
   657  	w.WriteBytes(b)
   658  }