github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/codec_native.go (about) 1 package avro 2 3 import ( 4 "fmt" 5 "math/big" 6 "reflect" 7 "time" 8 "unsafe" 9 10 "github.com/modern-go/reflect2" 11 ) 12 13 //nolint:maintidx // Splitting this would not make it simpler. 14 func createDecoderOfNative(schema *PrimitiveSchema, typ reflect2.Type) ValDecoder { 15 resolved := schema.encodedType != "" 16 switch typ.Kind() { 17 case reflect.Bool: 18 if schema.Type() != Boolean { 19 break 20 } 21 return &boolCodec{} 22 23 case reflect.Int: 24 if schema.Type() != Int { 25 break 26 } 27 return &intCodec[int]{} 28 29 case reflect.Int8: 30 if schema.Type() != Int { 31 break 32 } 33 return &intCodec[int8]{} 34 35 case reflect.Uint8: 36 if schema.Type() != Int { 37 break 38 } 39 return &intCodec[uint8]{} 40 41 case reflect.Int16: 42 if schema.Type() != Int { 43 break 44 } 45 return &intCodec[int16]{} 46 47 case reflect.Uint16: 48 if schema.Type() != Int { 49 break 50 } 51 return &intCodec[uint16]{} 52 53 case reflect.Int32: 54 if schema.Type() != Int { 55 break 56 } 57 return &intCodec[int32]{} 58 59 case reflect.Uint32: 60 if schema.Type() != Long { 61 break 62 } 63 if resolved { 64 return &longConvCodec[uint32]{convert: createLongConverter(schema.encodedType)} 65 } 66 return &longCodec[uint32]{} 67 68 case reflect.Int64: 69 st := schema.Type() 70 lt := getLogicalType(schema) 71 switch { 72 case st == Int && lt == TimeMillis: // time.Duration 73 return &timeMillisCodec{} 74 75 case st == Long && lt == TimeMicros: // time.Duration 76 return &timeMicrosCodec{ 77 convert: createLongConverter(schema.encodedType), 78 } 79 80 case st == Long && lt == "": 81 if resolved { 82 return &longConvCodec[int64]{convert: createLongConverter(schema.encodedType)} 83 } 84 return &longCodec[int64]{} 85 86 case lt != "": 87 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s and logicalType %s", 88 typ.String(), schema.Type(), lt)} 89 90 default: 91 break 92 } 93 94 case reflect.Float32: 95 if schema.Type() != Float { 96 break 97 } 98 if resolved { 99 return &float32ConvCodec{convert: createFloatConverter(schema.encodedType)} 100 } 101 return &float32Codec{} 102 103 case reflect.Float64: 104 if schema.Type() != Double { 105 break 106 } 107 if resolved { 108 return &float64ConvCodec{convert: createDoubleConverter(schema.encodedType)} 109 } 110 return &float64Codec{} 111 112 case reflect.String: 113 if schema.Type() != String { 114 break 115 } 116 return &stringCodec{} 117 118 case reflect.Slice: 119 if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes { 120 break 121 } 122 return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)} 123 124 case reflect.Struct: 125 st := schema.Type() 126 ls := getLogicalSchema(schema) 127 lt := getLogicalType(schema) 128 isTime := typ.Type1().ConvertibleTo(timeType) 129 switch { 130 case isTime && st == Int && lt == Date: 131 return &dateCodec{} 132 case isTime && st == Long && lt == TimestampMillis: 133 return ×tampMillisCodec{ 134 convert: createLongConverter(schema.encodedType), 135 } 136 case isTime && st == Long && lt == TimestampMicros: 137 return ×tampMicrosCodec{ 138 convert: createLongConverter(schema.encodedType), 139 } 140 case isTime && st == Long && lt == LocalTimestampMillis: 141 return ×tampMillisCodec{ 142 local: true, 143 convert: createLongConverter(schema.encodedType), 144 } 145 case isTime && st == Long && lt == LocalTimestampMicros: 146 return ×tampMicrosCodec{ 147 local: true, 148 convert: createLongConverter(schema.encodedType), 149 } 150 case typ.Type1().ConvertibleTo(ratType) && st == Bytes && lt == Decimal: 151 dec := ls.(*DecimalLogicalSchema) 152 return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()} 153 154 default: 155 break 156 } 157 case reflect.Ptr: 158 ptrType := typ.(*reflect2.UnsafePtrType) 159 elemType := ptrType.Elem() 160 tpy1 := elemType.Type1() 161 ls := getLogicalSchema(schema) 162 if ls == nil { 163 break 164 } 165 if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal { 166 break 167 } 168 dec := ls.(*DecimalLogicalSchema) 169 170 return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()} 171 } 172 173 return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 174 } 175 176 //nolint:maintidx // Splitting this would not make it simpler. 177 func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder { 178 switch typ.Kind() { 179 case reflect.Bool: 180 if schema.Type() != Boolean { 181 break 182 } 183 return &boolCodec{} 184 185 case reflect.Int: 186 if schema.Type() != Int { 187 break 188 } 189 return &intCodec[int]{} 190 191 case reflect.Int8: 192 if schema.Type() != Int { 193 break 194 } 195 return &intCodec[int8]{} 196 197 case reflect.Uint8: 198 if schema.Type() != Int { 199 break 200 } 201 return &intCodec[uint8]{} 202 203 case reflect.Int16: 204 if schema.Type() != Int { 205 break 206 } 207 return &intCodec[int16]{} 208 209 case reflect.Uint16: 210 if schema.Type() != Int { 211 break 212 } 213 return &intCodec[uint16]{} 214 215 case reflect.Int32: 216 switch schema.Type() { 217 case Long: 218 return &longCodec[int32]{} 219 220 case Int: 221 return &intCodec[int32]{} 222 } 223 224 case reflect.Uint32: 225 if schema.Type() != Long { 226 break 227 } 228 return &longCodec[uint32]{} 229 230 case reflect.Int64: 231 st := schema.Type() 232 lt := getLogicalType(schema) 233 switch { 234 case st == Int && lt == TimeMillis: // time.Duration 235 return &timeMillisCodec{} 236 237 case st == Long && lt == TimeMicros: // time.Duration 238 return &timeMicrosCodec{} 239 240 case st == Long && lt == "": 241 return &longCodec[int64]{} 242 243 case lt != "": 244 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s and logicalType %s", 245 typ.String(), schema.Type(), lt)} 246 247 default: 248 break 249 } 250 251 case reflect.Float32: 252 switch schema.Type() { 253 case Double: 254 return &float32DoubleCodec{} 255 case Float: 256 return &float32Codec{} 257 } 258 259 case reflect.Float64: 260 if schema.Type() != Double { 261 break 262 } 263 return &float64Codec{} 264 265 case reflect.String: 266 if schema.Type() != String { 267 break 268 } 269 return &stringCodec{} 270 271 case reflect.Slice: 272 if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes { 273 break 274 } 275 return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)} 276 277 case reflect.Struct: 278 st := schema.Type() 279 lt := getLogicalType(schema) 280 isTime := typ.Type1().ConvertibleTo(timeType) 281 switch { 282 case isTime && st == Int && lt == Date: 283 return &dateCodec{} 284 case isTime && st == Long && lt == TimestampMillis: 285 return ×tampMillisCodec{} 286 case isTime && st == Long && lt == TimestampMicros: 287 return ×tampMicrosCodec{} 288 case isTime && st == Long && lt == LocalTimestampMillis: 289 return ×tampMillisCodec{local: true} 290 case isTime && st == Long && lt == LocalTimestampMicros: 291 return ×tampMicrosCodec{local: true} 292 case typ.Type1().ConvertibleTo(ratType) && st != Bytes || lt == Decimal: 293 ls := getLogicalSchema(schema) 294 dec := ls.(*DecimalLogicalSchema) 295 return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()} 296 default: 297 break 298 } 299 300 case reflect.Ptr: 301 ptrType := typ.(*reflect2.UnsafePtrType) 302 elemType := ptrType.Elem() 303 tpy1 := elemType.Type1() 304 ls := getLogicalSchema(schema) 305 if ls == nil { 306 break 307 } 308 if !tpy1.ConvertibleTo(ratType) || schema.Type() != Bytes || ls.Type() != Decimal { 309 break 310 } 311 dec := ls.(*DecimalLogicalSchema) 312 313 return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()} 314 } 315 316 if schema.Type() == Null { 317 return &nullCodec{} 318 } 319 320 return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} 321 } 322 323 func getLogicalSchema(schema Schema) LogicalSchema { 324 lts, ok := schema.(LogicalTypeSchema) 325 if !ok { 326 return nil 327 } 328 329 return lts.Logical() 330 } 331 332 func getLogicalType(schema Schema) LogicalType { 333 ls := getLogicalSchema(schema) 334 if ls == nil { 335 return "" 336 } 337 338 return ls.Type() 339 } 340 341 type nullCodec struct{} 342 343 func (*nullCodec) Encode(unsafe.Pointer, *Writer) {} 344 345 type boolCodec struct{} 346 347 func (*boolCodec) Decode(ptr unsafe.Pointer, r *Reader) { 348 *((*bool)(ptr)) = r.ReadBool() 349 } 350 351 func (*boolCodec) Encode(ptr unsafe.Pointer, w *Writer) { 352 w.WriteBool(*((*bool)(ptr))) 353 } 354 355 type smallInt interface { 356 ~int | ~int8 | ~int16 | ~int32 | ~uint | ~uint8 | ~uint16 357 } 358 359 type intCodec[T smallInt] struct{} 360 361 func (*intCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { 362 *((*T)(ptr)) = T(r.ReadInt()) 363 } 364 365 func (*intCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { 366 w.WriteInt(int32(*((*T)(ptr)))) 367 } 368 369 type largeInt interface { 370 ~int32 | ~uint32 | int64 371 } 372 373 type longCodec[T largeInt] struct{} 374 375 func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { 376 *((*T)(ptr)) = T(r.ReadLong()) 377 } 378 379 func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { 380 w.WriteLong(int64(*((*T)(ptr)))) 381 } 382 383 type longConvCodec[T largeInt] struct { 384 convert func(*Reader) int64 385 } 386 387 func (c *longConvCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { 388 *((*T)(ptr)) = T(c.convert(r)) 389 } 390 391 type float32Codec struct{} 392 393 func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { 394 *((*float32)(ptr)) = r.ReadFloat() 395 } 396 397 func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) { 398 w.WriteFloat(*((*float32)(ptr))) 399 } 400 401 type float32ConvCodec struct { 402 convert func(*Reader) float32 403 } 404 405 func (c *float32ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) { 406 *((*float32)(ptr)) = c.convert(r) 407 } 408 409 type float32DoubleCodec struct{} 410 411 func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { 412 w.WriteDouble(float64(*((*float32)(ptr)))) 413 } 414 415 type float64Codec struct{} 416 417 func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { 418 *((*float64)(ptr)) = r.ReadDouble() 419 } 420 421 func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { 422 w.WriteDouble(*((*float64)(ptr))) 423 } 424 425 type float64ConvCodec struct { 426 convert func(*Reader) float64 427 } 428 429 func (c *float64ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) { 430 *((*float64)(ptr)) = c.convert(r) 431 } 432 433 type stringCodec struct{} 434 435 func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { 436 *((*string)(ptr)) = r.ReadString() 437 } 438 439 func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { 440 w.WriteString(*((*string)(ptr))) 441 } 442 443 type bytesCodec struct { 444 sliceType *reflect2.UnsafeSliceType 445 } 446 447 func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { 448 b := r.ReadBytes() 449 c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) 450 } 451 452 func (c *bytesCodec) Encode(ptr unsafe.Pointer, w *Writer) { 453 w.WriteBytes(*((*[]byte)(ptr))) 454 } 455 456 type dateCodec struct{} 457 458 func (c *dateCodec) Decode(ptr unsafe.Pointer, r *Reader) { 459 i := r.ReadInt() 460 sec := int64(i) * int64(24*time.Hour/time.Second) 461 *((*time.Time)(ptr)) = time.Unix(sec, 0).UTC() 462 } 463 464 func (c *dateCodec) Encode(ptr unsafe.Pointer, w *Writer) { 465 t := *((*time.Time)(ptr)) 466 days := t.Unix() / int64(24*time.Hour/time.Second) 467 w.WriteInt(int32(days)) 468 } 469 470 type timestampMillisCodec struct { 471 local bool 472 convert func(*Reader) int64 473 } 474 475 func (c *timestampMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { 476 var i int64 477 if c.convert != nil { 478 i = c.convert(r) 479 } else { 480 i = r.ReadLong() 481 } 482 sec := i / 1e3 483 nsec := (i - sec*1e3) * 1e6 484 t := time.Unix(sec, nsec) 485 486 if c.local { 487 // When doing unix time, Go will convert the time from UTC to Local, 488 // changing the time by the number of seconds in the zone offset. 489 // Remove those added seconds. 490 _, offset := t.Zone() 491 t = t.Add(time.Duration(-1*offset) * time.Second) 492 *((*time.Time)(ptr)) = t 493 return 494 } 495 *((*time.Time)(ptr)) = t.UTC() 496 } 497 498 func (c *timestampMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { 499 t := *((*time.Time)(ptr)) 500 if c.local { 501 t = t.Local() 502 t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) 503 } 504 w.WriteLong(t.Unix()*1e3 + int64(t.Nanosecond()/1e6)) 505 } 506 507 type timestampMicrosCodec struct { 508 local bool 509 convert func(*Reader) int64 510 } 511 512 func (c *timestampMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { 513 var i int64 514 if c.convert != nil { 515 i = c.convert(r) 516 } else { 517 i = r.ReadLong() 518 } 519 sec := i / 1e6 520 nsec := (i - sec*1e6) * 1e3 521 t := time.Unix(sec, nsec) 522 523 if c.local { 524 // When doing unix time, Go will convert the time from UTC to Local, 525 // changing the time by the number of seconds in the zone offset. 526 // Remove those added seconds. 527 _, offset := t.Zone() 528 t = t.Add(time.Duration(-1*offset) * time.Second) 529 *((*time.Time)(ptr)) = t 530 return 531 } 532 *((*time.Time)(ptr)) = t.UTC() 533 } 534 535 func (c *timestampMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { 536 t := *((*time.Time)(ptr)) 537 if c.local { 538 t = t.Local() 539 t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.UTC) 540 } 541 w.WriteLong(t.Unix()*1e6 + int64(t.Nanosecond()/1e3)) 542 } 543 544 type timeMillisCodec struct{} 545 546 func (c *timeMillisCodec) Decode(ptr unsafe.Pointer, r *Reader) { 547 i := r.ReadInt() 548 *((*time.Duration)(ptr)) = time.Duration(i) * time.Millisecond 549 } 550 551 func (c *timeMillisCodec) Encode(ptr unsafe.Pointer, w *Writer) { 552 d := *((*time.Duration)(ptr)) 553 w.WriteInt(int32(d.Nanoseconds() / int64(time.Millisecond))) 554 } 555 556 type timeMicrosCodec struct { 557 convert func(*Reader) int64 558 } 559 560 func (c *timeMicrosCodec) Decode(ptr unsafe.Pointer, r *Reader) { 561 var i int64 562 if c.convert != nil { 563 i = c.convert(r) 564 } else { 565 i = r.ReadLong() 566 } 567 *((*time.Duration)(ptr)) = time.Duration(i) * time.Microsecond 568 } 569 570 func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { 571 d := *((*time.Duration)(ptr)) 572 w.WriteLong(d.Nanoseconds() / int64(time.Microsecond)) 573 } 574 575 var one = big.NewInt(1) 576 577 type bytesDecimalCodec struct { 578 prec int 579 scale int 580 } 581 582 func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { 583 b := r.ReadBytes() 584 if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { 585 i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) 586 } 587 *((*big.Rat)(ptr)) = *ratFromBytes(b, c.scale) 588 } 589 590 func ratFromBytes(b []byte, scale int) *big.Rat { 591 num := (&big.Int{}).SetBytes(b) 592 if len(b) > 0 && b[0]&0x80 > 0 { 593 num.Sub(num, new(big.Int).Lsh(one, uint(len(b))*8)) 594 } 595 denom := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(scale)), nil) 596 return new(big.Rat).SetFrac(num, denom) 597 } 598 599 func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { 600 r := (*big.Rat)(ptr) 601 scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil) 602 i := (&big.Int{}).Mul(r.Num(), scale) 603 i = i.Div(i, r.Denom()) 604 605 var b []byte 606 switch i.Sign() { 607 case 0: 608 b = []byte{0} 609 610 case 1: 611 b = i.Bytes() 612 if b[0]&0x80 > 0 { 613 b = append([]byte{0}, b...) 614 } 615 616 case -1: 617 length := uint(i.BitLen()/8+1) * 8 618 b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes() 619 } 620 w.WriteBytes(b) 621 } 622 623 type bytesDecimalPtrCodec struct { 624 prec int 625 scale int 626 } 627 628 func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { 629 b := r.ReadBytes() 630 if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { 631 i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) 632 } 633 *((**big.Rat)(ptr)) = ratFromBytes(b, c.scale) 634 } 635 636 func (c *bytesDecimalPtrCodec) Encode(ptr unsafe.Pointer, w *Writer) { 637 r := *((**big.Rat)(ptr)) 638 scale := new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(c.scale)), nil) 639 i := (&big.Int{}).Mul(r.Num(), scale) 640 i = i.Div(i, r.Denom()) 641 642 var b []byte 643 switch i.Sign() { 644 case 0: 645 b = []byte{0} 646 647 case 1: 648 b = i.Bytes() 649 if b[0]&0x80 > 0 { 650 b = append([]byte{0}, b...) 651 } 652 653 case -1: 654 length := uint(i.BitLen()/8+1) * 8 655 b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes() 656 } 657 w.WriteBytes(b) 658 }