github.com/hamba/avro@v1.8.0/codec_native.go (about)

     1  package avro
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"reflect"
     7  	"time"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Bool:
    16  		if schema.Type() != Boolean {
    17  			break
    18  		}
    19  		return &boolCodec{}
    20  
    21  	case reflect.Int:
    22  		if schema.Type() != Int {
    23  			break
    24  		}
    25  		return &intCodec{}
    26  
    27  	case reflect.Int8:
    28  		if schema.Type() != Int {
    29  			break
    30  		}
    31  		return &int8Codec{}
    32  
    33  	case reflect.Int16:
    34  		if schema.Type() != Int {
    35  			break
    36  		}
    37  		return &int16Codec{}
    38  
    39  	case reflect.Int32:
    40  		if schema.Type() != Int {
    41  			break
    42  		}
    43  		return &int32Codec{}
    44  
    45  	case reflect.Int64:
    46  		st := schema.Type()
    47  		lt := getLogicalType(schema)
    48  		switch {
    49  		case st == Int && lt == TimeMillis: // time.Duration
    50  			return &timeMillisCodec{}
    51  
    52  		case st == Long && lt == TimeMicros: // time.Duration
    53  			return &timeMicrosCodec{}
    54  
    55  		case st == Long:
    56  			return &int64Codec{}
    57  
    58  		default:
    59  			break
    60  		}
    61  
    62  	case reflect.Float32:
    63  		if schema.Type() != Float {
    64  			break
    65  		}
    66  		return &float32Codec{}
    67  
    68  	case reflect.Float64:
    69  		if schema.Type() != Double {
    70  			break
    71  		}
    72  		return &float64Codec{}
    73  
    74  	case reflect.String:
    75  		if schema.Type() != String {
    76  			break
    77  		}
    78  		return &stringCodec{}
    79  
    80  	case reflect.Slice:
    81  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
    82  			break
    83  		}
    84  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
    85  
    86  	case reflect.Struct:
    87  		st := schema.Type()
    88  		ls := getLogicalSchema(schema)
    89  		lt := getLogicalType(schema)
    90  		switch {
    91  		case typ.RType() == timeRType && st == Int && lt == Date:
    92  			return &dateCodec{}
    93  
    94  		case typ.RType() == timeRType && st == Long && lt == TimestampMillis:
    95  			return &timestampMillisCodec{}
    96  
    97  		case typ.RType() == timeRType && st == Long && lt == TimestampMicros:
    98  			return &timestampMicrosCodec{}
    99  
   100  		case typ.RType() == ratRType && st == Bytes && lt == Decimal:
   101  			dec := ls.(*DecimalLogicalSchema)
   102  
   103  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   104  
   105  		default:
   106  			break
   107  		}
   108  	case reflect.Ptr:
   109  		ptrType := typ.(*reflect2.UnsafePtrType)
   110  		elemType := ptrType.Elem()
   111  
   112  		ls := getLogicalSchema(schema)
   113  		if ls == nil {
   114  			break
   115  		}
   116  		if elemType.RType() != ratRType || schema.Type() != Bytes || ls.Type() != Decimal {
   117  			break
   118  		}
   119  		dec := ls.(*DecimalLogicalSchema)
   120  
   121  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   122  	}
   123  
   124  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   125  }
   126  
   127  func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder {
   128  	switch typ.Kind() {
   129  	case reflect.Bool:
   130  		if schema.Type() != Boolean {
   131  			break
   132  		}
   133  		return &boolCodec{}
   134  
   135  	case reflect.Int:
   136  		if schema.Type() != Int {
   137  			break
   138  		}
   139  		return &intCodec{}
   140  
   141  	case reflect.Int8:
   142  		if schema.Type() != Int {
   143  			break
   144  		}
   145  		return &int8Codec{}
   146  
   147  	case reflect.Int16:
   148  		if schema.Type() != Int {
   149  			break
   150  		}
   151  		return &int16Codec{}
   152  
   153  	case reflect.Int32:
   154  		switch schema.Type() {
   155  		case Long:
   156  			return &int32LongCodec{}
   157  
   158  		case Int:
   159  			return &int32Codec{}
   160  		}
   161  
   162  	case reflect.Int64:
   163  		st := schema.Type()
   164  		lt := getLogicalType(schema)
   165  		switch {
   166  		case st == Int && lt == TimeMillis: // time.Duration
   167  			return &timeMillisCodec{}
   168  
   169  		case st == Long && lt == TimeMicros: // time.Duration
   170  			return &timeMicrosCodec{}
   171  
   172  		case st == Long:
   173  			return &int64Codec{}
   174  
   175  		default:
   176  			break
   177  		}
   178  
   179  	case reflect.Float32:
   180  		switch schema.Type() {
   181  		case Double:
   182  			return &float32DoubleCodec{}
   183  
   184  		case Float:
   185  			return &float32Codec{}
   186  		}
   187  
   188  	case reflect.Float64:
   189  		if schema.Type() != Double {
   190  			break
   191  		}
   192  		return &float64Codec{}
   193  
   194  	case reflect.String:
   195  		if schema.Type() != String {
   196  			break
   197  		}
   198  		return &stringCodec{}
   199  
   200  	case reflect.Slice:
   201  		if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes {
   202  			break
   203  		}
   204  		return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)}
   205  
   206  	case reflect.Struct:
   207  		st := schema.Type()
   208  		lt := getLogicalType(schema)
   209  		switch {
   210  		case typ.RType() == timeRType && st == Int && lt == Date:
   211  			return &dateCodec{}
   212  
   213  		case typ.RType() == timeRType && st == Long && lt == TimestampMillis:
   214  			return &timestampMillisCodec{}
   215  
   216  		case typ.RType() == timeRType && st == Long && lt == TimestampMicros:
   217  			return &timestampMicrosCodec{}
   218  
   219  		case typ.RType() == ratRType && st != Bytes || lt == Decimal:
   220  			ls := getLogicalSchema(schema)
   221  			dec := ls.(*DecimalLogicalSchema)
   222  
   223  			return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
   224  
   225  		default:
   226  			break
   227  		}
   228  
   229  	case reflect.Ptr:
   230  		ptrType := typ.(*reflect2.UnsafePtrType)
   231  		elemType := ptrType.Elem()
   232  
   233  		ls := getLogicalSchema(schema)
   234  		if ls == nil {
   235  			break
   236  		}
   237  		if elemType.RType() != ratRType || schema.Type() != Bytes || ls.Type() != Decimal {
   238  			break
   239  		}
   240  		dec := ls.(*DecimalLogicalSchema)
   241  
   242  		return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
   243  	}
   244  
   245  	if schema.Type() == Null {
   246  		return &nullCodec{}
   247  	}
   248  
   249  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
   250  }
   251  
   252  func getLogicalSchema(schema Schema) LogicalSchema {
   253  	lts, ok := schema.(LogicalTypeSchema)
   254  	if !ok {
   255  		return nil
   256  	}
   257  
   258  	return lts.Logical()
   259  }
   260  
   261  func getLogicalType(schema Schema) LogicalType {
   262  	ls := getLogicalSchema(schema)
   263  	if ls == nil {
   264  		return ""
   265  	}
   266  
   267  	return ls.Type()
   268  }
   269  
   270  type nullCodec struct{}
   271  
   272  func (*nullCodec) Encode(ptr unsafe.Pointer, w *Writer) {}
   273  
   274  type boolCodec struct{}
   275  
   276  func (*boolCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   277  	*((*bool)(ptr)) = r.ReadBool()
   278  }
   279  
   280  func (*boolCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   281  	w.WriteBool(*((*bool)(ptr)))
   282  }
   283  
   284  type intCodec struct{}
   285  
   286  func (*intCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   287  	*((*int)(ptr)) = int(r.ReadInt())
   288  }
   289  
   290  func (*intCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   291  	w.WriteInt(int32(*((*int)(ptr))))
   292  }
   293  
   294  type int8Codec struct{}
   295  
   296  func (*int8Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   297  	*((*int8)(ptr)) = int8(r.ReadInt())
   298  }
   299  
   300  func (*int8Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   301  	w.WriteInt(int32(*((*int8)(ptr))))
   302  }
   303  
   304  type int16Codec struct{}
   305  
   306  func (*int16Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   307  	*((*int16)(ptr)) = int16(r.ReadInt())
   308  }
   309  
   310  func (*int16Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   311  	w.WriteInt(int32(*((*int16)(ptr))))
   312  }
   313  
   314  type int32Codec struct{}
   315  
   316  func (*int32Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   317  	*((*int32)(ptr)) = r.ReadInt()
   318  }
   319  
   320  func (*int32Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   321  	w.WriteInt(*((*int32)(ptr)))
   322  }
   323  
   324  type int32LongCodec struct{}
   325  
   326  func (*int32LongCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   327  	w.WriteLong(int64(*((*int32)(ptr))))
   328  }
   329  
   330  type int64Codec struct{}
   331  
   332  func (*int64Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   333  	*((*int64)(ptr)) = r.ReadLong()
   334  }
   335  
   336  func (*int64Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   337  	w.WriteLong(*((*int64)(ptr)))
   338  }
   339  
   340  type float32Codec struct{}
   341  
   342  func (*float32Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   343  	*((*float32)(ptr)) = r.ReadFloat()
   344  }
   345  
   346  func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   347  	w.WriteFloat(*((*float32)(ptr)))
   348  }
   349  
   350  type float32DoubleCodec struct{}
   351  
   352  func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   353  	w.WriteDouble(float64(*((*float32)(ptr))))
   354  }
   355  
   356  type float64Codec struct{}
   357  
   358  func (*float64Codec) Decode(ptr unsafe.Pointer, r *Reader) {
   359  	*((*float64)(ptr)) = r.ReadDouble()
   360  }
   361  
   362  func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) {
   363  	w.WriteDouble(*((*float64)(ptr)))
   364  }
   365  
   366  type stringCodec struct{}
   367  
   368  func (*stringCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   369  	*((*string)(ptr)) = r.ReadString()
   370  }
   371  
   372  func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   373  	w.WriteString(*((*string)(ptr)))
   374  }
   375  
   376  type bytesCodec struct {
   377  	sliceType *reflect2.UnsafeSliceType
   378  }
   379  
   380  func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   381  	b := r.ReadBytes()
   382  	c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b))
   383  }
   384  
   385  func (c *bytesCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   386  	w.WriteBytes(*((*[]byte)(ptr)))
   387  }
   388  
   389  type dateCodec struct{}
   390  
   391  func (c *dateCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   392  	i := r.ReadInt()
   393  	sec := int64(i) * int64(24*time.Hour/time.Second)
   394  	*((*time.Time)(ptr)) = time.Unix(sec, 0).UTC()
   395  }
   396  
   397  func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   398  	t := *((*time.Time)(ptr))
   399  	days := t.Unix() / int64(24*time.Hour/time.Second)
   400  	w.WriteInt(int32(days))
   401  }
   402  
   403  type timestampMillisCodec struct{}
   404  
   405  func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   406  	i := r.ReadLong()
   407  	sec := i / 1e3
   408  	nsec := (i - sec*1e3) * 1e6
   409  	*((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC()
   410  }
   411  
   412  func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   413  	t := *((*time.Time)(ptr))
   414  	w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
   415  }
   416  
   417  type timestampMicrosCodec struct{}
   418  
   419  func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   420  	i := r.ReadLong()
   421  	sec := i / 1e6
   422  	nsec := (i - sec*1e6) * 1e3
   423  	*((*time.Time)(ptr)) = time.Unix(sec, nsec).UTC()
   424  }
   425  
   426  func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   427  	t := *((*time.Time)(ptr))
   428  	w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3))
   429  }
   430  
   431  type timeMillisCodec struct{}
   432  
   433  func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   434  	i := r.ReadInt()
   435  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Millisecond
   436  }
   437  
   438  func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   439  	d := *((*time.Duration)(ptr))
   440  	w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond)))
   441  }
   442  
   443  type timeMicrosCodec struct{}
   444  
   445  func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   446  	i := r.ReadLong()
   447  	*((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond
   448  }
   449  
   450  func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   451  	d := *((*time.Duration)(ptr))
   452  	w.WriteLong(d.Nanoseconds() / int64(time.Microsecond))
   453  }
   454  
   455  var one = big.NewInt(1)
   456  
   457  type bytesDecimalCodec struct {
   458  	prec  int
   459  	scale int
   460  }
   461  
   462  func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   463  	b := r.ReadBytes()
   464  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   465  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   466  	}
   467  	*((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale)
   468  }
   469  
   470  func ratFromBytes(b []byte, scale int) *big.Rat {
   471  	num := (&big.Int{}).SetBytes(b)
   472  	if len(b) > 0 && b[0]&0x80 > 0 {
   473  		num.Sub(num, new(big.Int).Lsh(one, uint(len(b))*8))
   474  	}
   475  	denom := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(scale)), nil)
   476  	return new(big.Rat).SetFrac(num, denom)
   477  }
   478  
   479  func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   480  	r := (*big.Rat)(ptr)
   481  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   482  	i := (&big.Int{}).Mul(r.Num(), scale)
   483  	i = i.Div(i, r.Denom())
   484  
   485  	var b []byte
   486  	switch i.Sign() {
   487  	case 0:
   488  		b = []byte{0}
   489  
   490  	case 1:
   491  		b = i.Bytes()
   492  		if b[0]&0x80 > 0 {
   493  			b = append([]byte{0}, b...)
   494  		}
   495  
   496  	case -1:
   497  		length := uint(i.BitLen()/8+1) * 8
   498  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   499  	}
   500  	w.WriteBytes(b)
   501  }
   502  
   503  type bytesDecimalPtrCodec struct {
   504  	prec  int
   505  	scale int
   506  }
   507  
   508  func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) {
   509  	b := r.ReadBytes()
   510  	if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
   511  		i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
   512  	}
   513  	*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
   514  }
   515  
   516  func (c *bytesDecimalPtrCodec) Encode(ptr unsafe.Pointer, w *Writer) {
   517  	r := *((**big.Rat)(ptr))
   518  	scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil)
   519  	i := (&big.Int{}).Mul(r.Num(), scale)
   520  	i = i.Div(i, r.Denom())
   521  
   522  	var b []byte
   523  	switch i.Sign() {
   524  	case 0:
   525  		b = []byte{0}
   526  
   527  	case 1:
   528  		b = i.Bytes()
   529  		if b[0]&0x80 > 0 {
   530  			b = append([]byte{0}, b...)
   531  		}
   532  
   533  	case -1:
   534  		length := uint(i.BitLen()/8+1) * 8
   535  		b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
   536  	}
   537  	w.WriteBytes(b)
   538  }