github.com/snowflakedb/gosnowflake@v1.9.0/converter.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "database/sql" 8 "database/sql/driver" 9 "encoding/hex" 10 "fmt" 11 "math" 12 "math/big" 13 "reflect" 14 "strconv" 15 "strings" 16 "time" 17 "unicode/utf8" 18 19 "github.com/apache/arrow/go/v15/arrow" 20 "github.com/apache/arrow/go/v15/arrow/array" 21 "github.com/apache/arrow/go/v15/arrow/compute" 22 "github.com/apache/arrow/go/v15/arrow/decimal128" 23 "github.com/apache/arrow/go/v15/arrow/memory" 24 ) 25 26 const format = "2006-01-02 15:04:05.999999999" 27 28 type timezoneType int 29 30 const ( 31 // TimestampNTZType denotes a NTZ timezoneType for array binds 32 TimestampNTZType timezoneType = iota 33 // TimestampLTZType denotes a LTZ timezoneType for array binds 34 TimestampLTZType 35 // TimestampTZType denotes a TZ timezoneType for array binds 36 TimestampTZType 37 // DateType denotes a date type for array binds 38 DateType 39 // TimeType denotes a time type for array binds 40 TimeType 41 ) 42 43 type snowflakeArrowBatchesTimestampOption int 44 45 const ( 46 // UseNanosecondTimestamp uses arrow.Timestamp in nanosecond precision, could cause ErrTooHighTimestampPrecision if arrow.Timestamp cannot fit original timestamp values. 47 UseNanosecondTimestamp snowflakeArrowBatchesTimestampOption = iota 48 // UseMicrosecondTimestamp uses arrow.Timestamp in microsecond precision 49 UseMicrosecondTimestamp 50 // UseMillisecondTimestamp uses arrow.Timestamp in millisecond precision 51 UseMillisecondTimestamp 52 // UseSecondTimestamp uses arrow.Timestamp in second precision 53 UseSecondTimestamp 54 // UseOriginalTimestamp uses original timestamp struct returned by Snowflake. It can be used in case arrow.Timestamp cannot fit original timestamp values. 55 UseOriginalTimestamp 56 ) 57 58 type interfaceArrayBinding struct { 59 hasTimezone bool 60 tzType timezoneType 61 timezoneTypeArray interface{} 62 } 63 64 func isInterfaceArrayBinding(t interface{}) bool { 65 switch t.(type) { 66 case interfaceArrayBinding: 67 return true 68 case *interfaceArrayBinding: 69 return true 70 default: 71 return false 72 } 73 } 74 75 // goTypeToSnowflake translates Go data type to Snowflake data type. 76 func goTypeToSnowflake(v driver.Value, tsmode snowflakeType) snowflakeType { 77 switch t := v.(type) { 78 case int64, sql.NullInt64: 79 return fixedType 80 case float64, sql.NullFloat64: 81 return realType 82 case bool, sql.NullBool: 83 return booleanType 84 case string, sql.NullString: 85 return textType 86 case []byte: 87 if tsmode == binaryType { 88 return binaryType // may be redundant but ensures BINARY type 89 } 90 if t == nil { 91 return nullType // invalid byte array. won't take as BINARY 92 } 93 if len(t) != 1 { 94 return unSupportedType 95 } 96 if _, err := dataTypeMode(t); err != nil { 97 return unSupportedType 98 } 99 return changeType 100 case time.Time, sql.NullTime: 101 return tsmode 102 } 103 if supportedArrayBind(&driver.NamedValue{Value: v}) { 104 return sliceType 105 } 106 return unSupportedType 107 } 108 109 // snowflakeTypeToGo translates Snowflake data type to Go data type. 110 func snowflakeTypeToGo(dbtype snowflakeType, scale int64) reflect.Type { 111 switch dbtype { 112 case fixedType: 113 if scale == 0 { 114 return reflect.TypeOf(int64(0)) 115 } 116 return reflect.TypeOf(float64(0)) 117 case realType: 118 return reflect.TypeOf(float64(0)) 119 case textType, variantType, objectType, arrayType: 120 return reflect.TypeOf("") 121 case dateType, timeType, timestampLtzType, timestampNtzType, timestampTzType: 122 return reflect.TypeOf(time.Now()) 123 case binaryType: 124 return reflect.TypeOf([]byte{}) 125 case booleanType: 126 return reflect.TypeOf(true) 127 } 128 logger.Errorf("unsupported dbtype is specified. %v", dbtype) 129 return reflect.TypeOf("") 130 } 131 132 // valueToString converts arbitrary golang type to a string. This is mainly used in binding data with placeholders 133 // in queries. 134 func valueToString(v driver.Value, tsmode snowflakeType) (*string, error) { 135 logger.Debugf("TYPE: %v, %v", reflect.TypeOf(v), reflect.ValueOf(v)) 136 if v == nil { 137 return nil, nil 138 } 139 v1 := reflect.ValueOf(v) 140 switch v1.Kind() { 141 case reflect.Bool: 142 s := strconv.FormatBool(v1.Bool()) 143 return &s, nil 144 case reflect.Int64: 145 s := strconv.FormatInt(v1.Int(), 10) 146 return &s, nil 147 case reflect.Float64: 148 s := strconv.FormatFloat(v1.Float(), 'g', -1, 32) 149 return &s, nil 150 case reflect.String: 151 s := v1.String() 152 return &s, nil 153 case reflect.Slice, reflect.Map: 154 if v1.IsNil() { 155 return nil, nil 156 } 157 if bd, ok := v.([]byte); ok { 158 if tsmode == binaryType { 159 s := hex.EncodeToString(bd) 160 return &s, nil 161 } 162 } 163 // TODO: is this good enough? 164 s := v1.String() 165 return &s, nil 166 case reflect.Struct: 167 switch typedVal := v.(type) { 168 case time.Time: 169 return timeTypeValueToString(typedVal, tsmode) 170 case sql.NullTime: 171 if !typedVal.Valid { 172 return nil, nil 173 } 174 return timeTypeValueToString(typedVal.Time, tsmode) 175 case sql.NullBool: 176 if !typedVal.Valid { 177 return nil, nil 178 } 179 s := strconv.FormatBool(typedVal.Bool) 180 return &s, nil 181 case sql.NullInt64: 182 if !typedVal.Valid { 183 return nil, nil 184 } 185 s := strconv.FormatInt(typedVal.Int64, 10) 186 return &s, nil 187 case sql.NullFloat64: 188 if !typedVal.Valid { 189 return nil, nil 190 } 191 s := strconv.FormatFloat(typedVal.Float64, 'g', -1, 32) 192 return &s, nil 193 case sql.NullString: 194 if !typedVal.Valid { 195 return nil, nil 196 } 197 return &typedVal.String, nil 198 } 199 } 200 return nil, fmt.Errorf("unsupported type: %v", v1.Kind()) 201 } 202 203 func timeTypeValueToString(tm time.Time, tsmode snowflakeType) (*string, error) { 204 switch tsmode { 205 case dateType: 206 _, offset := tm.Zone() 207 tm = tm.Add(time.Second * time.Duration(offset)) 208 s := strconv.FormatInt(tm.Unix()*1000, 10) 209 return &s, nil 210 case timeType: 211 s := fmt.Sprintf("%d", 212 (tm.Hour()*3600+tm.Minute()*60+tm.Second())*1e9+tm.Nanosecond()) 213 return &s, nil 214 case timestampNtzType, timestampLtzType: 215 unixTime, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Unix()), 10) 216 m, _ := new(big.Int).SetString(strconv.FormatInt(1e9, 10), 10) 217 unixTime.Mul(unixTime, m) 218 tmNanos, _ := new(big.Int).SetString(fmt.Sprintf("%d", tm.Nanosecond()), 10) 219 s := unixTime.Add(unixTime, tmNanos).String() 220 return &s, nil 221 case timestampTzType: 222 _, offset := tm.Zone() 223 s := fmt.Sprintf("%v %v", tm.UnixNano(), offset/60+1440) 224 return &s, nil 225 } 226 return nil, fmt.Errorf("unsupported time type: %v", tsmode) 227 } 228 229 // extractTimestamp extracts the internal timestamp data to epoch time in seconds and milliseconds 230 func extractTimestamp(srcValue *string) (sec int64, nsec int64, err error) { 231 logger.Debugf("SRC: %v", srcValue) 232 var i int 233 for i = 0; i < len(*srcValue); i++ { 234 if (*srcValue)[i] == '.' { 235 sec, err = strconv.ParseInt((*srcValue)[0:i], 10, 64) 236 if err != nil { 237 return 0, 0, err 238 } 239 break 240 } 241 } 242 if i == len(*srcValue) { 243 // no fraction 244 sec, err = strconv.ParseInt(*srcValue, 10, 64) 245 if err != nil { 246 return 0, 0, err 247 } 248 nsec = 0 249 } else { 250 s := (*srcValue)[i+1:] 251 nsec, err = strconv.ParseInt(s+strings.Repeat("0", 9-len(s)), 10, 64) 252 if err != nil { 253 return 0, 0, err 254 } 255 } 256 logger.Infof("sec: %v, nsec: %v", sec, nsec) 257 return sec, nsec, nil 258 } 259 260 // stringToValue converts a pointer of string data to an arbitrary golang variable 261 // This is mainly used in fetching data. 262 func stringToValue( 263 dest *driver.Value, 264 srcColumnMeta execResponseRowType, 265 srcValue *string, 266 loc *time.Location, 267 ) error { 268 if srcValue == nil { 269 logger.Debugf("snowflake data type: %v, raw value: nil", srcColumnMeta.Type) 270 *dest = nil 271 return nil 272 } 273 logger.Debugf("snowflake data type: %v, raw value: %v", srcColumnMeta.Type, *srcValue) 274 switch srcColumnMeta.Type { 275 case "text", "fixed", "real", "variant", "object": 276 *dest = *srcValue 277 return nil 278 case "date": 279 v, err := strconv.ParseInt(*srcValue, 10, 64) 280 if err != nil { 281 return err 282 } 283 *dest = time.Unix(v*86400, 0).UTC() 284 return nil 285 case "time": 286 sec, nsec, err := extractTimestamp(srcValue) 287 if err != nil { 288 return err 289 } 290 t0 := time.Time{} 291 *dest = t0.Add(time.Duration(sec*1e9 + nsec)) 292 return nil 293 case "timestamp_ntz": 294 sec, nsec, err := extractTimestamp(srcValue) 295 if err != nil { 296 return err 297 } 298 *dest = time.Unix(sec, nsec).UTC() 299 return nil 300 case "timestamp_ltz": 301 sec, nsec, err := extractTimestamp(srcValue) 302 if err != nil { 303 return err 304 } 305 if loc == nil { 306 loc = time.Now().Location() 307 } 308 *dest = time.Unix(sec, nsec).In(loc) 309 return nil 310 case "timestamp_tz": 311 logger.Debugf("tz: %v", *srcValue) 312 313 tm := strings.Split(*srcValue, " ") 314 if len(tm) != 2 { 315 return &SnowflakeError{ 316 Number: ErrInvalidTimestampTz, 317 SQLState: SQLStateInvalidDataTimeFormat, 318 Message: fmt.Sprintf("invalid TIMESTAMP_TZ data. The value doesn't consist of two numeric values separated by a space: %v", *srcValue), 319 } 320 } 321 sec, nsec, err := extractTimestamp(&tm[0]) 322 if err != nil { 323 return err 324 } 325 offset, err := strconv.ParseInt(tm[1], 10, 64) 326 if err != nil { 327 return &SnowflakeError{ 328 Number: ErrInvalidTimestampTz, 329 SQLState: SQLStateInvalidDataTimeFormat, 330 Message: fmt.Sprintf("invalid TIMESTAMP_TZ data. The offset value is not integer: %v", tm[1]), 331 } 332 } 333 loc := Location(int(offset) - 1440) 334 tt := time.Unix(sec, nsec) 335 *dest = tt.In(loc) 336 return nil 337 case "binary": 338 b, err := hex.DecodeString(*srcValue) 339 if err != nil { 340 return &SnowflakeError{ 341 Number: ErrInvalidBinaryHexForm, 342 SQLState: SQLStateNumericValueOutOfRange, 343 Message: err.Error(), 344 } 345 } 346 *dest = b 347 return nil 348 } 349 *dest = *srcValue 350 return nil 351 } 352 353 var decimalShift = new(big.Int).Exp(big.NewInt(2), big.NewInt(64), nil) 354 355 func intToBigFloat(val int64, scale int64) *big.Float { 356 f := new(big.Float).SetInt64(val) 357 s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) 358 return new(big.Float).Quo(f, s) 359 } 360 361 func decimalToBigInt(num decimal128.Num) *big.Int { 362 high := new(big.Int).SetInt64(num.HighBits()) 363 low := new(big.Int).SetUint64(num.LowBits()) 364 return new(big.Int).Add(new(big.Int).Mul(high, decimalShift), low) 365 } 366 367 func decimalToBigFloat(num decimal128.Num, scale int64) *big.Float { 368 f := new(big.Float).SetInt(decimalToBigInt(num)) 369 s := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(scale), nil)) 370 return new(big.Float).Quo(f, s) 371 } 372 373 // ArrowSnowflakeTimestampToTime converts original timestamp returned by Snowflake to time.Time 374 func (rb *ArrowBatch) ArrowSnowflakeTimestampToTime(rec arrow.Record, colIdx int, recIdx int) *time.Time { 375 scale := int(rb.scd.RowSet.RowType[colIdx].Scale) 376 dbType := rb.scd.RowSet.RowType[colIdx].Type 377 return arrowSnowflakeTimestampToTime(rec.Column(colIdx), getSnowflakeType(dbType), scale, recIdx, rb.loc) 378 } 379 380 func arrowSnowflakeTimestampToTime( 381 column arrow.Array, 382 sfType snowflakeType, 383 scale int, 384 recIdx int, 385 loc *time.Location) *time.Time { 386 387 if column.IsNull(recIdx) { 388 return nil 389 } 390 var ret time.Time 391 switch sfType { 392 case timestampNtzType: 393 if column.DataType().ID() == arrow.STRUCT { 394 structData := column.(*array.Struct) 395 epoch := structData.Field(0).(*array.Int64).Int64Values() 396 fraction := structData.Field(1).(*array.Int32).Int32Values() 397 ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).UTC() 398 } else { 399 intData := column.(*array.Int64) 400 value := intData.Value(recIdx) 401 epoch := extractEpoch(value, scale) 402 fraction := extractFraction(value, scale) 403 ret = time.Unix(epoch, fraction).UTC() 404 } 405 case timestampLtzType: 406 if column.DataType().ID() == arrow.STRUCT { 407 structData := column.(*array.Struct) 408 epoch := structData.Field(0).(*array.Int64).Int64Values() 409 fraction := structData.Field(1).(*array.Int32).Int32Values() 410 ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(loc) 411 } else { 412 intData := column.(*array.Int64) 413 value := intData.Value(recIdx) 414 epoch := extractEpoch(value, scale) 415 fraction := extractFraction(value, scale) 416 ret = time.Unix(epoch, fraction).In(loc) 417 } 418 case timestampTzType: 419 structData := column.(*array.Struct) 420 if structData.NumField() == 2 { 421 value := structData.Field(0).(*array.Int64).Int64Values() 422 timezone := structData.Field(1).(*array.Int32).Int32Values() 423 epoch := extractEpoch(value[recIdx], scale) 424 fraction := extractFraction(value[recIdx], scale) 425 locTz := Location(int(timezone[recIdx]) - 1440) 426 ret = time.Unix(epoch, fraction).In(locTz) 427 } else { 428 epoch := structData.Field(0).(*array.Int64).Int64Values() 429 fraction := structData.Field(1).(*array.Int32).Int32Values() 430 timezone := structData.Field(2).(*array.Int32).Int32Values() 431 locTz := Location(int(timezone[recIdx]) - 1440) 432 ret = time.Unix(epoch[recIdx], int64(fraction[recIdx])).In(locTz) 433 } 434 } 435 return &ret 436 } 437 438 func extractEpoch(value int64, scale int) int64 { 439 return value / int64(math.Pow10(scale)) 440 } 441 442 func extractFraction(value int64, scale int) int64 { 443 return (value % int64(math.Pow10(scale))) * int64(math.Pow10(9-scale)) 444 } 445 446 // Arrow Interface (Column) converter. This is called when Arrow chunks are 447 // downloaded to convert to the corresponding row type. 448 func arrowToValue( 449 destcol []snowflakeValue, 450 srcColumnMeta execResponseRowType, 451 srcValue arrow.Array, 452 loc *time.Location, 453 higherPrecision bool) error { 454 455 var err error 456 if len(destcol) != srcValue.Len() { 457 err = fmt.Errorf("array interface length mismatch") 458 } 459 logger.Debugf("snowflake data type: %v, arrow data type: %v", srcColumnMeta.Type, srcValue.DataType()) 460 461 snowflakeType := getSnowflakeType(srcColumnMeta.Type) 462 switch snowflakeType { 463 case fixedType: 464 // Snowflake data types that are fixed-point numbers will fall into this category 465 // e.g. NUMBER, DECIMAL/NUMERIC, INT/INTEGER 466 switch data := srcValue.(type) { 467 case *array.Decimal128: 468 for i, num := range data.Values() { 469 if !srcValue.IsNull(i) { 470 if srcColumnMeta.Scale == 0 { 471 if higherPrecision { 472 destcol[i] = num.BigInt() 473 } else { 474 destcol[i] = num.ToString(0) 475 } 476 } else { 477 f := decimalToBigFloat(num, srcColumnMeta.Scale) 478 if higherPrecision { 479 destcol[i] = f 480 } else { 481 destcol[i] = fmt.Sprintf("%.*f", srcColumnMeta.Scale, f) 482 } 483 } 484 } 485 } 486 case *array.Int64: 487 for i, val := range data.Int64Values() { 488 if !srcValue.IsNull(i) { 489 if srcColumnMeta.Scale == 0 { 490 if higherPrecision { 491 destcol[i] = val 492 } else { 493 destcol[i] = fmt.Sprintf("%d", val) 494 } 495 } else { 496 if higherPrecision { 497 f := intToBigFloat(val, srcColumnMeta.Scale) 498 destcol[i] = f 499 } else { 500 destcol[i] = fmt.Sprintf("%.*f", srcColumnMeta.Scale, float64(val)/math.Pow10(int(srcColumnMeta.Scale))) 501 } 502 } 503 } 504 } 505 case *array.Int32: 506 for i, val := range data.Int32Values() { 507 if !srcValue.IsNull(i) { 508 if srcColumnMeta.Scale == 0 { 509 if higherPrecision { 510 destcol[i] = int64(val) 511 } else { 512 destcol[i] = fmt.Sprintf("%d", val) 513 } 514 } else { 515 if higherPrecision { 516 f := intToBigFloat(int64(val), srcColumnMeta.Scale) 517 destcol[i] = f 518 } else { 519 destcol[i] = fmt.Sprintf("%.*f", srcColumnMeta.Scale, float64(val)/math.Pow10(int(srcColumnMeta.Scale))) 520 } 521 } 522 } 523 } 524 case *array.Int16: 525 for i, val := range data.Int16Values() { 526 if !srcValue.IsNull(i) { 527 if srcColumnMeta.Scale == 0 { 528 if higherPrecision { 529 destcol[i] = int64(val) 530 } else { 531 destcol[i] = fmt.Sprintf("%d", val) 532 } 533 } else { 534 if higherPrecision { 535 f := intToBigFloat(int64(val), srcColumnMeta.Scale) 536 destcol[i] = f 537 } else { 538 destcol[i] = fmt.Sprintf("%.*f", srcColumnMeta.Scale, float64(val)/math.Pow10(int(srcColumnMeta.Scale))) 539 } 540 } 541 } 542 } 543 case *array.Int8: 544 for i, val := range data.Int8Values() { 545 if !srcValue.IsNull(i) { 546 if srcColumnMeta.Scale == 0 { 547 if higherPrecision { 548 destcol[i] = int64(val) 549 } else { 550 destcol[i] = fmt.Sprintf("%d", val) 551 } 552 } else { 553 if higherPrecision { 554 f := intToBigFloat(int64(val), srcColumnMeta.Scale) 555 destcol[i] = f 556 } else { 557 destcol[i] = fmt.Sprintf("%.*f", srcColumnMeta.Scale, float64(val)/math.Pow10(int(srcColumnMeta.Scale))) 558 } 559 } 560 } 561 } 562 } 563 return err 564 case booleanType: 565 boolData := srcValue.(*array.Boolean) 566 for i := range destcol { 567 if !srcValue.IsNull(i) { 568 destcol[i] = boolData.Value(i) 569 } 570 } 571 return err 572 case realType: 573 // Snowflake data types that are floating-point numbers will fall in this category 574 // e.g. FLOAT/REAL/DOUBLE 575 for i, flt64 := range srcValue.(*array.Float64).Float64Values() { 576 if !srcValue.IsNull(i) { 577 destcol[i] = flt64 578 } 579 } 580 return err 581 case textType, arrayType, variantType, objectType: 582 strings := srcValue.(*array.String) 583 for i := range destcol { 584 if !srcValue.IsNull(i) { 585 destcol[i] = strings.Value(i) 586 } 587 } 588 return err 589 case binaryType: 590 binaryData := srcValue.(*array.Binary) 591 for i := range destcol { 592 if !srcValue.IsNull(i) { 593 destcol[i] = binaryData.Value(i) 594 } 595 } 596 return err 597 case dateType: 598 for i, date32 := range srcValue.(*array.Date32).Date32Values() { 599 if !srcValue.IsNull(i) { 600 t0 := time.Unix(int64(date32)*86400, 0).UTC() 601 destcol[i] = t0 602 } 603 } 604 return err 605 case timeType: 606 t0 := time.Time{} 607 if srcValue.DataType().ID() == arrow.INT64 { 608 for i, i64 := range srcValue.(*array.Int64).Int64Values() { 609 if !srcValue.IsNull(i) { 610 destcol[i] = t0.Add(time.Duration(i64 * int64(math.Pow10(9-int(srcColumnMeta.Scale))))) 611 } 612 } 613 } else { 614 for i, i32 := range srcValue.(*array.Int32).Int32Values() { 615 if !srcValue.IsNull(i) { 616 destcol[i] = t0.Add(time.Duration(int64(i32) * int64(math.Pow10(9-int(srcColumnMeta.Scale))))) 617 } 618 } 619 } 620 return err 621 case timestampNtzType, timestampLtzType, timestampTzType: 622 for i := range destcol { 623 var ts = arrowSnowflakeTimestampToTime(srcValue, snowflakeType, int(srcColumnMeta.Scale), i, loc) 624 if ts != nil { 625 destcol[i] = *ts 626 } 627 } 628 return err 629 } 630 631 return fmt.Errorf("unsupported data type") 632 } 633 634 type ( 635 intArray []int 636 int32Array []int32 637 int64Array []int64 638 float64Array []float64 639 float32Array []float32 640 boolArray []bool 641 stringArray []string 642 byteArray [][]byte 643 timestampNtzArray []time.Time 644 timestampLtzArray []time.Time 645 timestampTzArray []time.Time 646 dateArray []time.Time 647 timeArray []time.Time 648 ) 649 650 // Array takes in a column of a row to be inserted via array binding, bulk or 651 // otherwise, and converts it into a native snowflake type for binding 652 func Array(a interface{}, typ ...timezoneType) interface{} { 653 switch t := a.(type) { 654 case []int: 655 return (*intArray)(&t) 656 case []int32: 657 return (*int32Array)(&t) 658 case []int64: 659 return (*int64Array)(&t) 660 case []float64: 661 return (*float64Array)(&t) 662 case []float32: 663 return (*float32Array)(&t) 664 case []bool: 665 return (*boolArray)(&t) 666 case []string: 667 return (*stringArray)(&t) 668 case [][]byte: 669 return (*byteArray)(&t) 670 case []time.Time: 671 if len(typ) < 1 { 672 return a 673 } 674 switch typ[0] { 675 case TimestampNTZType: 676 return (*timestampNtzArray)(&t) 677 case TimestampLTZType: 678 return (*timestampLtzArray)(&t) 679 case TimestampTZType: 680 return (*timestampTzArray)(&t) 681 case DateType: 682 return (*dateArray)(&t) 683 case TimeType: 684 return (*timeArray)(&t) 685 default: 686 return a 687 } 688 case *[]int: 689 return (*intArray)(t) 690 case *[]int32: 691 return (*int32Array)(t) 692 case *[]int64: 693 return (*int64Array)(t) 694 case *[]float64: 695 return (*float64Array)(t) 696 case *[]float32: 697 return (*float32Array)(t) 698 case *[]bool: 699 return (*boolArray)(t) 700 case *[]string: 701 return (*stringArray)(t) 702 case *[][]byte: 703 return (*byteArray)(t) 704 case *[]time.Time: 705 if len(typ) < 1 { 706 return a 707 } 708 switch typ[0] { 709 case TimestampNTZType: 710 return (*timestampNtzArray)(t) 711 case TimestampLTZType: 712 return (*timestampLtzArray)(t) 713 case TimestampTZType: 714 return (*timestampTzArray)(t) 715 case DateType: 716 return (*dateArray)(t) 717 case TimeType: 718 return (*timeArray)(t) 719 default: 720 return a 721 } 722 case []interface{}, *[]interface{}: 723 // Support for bulk array binding insertion using []interface{} 724 if len(typ) < 1 { 725 return interfaceArrayBinding{ 726 hasTimezone: false, 727 timezoneTypeArray: a, 728 } 729 } 730 return interfaceArrayBinding{ 731 hasTimezone: true, 732 tzType: typ[0], 733 timezoneTypeArray: a, 734 } 735 default: 736 return a 737 } 738 } 739 740 // snowflakeArrayToString converts the array binding to snowflake's native 741 // string type. The string value differs whether it's directly bound or 742 // uploaded via stream. 743 func snowflakeArrayToString(nv *driver.NamedValue, stream bool) (snowflakeType, []*string) { 744 var t snowflakeType 745 var arr []*string 746 switch reflect.TypeOf(nv.Value) { 747 case reflect.TypeOf(&intArray{}): 748 t = fixedType 749 a := nv.Value.(*intArray) 750 for _, x := range *a { 751 v := strconv.Itoa(x) 752 arr = append(arr, &v) 753 } 754 case reflect.TypeOf(&int64Array{}): 755 t = fixedType 756 a := nv.Value.(*int64Array) 757 for _, x := range *a { 758 v := strconv.FormatInt(x, 10) 759 arr = append(arr, &v) 760 } 761 case reflect.TypeOf(&int32Array{}): 762 t = fixedType 763 a := nv.Value.(*int32Array) 764 for _, x := range *a { 765 v := strconv.Itoa(int(x)) 766 arr = append(arr, &v) 767 } 768 case reflect.TypeOf(&float64Array{}): 769 t = realType 770 a := nv.Value.(*float64Array) 771 for _, x := range *a { 772 v := fmt.Sprintf("%g", x) 773 arr = append(arr, &v) 774 } 775 case reflect.TypeOf(&float32Array{}): 776 t = realType 777 a := nv.Value.(*float32Array) 778 for _, x := range *a { 779 v := fmt.Sprintf("%g", x) 780 arr = append(arr, &v) 781 } 782 case reflect.TypeOf(&boolArray{}): 783 t = booleanType 784 a := nv.Value.(*boolArray) 785 for _, x := range *a { 786 v := strconv.FormatBool(x) 787 arr = append(arr, &v) 788 } 789 case reflect.TypeOf(&stringArray{}): 790 t = textType 791 a := nv.Value.(*stringArray) 792 for _, x := range *a { 793 v := x // necessary for address to be not overwritten 794 arr = append(arr, &v) 795 } 796 case reflect.TypeOf(&byteArray{}): 797 t = binaryType 798 a := nv.Value.(*byteArray) 799 for _, x := range *a { 800 v := hex.EncodeToString(x) 801 arr = append(arr, &v) 802 } 803 case reflect.TypeOf(×tampNtzArray{}): 804 t = timestampNtzType 805 a := nv.Value.(*timestampNtzArray) 806 for _, x := range *a { 807 v := strconv.FormatInt(x.UnixNano(), 10) 808 arr = append(arr, &v) 809 } 810 case reflect.TypeOf(×tampLtzArray{}): 811 t = timestampLtzType 812 a := nv.Value.(*timestampLtzArray) 813 for _, x := range *a { 814 v := strconv.FormatInt(x.UnixNano(), 10) 815 arr = append(arr, &v) 816 } 817 case reflect.TypeOf(×tampTzArray{}): 818 t = timestampTzType 819 a := nv.Value.(*timestampTzArray) 820 for _, x := range *a { 821 var v string 822 if stream { 823 v = x.Format(format) 824 } else { 825 _, offset := x.Zone() 826 v = fmt.Sprintf("%v %v", x.UnixNano(), offset/60+1440) 827 } 828 arr = append(arr, &v) 829 } 830 case reflect.TypeOf(&dateArray{}): 831 t = dateType 832 a := nv.Value.(*dateArray) 833 for _, x := range *a { 834 _, offset := x.Zone() 835 x = x.Add(time.Second * time.Duration(offset)) 836 v := fmt.Sprintf("%d", x.Unix()*1000) 837 arr = append(arr, &v) 838 } 839 case reflect.TypeOf(&timeArray{}): 840 t = timeType 841 a := nv.Value.(*timeArray) 842 for _, x := range *a { 843 var v string 844 if stream { 845 v = fmt.Sprintf("%02d:%02d:%02d.%09d", x.Hour(), x.Minute(), x.Second(), x.Nanosecond()) 846 } else { 847 h, m, s := x.Clock() 848 tm := int64(h)*int64(time.Hour) + int64(m)*int64(time.Minute) + int64(s)*int64(time.Second) + int64(x.Nanosecond()) 849 v = strconv.FormatInt(tm, 10) 850 } 851 arr = append(arr, &v) 852 } 853 default: 854 // Support for bulk array binding insertion using []interface{} 855 nvValue := reflect.ValueOf(nv) 856 if nvValue.Kind() == reflect.Ptr { 857 value := reflect.Indirect(reflect.ValueOf(nv.Value)) 858 if isInterfaceArrayBinding(value.Interface()) { 859 timeStruct, ok := value.Interface().(interfaceArrayBinding) 860 if ok { 861 timeInterfaceSlice := reflect.Indirect(reflect.ValueOf(timeStruct.timezoneTypeArray)) 862 if timeStruct.hasTimezone { 863 return interfaceSliceToString(timeInterfaceSlice, stream, timeStruct.tzType) 864 } 865 return interfaceSliceToString(timeInterfaceSlice, stream) 866 } 867 } 868 } 869 return unSupportedType, nil 870 } 871 return t, arr 872 } 873 874 func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType ...timezoneType) (snowflakeType, []*string) { 875 var t snowflakeType 876 var arr []*string 877 878 for i := 0; i < interfaceSlice.Len(); i++ { 879 val := interfaceSlice.Index(i) 880 if val.CanInterface() { 881 switch val.Interface().(type) { 882 case int: 883 t = fixedType 884 x := val.Interface().(int) 885 v := strconv.Itoa(x) 886 arr = append(arr, &v) 887 case int32: 888 t = fixedType 889 x := val.Interface().(int32) 890 v := strconv.Itoa(int(x)) 891 arr = append(arr, &v) 892 case int64: 893 t = fixedType 894 x := val.Interface().(int64) 895 v := strconv.FormatInt(x, 10) 896 arr = append(arr, &v) 897 case float32: 898 t = realType 899 x := val.Interface().(float32) 900 v := fmt.Sprintf("%g", x) 901 arr = append(arr, &v) 902 case float64: 903 t = realType 904 x := val.Interface().(float64) 905 v := fmt.Sprintf("%g", x) 906 arr = append(arr, &v) 907 case bool: 908 t = booleanType 909 x := val.Interface().(bool) 910 v := strconv.FormatBool(x) 911 arr = append(arr, &v) 912 case string: 913 t = textType 914 x := val.Interface().(string) 915 arr = append(arr, &x) 916 case []byte: 917 t = binaryType 918 x := val.Interface().([]byte) 919 v := hex.EncodeToString(x) 920 arr = append(arr, &v) 921 case time.Time: 922 if len(tzType) < 1 { 923 return unSupportedType, nil 924 } 925 926 x := val.Interface().(time.Time) 927 switch tzType[0] { 928 case TimestampNTZType: 929 t = timestampNtzType 930 v := strconv.FormatInt(x.UnixNano(), 10) 931 arr = append(arr, &v) 932 case TimestampLTZType: 933 t = timestampLtzType 934 v := strconv.FormatInt(x.UnixNano(), 10) 935 arr = append(arr, &v) 936 case TimestampTZType: 937 t = timestampTzType 938 var v string 939 if stream { 940 v = x.Format(format) 941 } else { 942 _, offset := x.Zone() 943 v = fmt.Sprintf("%v %v", x.UnixNano(), offset/60+1440) 944 } 945 arr = append(arr, &v) 946 case DateType: 947 t = dateType 948 _, offset := x.Zone() 949 x = x.Add(time.Second * time.Duration(offset)) 950 v := fmt.Sprintf("%d", x.Unix()*1000) 951 arr = append(arr, &v) 952 case TimeType: 953 t = timeType 954 var v string 955 if stream { 956 v = x.Format(format[11:19]) 957 } else { 958 h, m, s := x.Clock() 959 tm := int64(h)*int64(time.Hour) + int64(m)*int64(time.Minute) + int64(s)*int64(time.Second) + int64(x.Nanosecond()) 960 v = strconv.FormatInt(tm, 10) 961 } 962 arr = append(arr, &v) 963 default: 964 return unSupportedType, nil 965 } 966 default: 967 if val.Interface() != nil { 968 return unSupportedType, nil 969 } 970 971 arr = append(arr, nil) 972 } 973 } 974 } 975 return t, arr 976 } 977 978 func higherPrecisionEnabled(ctx context.Context) bool { 979 v := ctx.Value(enableHigherPrecision) 980 if v == nil { 981 return false 982 } 983 d, ok := v.(bool) 984 return ok && d 985 } 986 987 func arrowBatchesUtf8ValidationEnabled(ctx context.Context) bool { 988 v := ctx.Value(enableArrowBatchesUtf8Validation) 989 if v == nil { 990 return false 991 } 992 d, ok := v.(bool) 993 return ok && d 994 } 995 996 func getArrowBatchesTimestampOption(ctx context.Context) snowflakeArrowBatchesTimestampOption { 997 v := ctx.Value(arrowBatchesTimestampOption) 998 if v == nil { 999 return UseNanosecondTimestamp 1000 } 1001 o, ok := v.(snowflakeArrowBatchesTimestampOption) 1002 if !ok { 1003 return UseNanosecondTimestamp 1004 } 1005 return o 1006 } 1007 1008 func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) { 1009 arrowBatchesTimestampOption := getArrowBatchesTimestampOption(ctx) 1010 higherPrecisionEnabled := higherPrecisionEnabled(ctx) 1011 1012 s, err := recordToSchema(record.Schema(), rowType, loc, arrowBatchesTimestampOption, higherPrecisionEnabled) 1013 if err != nil { 1014 return nil, err 1015 } 1016 1017 var cols []arrow.Array 1018 numRows := record.NumRows() 1019 ctxAlloc := compute.WithAllocator(ctx, pool) 1020 1021 for i, col := range record.Columns() { 1022 srcColumnMeta := rowType[i] 1023 1024 // TODO: confirm that it is okay to be using higher precision logic for conversions 1025 newCol := col 1026 snowflakeType := getSnowflakeType(srcColumnMeta.Type) 1027 switch snowflakeType { 1028 case fixedType: 1029 var toType arrow.DataType 1030 if higherPrecisionEnabled { 1031 // do nothing - return decimal as is 1032 } else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 { 1033 if srcColumnMeta.Scale == 0 { 1034 toType = arrow.PrimitiveTypes.Int64 1035 } else { 1036 toType = arrow.PrimitiveTypes.Float64 1037 } 1038 // we're fine truncating so no error for data loss here. 1039 // so we use UnsafeCastOptions. 1040 newCol, err = compute.CastArray(ctxAlloc, col, compute.UnsafeCastOptions(toType)) 1041 if err != nil { 1042 return nil, err 1043 } 1044 defer newCol.Release() 1045 } else if srcColumnMeta.Scale != 0 && col.DataType().ID() != arrow.INT64 { 1046 result, err := compute.Divide(ctxAlloc, compute.ArithmeticOptions{NoCheckOverflow: true}, 1047 &compute.ArrayDatum{Value: newCol.Data()}, 1048 compute.NewDatum(math.Pow10(int(srcColumnMeta.Scale)))) 1049 if err != nil { 1050 return nil, err 1051 } 1052 defer result.Release() 1053 newCol = result.(*compute.ArrayDatum).MakeArray() 1054 defer newCol.Release() 1055 } else if srcColumnMeta.Scale != 0 && col.DataType().ID() == arrow.INT64 { 1056 // gosnowflake driver uses compute.Divide() which could bring `integer value not in range: -9007199254740992 to 9007199254740992` error 1057 // if we convert int64 to BigDecimal and then use compute.CastArray to convert BigDecimal to float64, we won't have enough precision. 1058 // e.g 0.1 as (38,19) will result 0.09999999999999999 1059 values := col.(*array.Int64).Int64Values() 1060 floatValues := make([]float64, len(values)) 1061 for i, val := range values { 1062 floatValues[i], _ = intToBigFloat(val, srcColumnMeta.Scale).Float64() 1063 } 1064 builder := array.NewFloat64Builder(memory.NewCheckedAllocator(memory.NewGoAllocator())) 1065 builder.AppendValues(floatValues, nil) 1066 newCol = builder.NewArray() 1067 builder.Release() 1068 defer newCol.Release() 1069 } 1070 case timeType: 1071 newCol, err = compute.CastArray(ctxAlloc, col, compute.SafeCastOptions(arrow.FixedWidthTypes.Time64ns)) 1072 if err != nil { 1073 return nil, err 1074 } 1075 defer newCol.Release() 1076 case timestampNtzType, timestampLtzType, timestampTzType: 1077 if arrowBatchesTimestampOption == UseOriginalTimestamp { 1078 // do nothing - return timestamp as is 1079 } else { 1080 var unit arrow.TimeUnit 1081 switch arrowBatchesTimestampOption { 1082 case UseMicrosecondTimestamp: 1083 unit = arrow.Microsecond 1084 case UseMillisecondTimestamp: 1085 unit = arrow.Millisecond 1086 case UseSecondTimestamp: 1087 unit = arrow.Second 1088 case UseNanosecondTimestamp: 1089 unit = arrow.Nanosecond 1090 } 1091 var tb *array.TimestampBuilder 1092 if snowflakeType == timestampLtzType { 1093 tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit, TimeZone: loc.String()}) 1094 } else { 1095 tb = array.NewTimestampBuilder(pool, &arrow.TimestampType{Unit: unit}) 1096 } 1097 defer tb.Release() 1098 1099 for i := 0; i < int(numRows); i++ { 1100 ts := arrowSnowflakeTimestampToTime(col, snowflakeType, int(srcColumnMeta.Scale), i, loc) 1101 if ts != nil { 1102 var ar arrow.Timestamp 1103 switch arrowBatchesTimestampOption { 1104 case UseMicrosecondTimestamp: 1105 ar = arrow.Timestamp(ts.UnixMicro()) 1106 case UseMillisecondTimestamp: 1107 ar = arrow.Timestamp(ts.UnixMilli()) 1108 case UseSecondTimestamp: 1109 ar = arrow.Timestamp(ts.Unix()) 1110 case UseNanosecondTimestamp: 1111 ar = arrow.Timestamp(ts.UnixNano()) 1112 // in case of overflow in arrow timestamp return error 1113 // this could only happen for nanosecond case 1114 if ts.UTC().Year() != ar.ToTime(arrow.Nanosecond).Year() { 1115 return nil, &SnowflakeError{ 1116 Number: ErrTooHighTimestampPrecision, 1117 SQLState: SQLStateInvalidDataTimeFormat, 1118 Message: fmt.Sprintf("Cannot convert timestamp %v in column %v to Arrow.Timestamp data type due to too high precision. Please use context with WithOriginalTimestamp.", ts.UTC(), srcColumnMeta.Name), 1119 } 1120 } 1121 } 1122 tb.Append(ar) 1123 } else { 1124 tb.AppendNull() 1125 } 1126 } 1127 1128 newCol = tb.NewArray() 1129 defer newCol.Release() 1130 } 1131 case textType: 1132 if arrowBatchesUtf8ValidationEnabled(ctx) && col.DataType().ID() == arrow.STRING { 1133 tb := array.NewStringBuilder(pool) 1134 defer tb.Release() 1135 1136 for i := 0; i < int(numRows); i++ { 1137 if col.(*array.String).IsValid(i) { 1138 stringValue := col.(*array.String).Value(i) 1139 if !utf8.ValidString(stringValue) { 1140 logger.Error("Invalid UTF-8 characters detected while reading query response, column: ", srcColumnMeta.Name) 1141 stringValue = strings.ToValidUTF8(stringValue, "�") 1142 } 1143 tb.Append(stringValue) 1144 } else { 1145 tb.AppendNull() 1146 } 1147 } 1148 newCol = tb.NewArray() 1149 defer newCol.Release() 1150 } 1151 } 1152 cols = append(cols, newCol) 1153 } 1154 return array.NewRecord(s, cols, numRows), nil 1155 } 1156 1157 func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) (*arrow.Schema, error) { 1158 var fields []arrow.Field 1159 for i := 0; i < len(sc.Fields()); i++ { 1160 f := sc.Field(i) 1161 srcColumnMeta := rowType[i] 1162 converted := true 1163 1164 var t arrow.DataType 1165 switch getSnowflakeType(srcColumnMeta.Type) { 1166 case fixedType: 1167 switch f.Type.ID() { 1168 case arrow.DECIMAL: 1169 if withHigherPrecision { 1170 converted = false 1171 } else if srcColumnMeta.Scale == 0 { 1172 t = &arrow.Int64Type{} 1173 } else { 1174 t = &arrow.Float64Type{} 1175 } 1176 default: 1177 if withHigherPrecision { 1178 converted = false 1179 } else if srcColumnMeta.Scale != 0 { 1180 t = &arrow.Float64Type{} 1181 } else { 1182 converted = false 1183 } 1184 } 1185 case timeType: 1186 t = &arrow.Time64Type{Unit: arrow.Nanosecond} 1187 case timestampNtzType, timestampTzType: 1188 if timestampOption == UseOriginalTimestamp { 1189 // do nothing - return timestamp as is 1190 converted = false 1191 } else if timestampOption == UseMicrosecondTimestamp { 1192 t = &arrow.TimestampType{Unit: arrow.Microsecond} 1193 } else if timestampOption == UseMillisecondTimestamp { 1194 t = &arrow.TimestampType{Unit: arrow.Millisecond} 1195 } else if timestampOption == UseSecondTimestamp { 1196 t = &arrow.TimestampType{Unit: arrow.Second} 1197 } else { 1198 t = &arrow.TimestampType{Unit: arrow.Nanosecond} 1199 } 1200 case timestampLtzType: 1201 if timestampOption == UseOriginalTimestamp { 1202 // do nothing - return timestamp as is 1203 converted = false 1204 } else if timestampOption == UseMicrosecondTimestamp { 1205 t = &arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: loc.String()} 1206 } else if timestampOption == UseMillisecondTimestamp { 1207 t = &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: loc.String()} 1208 } else if timestampOption == UseSecondTimestamp { 1209 t = &arrow.TimestampType{Unit: arrow.Second, TimeZone: loc.String()} 1210 } else { 1211 t = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: loc.String()} 1212 } 1213 default: 1214 converted = false 1215 } 1216 1217 newField := f 1218 if converted { 1219 newField = arrow.Field{ 1220 Name: f.Name, 1221 Type: t, 1222 Nullable: f.Nullable, 1223 Metadata: f.Metadata, 1224 } 1225 } 1226 fields = append(fields, newField) 1227 } 1228 meta := sc.Metadata() 1229 return arrow.NewSchema(fields, &meta), nil 1230 } 1231 1232 // TypedNullTime is required to properly bind the null value with the snowflakeType as the Snowflake functions 1233 // require the type of the field to be provided explicitly for the null values 1234 type TypedNullTime struct { 1235 Time sql.NullTime 1236 TzType timezoneType 1237 } 1238 1239 func convertTzTypeToSnowflakeType(tzType timezoneType) snowflakeType { 1240 switch tzType { 1241 case TimestampNTZType: 1242 return timestampNtzType 1243 case TimestampLTZType: 1244 return timestampLtzType 1245 case TimestampTZType: 1246 return timestampTzType 1247 case DateType: 1248 return dateType 1249 case TimeType: 1250 return timeType 1251 } 1252 return unSupportedType 1253 }