github.com/aacfactory/avro@v1.2.12/internal/base/codec_native.go (about)

     1  package base
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"reflect"
     7  	"time"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Bool:
    16  		if schema.Type() != Boolean {
    17  			break
    18  		}
    19  		return &boolCodec{}
    20  
    21  	case reflect.Int:
    22  		if schema.Type() != Int {
    23  			break
    24  		}
    25  		return &intCodec[int]{}
    26  
    27  	case reflect.Int8:
    28  		if schema.Type() != Int {
    29  			break
    30  		}
    31  		return &intCodec[int8]{}
    32  
    33  	case reflect.Uint8:
    34  		if schema.Type() != Int {
    35  			break
    36  		}
    37  		return &intCodec[uint8]{}
    38  
    39  	case reflect.Int16:
    40  		if schema.Type() != Int {
    41  			break
    42  		}
    43  		return &intCodec[int16]{}
    44  
    45  	case reflect.Uint16:
    46  		if schema.Type() != Int {
    47  			break
    48  		}
    49  		return &intCodec[uint16]{}
    50  
    51  	case reflect.Int32:
    52  		if schema.Type() != Int {
    53  			break
    54  		}
    55  		return &intCodec[int32]{}
    56  
    57  	case reflect.Uint32:
    58  		if schema.Type() != Long {
    59  			break
    60  		}
    61  		return &longCodec[uint32]{}
    62  
    63  	case reflect.Int64:
    64  		st := schema.Type()
    65  		lt := getLogicalType(schema)
    66  		switch {
    67  		case st == Int && lt == TimeMillis: // time.Duration
    68  			return &timeMillisCodec{}
    69  
    70  		case st == Long && lt == TimeMicros: // time.Duration
    71  			return &timeMicrosCodec{}
    72  
    73  		case st == Long:
    74  			return &longCodec[int64]{}
    75  
    76  		default:
    77  			break
    78  		}
    79  
    80  	case reflect.Float32:
    81  		if schema.Type() != Float {
    82  			break
    83  		}
    84  		return &float32Codec{}
    85  
    86  	case reflect.Float64:
    87  		if schema.Type() != Double {
    88  			break
    89  		}
    90  		return &float64Codec{}
    91  
    92  	case reflect.String:
    93  		if schema.Type() != String {
    94  			break
    95  		}
    96  		return &stringCodec{}
    97  
    98  	case reflect.Slice:
    99  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
   100  			break
   101  		}
   102  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
   103  
   104  	case reflect.Struct:
   105  		st := schema.Type()
   106  		ls := getLogicalSchema(schema)
   107  		lt := getLogicalType(schema)
   108  		tpy1 := typ.Type1()
   109  		Istpy1Time := tpy1.ConvertibleTo(timeType)
   110  		Istpy1Rat := tpy1.ConvertibleTo(ratType)
   111  		switch {
   112  		case Istpy1Time && st == Int && lt == Date:
   113  			return &dateCodec{}
   114  
   115  		case Istpy1Time && st == Long && lt == TimestampMillis:
   116  			return &timestampMillisCodec{}
   117  
   118  		case Istpy1Time && st == Long && lt == TimestampMicros:
   119  			return &timestampMicrosCodec{}
   120  
   121  		case Istpy1Rat && st == Bytes && lt == Decimal:
   122  			dec := ls.(*DecimalLogicalSchema)
   123  
   124  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   125  
   126  		default:
   127  			break
   128  		}
   129  	case reflect.Ptr:
   130  		ptrType := typ.(*reflect2.UnsafePtrType)
   131  		elemType := ptrType.Elem()
   132  		tpy1 := elemType.Type1()
   133  		ls := getLogicalSchema(schema)
   134  		if ls == nil {
   135  			break
   136  		}
   137  		if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal {
   138  			break
   139  		}
   140  		dec := ls.(*DecimalLogicalSchema)
   141  
   142  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   143  	default:
   144  		break
   145  	}
   146  
   147  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   148  }
   149  
   150  func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder {
   151  	switch typ.Kind() {
   152  	case reflect.Bool:
   153  		if schema.Type() != Boolean {
   154  			break
   155  		}
   156  		return &boolCodec{}
   157  
   158  	case reflect.Int:
   159  		if schema.Type() != Int {
   160  			break
   161  		}
   162  		return &intCodec[int]{}
   163  
   164  	case reflect.Int8:
   165  		if schema.Type() != Int {
   166  			break
   167  		}
   168  		return &intCodec[int8]{}
   169  
   170  	case reflect.Uint8:
   171  		if schema.Type() != Int {
   172  			break
   173  		}
   174  		return &intCodec[uint8]{}
   175  
   176  	case reflect.Int16:
   177  		if schema.Type() != Int {
   178  			break
   179  		}
   180  		return &intCodec[int16]{}
   181  
   182  	case reflect.Uint16:
   183  		if schema.Type() != Int {
   184  			break
   185  		}
   186  		return &intCodec[uint16]{}
   187  
   188  	case reflect.Int32:
   189  		switch schema.Type() {
   190  		case Long:
   191  			return &longCodec[int32]{}
   192  
   193  		case Int:
   194  			return &intCodec[int32]{}
   195  		}
   196  
   197  	case reflect.Uint32:
   198  		if schema.Type() != Long {
   199  			break
   200  		}
   201  		return &longCodec[uint32]{}
   202  
   203  	case reflect.Int64:
   204  		st := schema.Type()
   205  		lt := getLogicalType(schema)
   206  		switch {
   207  		case st == Int && lt == TimeMillis: // time.Duration
   208  			return &timeMillisCodec{}
   209  
   210  		case st == Long && lt == TimeMicros: // time.Duration
   211  			return &timeMicrosCodec{}
   212  
   213  		case st == Long:
   214  			return &longCodec[int64]{}
   215  
   216  		default:
   217  			break
   218  		}
   219  
   220  	case reflect.Float32:
   221  		switch schema.Type() {
   222  		case Double:
   223  			return &float32DoubleCodec{}
   224  
   225  		case Float:
   226  			return &float32Codec{}
   227  		}
   228  
   229  	case reflect.Float64:
   230  		if schema.Type() != Double {
   231  			break
   232  		}
   233  		return &float64Codec{}
   234  
   235  	case reflect.String:
   236  		if schema.Type() != String {
   237  			break
   238  		}
   239  		return &stringCodec{}
   240  
   241  	case reflect.Slice:
   242  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
   243  			break
   244  		}
   245  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
   246  
   247  	case reflect.Struct:
   248  		st := schema.Type()
   249  		lt := getLogicalType(schema)
   250  		tpy1 := typ.Type1()
   251  		isTpy1Time := tpy1.ConvertibleTo(timeType)
   252  		isTpy1Rat := tpy1.ConvertibleTo(ratType)
   253  		switch {
   254  		case isTpy1Time && st == Int && lt == Date:
   255  			return &dateCodec{}
   256  		case isTpy1Time && st == Long && lt == TimestampMillis:
   257  			return &timestampMillisCodec{}
   258  
   259  		case isTpy1Time && st == Long && lt == TimestampMicros:
   260  			return &timestampMicrosCodec{}
   261  
   262  		case isTpy1Rat && st != Bytes || lt == Decimal:
   263  			ls := getLogicalSchema(schema)
   264  			dec := ls.(*DecimalLogicalSchema)
   265  
   266  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   267  
   268  		default:
   269  			break
   270  		}
   271  
   272  	case reflect.Ptr:
   273  		ptrType := typ.(*reflect2.UnsafePtrType)
   274  		elemType := ptrType.Elem()
   275  		tpy1 := elemType.Type1()
   276  		ls := getLogicalSchema(schema)
   277  		if ls == nil {
   278  			break
   279  		}
   280  		if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal {
   281  			break
   282  		}
   283  		dec := ls.(*DecimalLogicalSchema)
   284  
   285  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   286  	default:
   287  		break
   288  	}
   289  
   290  	if schema.Type() == Null {
   291  		return &nullCodec{}
   292  	}
   293  
   294  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   295  }
   296  
   297  func getLogicalSchema(schema Schema) LogicalSchema {
   298  	lts, ok := schema.(LogicalTypeSchema)
   299  	if !ok {
   300  		return nil
   301  	}
   302  
   303  	return lts.Logical()
   304  }
   305  
   306  func getLogicalType(schema Schema) LogicalType {
   307  	ls := getLogicalSchema(schema)
   308  	if ls == nil {
   309  		return ""
   310  	}
   311  
   312  	return ls.Type()
   313  }
   314  
   315  type nullCodec struct{}
   316  
   317  func (*nullCodec) Encode(unsafe.Pointer, *Writer) {}
   318  
   319  type boolCodec struct{}
   320  
   321  func (*boolCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   322  	*((*bool)(ptr)) = r.ReadBool()
   323  }
   324  
   325  func (*boolCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   326  	w.WriteBool(*((*bool)(ptr)))
   327  }
   328  
   329  type smallInt interface {
   330  	~int | ~int8 | ~int16 | ~int32 | ~uint | ~uint8 | ~uint16
   331  }
   332  
   333  type intCodec[T smallInt] struct{}
   334  
   335  func (*intCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) {
   336  	*((*T)(ptr)) = T(r.ReadInt())
   337  }
   338  
   339  func (*intCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) {
   340  	w.WriteInt(int32(*((*T)(ptr))))
   341  }
   342  
   343  type largeInt interface {
   344  	~int32 | ~uint32 | int64
   345  }
   346  
   347  type longCodec[T largeInt] struct{}
   348  
   349  func (*longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) {
   350  	*((*T)(ptr)) = T(r.ReadLong())
   351  }
   352  
   353  func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) {
   354  	w.WriteLong(int64(*((*T)(ptr))))
   355  }
   356  
   357  type float32Codec struct{}
   358  
   359  func (*float32Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   360  	*((*float32)(ptr)) = r.ReadFloat()
   361  }
   362  
   363  func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   364  	w.WriteFloat(*((*float32)(ptr)))
   365  }
   366  
   367  type float32DoubleCodec struct{}
   368  
   369  func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   370  	w.WriteDouble(float64(*((*float32)(ptr))))
   371  }
   372  
   373  type float64Codec struct{}
   374  
   375  func (*float64Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   376  	*((*float64)(ptr)) = r.ReadDouble()
   377  }
   378  
   379  func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   380  	w.WriteDouble(*((*float64)(ptr)))
   381  }
   382  
   383  type stringCodec struct{}
   384  
   385  func (*stringCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   386  	*((*string)(ptr)) = r.ReadString()
   387  }
   388  
   389  func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   390  	w.WriteString(*((*string)(ptr)))
   391  }
   392  
   393  type bytesCodec struct {
   394  	sliceType *reflect2.UnsafeSliceType
   395  }
   396  
   397  func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   398  	b := r.ReadBytes()
   399  	c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b))
   400  }
   401  
   402  func (c *bytesCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   403  	w.WriteBytes(*((*[]byte)(ptr)))
   404  }
   405  
   406  type dateCodec struct{}
   407  
   408  func (c *dateCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   409  	i := r.ReadInt()
   410  	sec := int64(i) * int64(24*time.Hour/time.Second)
   411  	*((*time.Time)(ptr)) = time.Unix(sec, 0).UTC()
   412  }
   413  
   414  func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   415  	t := *((*time.Time)(ptr))
   416  	days := t.Unix() / int64(24*time.Hour/time.Second)
   417  	w.WriteInt(int32(days))
   418  }
   419  
   420  type timestampMillisCodec struct{}
   421  
   422  func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   423  	i := r.ReadLong()
   424  	sec := i / 1e3
   425  	nsec := (i - sec*1e3) * 1e6
   426  	*((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC()
   427  }
   428  
   429  func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   430  	t := *((*time.Time)(ptr))
   431  	w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
   432  }
   433  
   434  type timestampMicrosCodec struct{}
   435  
   436  func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   437  	i := r.ReadLong()
   438  	sec := i / 1e6
   439  	nsec := (i - sec*1e6) * 1e3
   440  	*((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC()
   441  }
   442  
   443  func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   444  	t := *((*time.Time)(ptr))
   445  	w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3))
   446  }
   447  
   448  type timeMillisCodec struct{}
   449  
   450  func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   451  	i := r.ReadInt()
   452  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Millisecond
   453  }
   454  
   455  func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   456  	d := *((*time.Duration)(ptr))
   457  	w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond)))
   458  }
   459  
   460  type timeMicrosCodec struct{}
   461  
   462  func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   463  	i := r.ReadLong()
   464  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond
   465  }
   466  
   467  func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   468  	d := *((*time.Duration)(ptr))
   469  	w.WriteLong(d.Nanoseconds() / int64(time.Microsecond))
   470  }
   471  
   472  var one = big.NewInt(1)
   473  
   474  type bytesDecimalCodec struct {
   475  	prec  int
   476  	scale int
   477  }
   478  
   479  func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   480  	b := r.ReadBytes()
   481  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   482  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   483  	}
   484  	*((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale)
   485  }
   486  
   487  func ratFromBytes(b []byte, scale int) *big.Rat {
   488  	num := (&big.Int{}).SetBytes(b)
   489  	if len(b) > 0 && b[0]&0x80 > 0 {
   490  		num.Sub(num, new(big.Int).Lsh(one, uint(len(b))*8))
   491  	}
   492  	denom := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(scale)), nil)
   493  	return new(big.Rat).SetFrac(num, denom)
   494  }
   495  
   496  func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   497  	r := (*big.Rat)(ptr)
   498  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   499  	i := (&big.Int{}).Mul(r.Num(), scale)
   500  	i = i.Div(i, r.Denom())
   501  
   502  	var b []byte
   503  	switch i.Sign() {
   504  	case 0:
   505  		b = []byte{0}
   506  
   507  	case 1:
   508  		b = i.Bytes()
   509  		if b[0]&0x80 > 0 {
   510  			b = append([]byte{0}, b...)
   511  		}
   512  
   513  	case -1:
   514  		length := uint(i.BitLen()/8+1) * 8
   515  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   516  	}
   517  	w.WriteBytes(b)
   518  }
   519  
   520  type bytesDecimalPtrCodec struct {
   521  	prec  int
   522  	scale int
   523  }
   524  
   525  func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   526  	b := r.ReadBytes()
   527  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   528  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   529  	}
   530  	*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
   531  }
   532  
   533  func (c *bytesDecimalPtrCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   534  	r := *((**big.Rat)(ptr))
   535  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   536  	i := (&big.Int{}).Mul(r.Num(), scale)
   537  	i = i.Div(i, r.Denom())
   538  
   539  	var b []byte
   540  	switch i.Sign() {
   541  	case 0:
   542  		b = []byte{0}
   543  
   544  	case 1:
   545  		b = i.Bytes()
   546  		if b[0]&0x80 > 0 {
   547  			b = append([]byte{0}, b...)
   548  		}
   549  
   550  	case -1:
   551  		length := uint(i.BitLen()/8+1) * 8
   552  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   553  	}
   554  	w.WriteBytes(b)
   555  }