github.com/snowflakedb/gosnowflake@v1.9.0/arrow_test.go (about) 1 // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "fmt" 9 "math/big" 10 "reflect" 11 "strings" 12 "testing" 13 "time" 14 15 "database/sql/driver" 16 ) 17 18 // A test just to show Snowflake version 19 func TestCheckVersion(t *testing.T) { 20 conn := openConn(t) 21 defer conn.Close() 22 23 rows, err := conn.QueryContext(context.Background(), "SELECT current_version()") 24 if err != nil { 25 t.Error(err) 26 } 27 defer rows.Close() 28 29 if !rows.Next() { 30 t.Fatalf("failed to find any row") 31 } 32 var s string 33 if err = rows.Scan(&s); err != nil { 34 t.Fatal(err) 35 } 36 println(s) 37 } 38 39 func TestArrowBatchHighPrecision(t *testing.T) { 40 runDBTest(t, func(dbt *DBTest) { 41 ctx := WithArrowBatches(context.Background()) 42 query := "select '0.1':: DECIMAL(38, 19) as c" 43 44 var rows driver.Rows 45 var err error 46 47 // must use conn.Raw so we can get back driver rows (an interface) 48 // which can be cast to snowflakeRows which exposes GetArrowBatch 49 err = dbt.conn.Raw(func(x interface{}) error { 50 queryer, implementsQueryContext := x.(driver.QueryerContext) 51 assertTrueF(t, implementsQueryContext, "snowflake connection driver does not implement queryerContext") 52 53 rows, err = queryer.QueryContext(WithArrowBatches(ctx), query, nil) 54 return err 55 }) 56 57 assertNilF(t, err, "error running select query") 58 59 sfRows, isSfRows := rows.(SnowflakeRows) 60 assertTrueF(t, isSfRows, "rows should be snowflakeRows") 61 62 arrowBatches, err := sfRows.GetArrowBatches() 63 assertNilF(t, err, "error getting arrow batches") 64 assertNotEqualF(t, len(arrowBatches), 0, "should have at least one batch") 65 66 c, err := arrowBatches[0].Fetch() 67 assertNilF(t, err, "error fetching first batch") 68 69 chunk := *c 70 assertNotEqualF(t, len(chunk), 0, "should have at least one chunk") 71 72 strVal := chunk[0].Column(0).ValueStr(0) 73 expected := "0.1" 74 assertEqualF(t, strVal, expected, fmt.Sprintf("should have returned 0.1, but got: %s", strVal)) 75 }) 76 } 77 78 func TestArrowBigInt(t *testing.T) { 79 conn := openConn(t) 80 defer conn.Close() 81 dbt := &DBTest{t, conn} 82 83 testcases := []struct { 84 num string 85 prec int 86 sc int 87 }{ 88 {"10000000000000000000000000000000000000", 38, 0}, 89 {"-10000000000000000000000000000000000000", 38, 0}, 90 {"12345678901234567890123456789012345678", 38, 0}, // #pragma: allowlist secret 91 {"-12345678901234567890123456789012345678", 38, 0}, 92 {"99999999999999999999999999999999999999", 38, 0}, 93 {"-99999999999999999999999999999999999999", 38, 0}, 94 } 95 96 for _, tc := range testcases { 97 rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), 98 fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 99 if !rows.Next() { 100 dbt.Error("failed to query") 101 } 102 defer rows.Close() 103 var v *big.Int 104 if err := rows.Scan(&v); err != nil { 105 dbt.Errorf("failed to scan. %#v", err) 106 } 107 108 b, ok := new(big.Int).SetString(tc.num, 10) 109 if !ok { 110 dbt.Errorf("failed to convert %v big.Int.", tc.num) 111 } 112 if v.Cmp(b) != 0 { 113 dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v) 114 } 115 } 116 } 117 118 func TestArrowBigFloat(t *testing.T) { 119 conn := openConn(t) 120 defer conn.Close() 121 dbt := &DBTest{t, conn} 122 123 testcases := []struct { 124 num string 125 prec int 126 sc int 127 }{ 128 {"1.23", 30, 2}, 129 {"1.0000000000000000000000000000000000000", 38, 37}, 130 {"-1.0000000000000000000000000000000000000", 38, 37}, 131 {"1.2345678901234567890123456789012345678", 38, 37}, 132 {"-1.2345678901234567890123456789012345678", 38, 37}, 133 {"9.9999999999999999999999999999999999999", 38, 37}, 134 {"-9.9999999999999999999999999999999999999", 38, 37}, 135 } 136 137 for _, tc := range testcases { 138 rows := dbt.mustQueryContext(WithHigherPrecision(context.Background()), 139 fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 140 if !rows.Next() { 141 dbt.Error("failed to query") 142 } 143 defer rows.Close() 144 var v *big.Float 145 if err := rows.Scan(&v); err != nil { 146 dbt.Errorf("failed to scan. %#v", err) 147 } 148 149 prec := v.Prec() 150 b, ok := new(big.Float).SetPrec(prec).SetString(tc.num) 151 if !ok { 152 dbt.Errorf("failed to convert %v to big.Float.", tc.num) 153 } 154 if v.Cmp(b) != 0 { 155 dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v) 156 } 157 } 158 } 159 160 func TestArrowIntPrecision(t *testing.T) { 161 db := openDB(t) 162 defer db.Close() 163 164 _, err := db.Exec(forceJSON) 165 if err != nil { 166 t.Fatalf("failed to set JSON as result type: %v", err) 167 } 168 169 intTestcases := []struct { 170 num string 171 prec int 172 sc int 173 }{ 174 {"10000000000000000000000000000000000000", 38, 0}, 175 {"-10000000000000000000000000000000000000", 38, 0}, 176 {"12345678901234567890123456789012345678", 38, 0}, // pragma: allowlist secret 177 {"-12345678901234567890123456789012345678", 38, 0}, 178 {"99999999999999999999999999999999999999", 38, 0}, 179 {"-99999999999999999999999999999999999999", 38, 0}, 180 } 181 182 t.Run("arrow_disabled_scan_int64", func(t *testing.T) { 183 for _, tc := range intTestcases { 184 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 185 if err != nil { 186 t.Fatalf("failed to query: %v", err) 187 } 188 defer rows.Close() 189 if !rows.Next() { 190 t.Error("failed to query") 191 } 192 var v int64 193 if err := rows.Scan(&v); err == nil { 194 t.Error("should fail to scan") 195 } 196 } 197 }) 198 t.Run("arrow_disabled_scan_string", func(t *testing.T) { 199 for _, tc := range intTestcases { 200 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 201 if err != nil { 202 t.Fatalf("failed to query: %v", err) 203 } 204 defer rows.Close() 205 if !rows.Next() { 206 t.Error("failed to query") 207 } 208 defer rows.Close() 209 var v int64 210 if err := rows.Scan(&v); err == nil { 211 t.Error("should fail to scan") 212 } 213 } 214 }) 215 t.Run("arrow_enabled_scan_big_int", func(t *testing.T) { 216 for _, tc := range intTestcases { 217 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 218 if err != nil { 219 t.Fatalf("failed to query: %v", err) 220 } 221 defer rows.Close() 222 if !rows.Next() { 223 t.Error("failed to query") 224 } 225 var v string 226 if err := rows.Scan(&v); err != nil { 227 t.Errorf("failed to scan. %#v", err) 228 } 229 if !strings.EqualFold(v, tc.num) { 230 t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) 231 } 232 } 233 }) 234 t.Run("arrow_high_precision_enabled_scan_big_int", func(t *testing.T) { 235 for _, tc := range intTestcases { 236 rows, err := db.QueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 237 if err != nil { 238 t.Fatalf("failed to query: %v", err) 239 } 240 defer rows.Close() 241 if !rows.Next() { 242 t.Error("failed to query") 243 } 244 var v *big.Int 245 if err := rows.Scan(&v); err != nil { 246 t.Errorf("failed to scan. %#v", err) 247 } 248 249 b, ok := new(big.Int).SetString(tc.num, 10) 250 if !ok { 251 t.Errorf("failed to convert %v big.Int.", tc.num) 252 } 253 if v.Cmp(b) != 0 { 254 t.Errorf("big.Int value mismatch: expected %v, got %v", b, v) 255 } 256 } 257 }) 258 } 259 260 // TestArrowFloatPrecision tests the different variable types allowed in the 261 // rows.Scan() method. Note that for lower precision types we do not attempt 262 // to check the value as precision could be lost. 263 func TestArrowFloatPrecision(t *testing.T) { 264 db := openDB(t) 265 defer db.Close() 266 267 _, err := db.Exec(forceJSON) 268 if err != nil { 269 t.Fatalf("failed to set JSON as result type: %v", err) 270 } 271 272 fltTestcases := []struct { 273 num string 274 prec int 275 sc int 276 }{ 277 {"1.23", 30, 2}, 278 {"1.0000000000000000000000000000000000000", 38, 37}, 279 {"-1.0000000000000000000000000000000000000", 38, 37}, 280 {"1.2345678901234567890123456789012345678", 38, 37}, 281 {"-1.2345678901234567890123456789012345678", 38, 37}, 282 {"9.9999999999999999999999999999999999999", 38, 37}, 283 {"-9.9999999999999999999999999999999999999", 38, 37}, 284 } 285 286 t.Run("arrow_disabled_scan_float64", func(t *testing.T) { 287 for _, tc := range fltTestcases { 288 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 289 if err != nil { 290 t.Fatalf("failed to query: %v", err) 291 } 292 defer rows.Close() 293 if !rows.Next() { 294 t.Error("failed to query") 295 } 296 var v float64 297 if err := rows.Scan(&v); err != nil { 298 t.Errorf("failed to scan. %#v", err) 299 } 300 } 301 }) 302 t.Run("arrow_disabled_scan_float32", func(t *testing.T) { 303 for _, tc := range fltTestcases { 304 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 305 if err != nil { 306 t.Fatalf("failed to query: %v", err) 307 } 308 defer rows.Close() 309 if !rows.Next() { 310 t.Error("failed to query") 311 } 312 var v float32 313 if err := rows.Scan(&v); err != nil { 314 t.Errorf("failed to scan. %#v", err) 315 } 316 } 317 }) 318 t.Run("arrow_disabled_scan_string", func(t *testing.T) { 319 for _, tc := range fltTestcases { 320 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 321 if err != nil { 322 t.Fatalf("failed to query: %v", err) 323 } 324 defer rows.Close() 325 if !rows.Next() { 326 t.Error("failed to query") 327 } 328 var v string 329 if err := rows.Scan(&v); err != nil { 330 t.Errorf("failed to scan. %#v", err) 331 } 332 if !strings.EqualFold(v, tc.num) { 333 t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) 334 } 335 } 336 }) 337 t.Run("arrow_enabled_scan_float64", func(t *testing.T) { 338 for _, tc := range fltTestcases { 339 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 340 if err != nil { 341 t.Fatalf("failed to query: %v", err) 342 } 343 defer rows.Close() 344 if !rows.Next() { 345 t.Error("failed to query") 346 } 347 var v float64 348 if err := rows.Scan(&v); err != nil { 349 t.Errorf("failed to scan. %#v", err) 350 } 351 } 352 }) 353 t.Run("arrow_enabled_scan_float32", func(t *testing.T) { 354 for _, tc := range fltTestcases { 355 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 356 if err != nil { 357 t.Fatalf("failed to query: %v", err) 358 } 359 defer rows.Close() 360 if !rows.Next() { 361 t.Error("failed to query") 362 } 363 var v float32 364 if err := rows.Scan(&v); err != nil { 365 t.Errorf("failed to scan. %#v", err) 366 } 367 } 368 }) 369 t.Run("arrow_enabled_scan_string", func(t *testing.T) { 370 for _, tc := range fltTestcases { 371 rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 372 if err != nil { 373 t.Fatalf("failed to query: %v", err) 374 } 375 defer rows.Close() 376 if !rows.Next() { 377 t.Error("failed to query") 378 } 379 var v string 380 if err := rows.Scan(&v); err != nil { 381 t.Errorf("failed to scan. %#v", err) 382 } 383 } 384 }) 385 t.Run("arrow_high_precision_enabled_scan_big_float", func(t *testing.T) { 386 for _, tc := range fltTestcases { 387 rows, err := db.QueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) 388 if err != nil { 389 t.Fatalf("failed to query: %v", err) 390 } 391 defer rows.Close() 392 if !rows.Next() { 393 t.Error("failed to query") 394 } 395 var v *big.Float 396 if err := rows.Scan(&v); err != nil { 397 t.Errorf("failed to scan. %#v", err) 398 } 399 400 prec := v.Prec() 401 b, ok := new(big.Float).SetPrec(prec).SetString(tc.num) 402 if !ok { 403 t.Errorf("failed to convert %v to big.Float.", tc.num) 404 } 405 if v.Cmp(b) != 0 { 406 t.Errorf("big.Float value mismatch: expected %v, got %v", b, v) 407 } 408 } 409 }) 410 } 411 412 func TestArrowTimePrecision(t *testing.T) { 413 runDBTest(t, func(dbt *DBTest) { 414 dbt.mustExec("CREATE TABLE t (col5 TIME(5), col6 TIME(6), col7 TIME(7), col8 TIME(8));") 415 defer dbt.mustExec("DROP TABLE IF EXISTS t") 416 dbt.mustExec("INSERT INTO t VALUES ('23:59:59.99999', '23:59:59.999999', '23:59:59.9999999', '23:59:59.99999999');") 417 418 rows := dbt.mustQuery("select * from t") 419 defer rows.Close() 420 var c5, c6, c7, c8 time.Time 421 for rows.Next() { 422 if err := rows.Scan(&c5, &c6, &c7, &c8); err != nil { 423 t.Errorf("values were not scanned: %v", err) 424 } 425 } 426 427 nano := 999999990 428 expected := time.Time{}.Add(23*time.Hour + 59*time.Minute + 59*time.Second + 99*time.Millisecond) 429 if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano { 430 t.Errorf("the value did not match. expected: %v, got: %v", expected, c8) 431 } 432 if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) { 433 t.Errorf("the value did not match. expected: %v, got: %v", expected, c7) 434 } 435 if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) { 436 t.Errorf("the value did not match. expected: %v, got: %v", expected, c6) 437 } 438 if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) { 439 t.Errorf("the value did not match. expected: %v, got: %v", expected, c5) 440 } 441 442 dbt.mustExec(`CREATE TABLE t_ntz ( 443 col1 TIMESTAMP_NTZ(1), 444 col2 TIMESTAMP_NTZ(2), 445 col3 TIMESTAMP_NTZ(3), 446 col4 TIMESTAMP_NTZ(4), 447 col5 TIMESTAMP_NTZ(5), 448 col6 TIMESTAMP_NTZ(6), 449 col7 TIMESTAMP_NTZ(7), 450 col8 TIMESTAMP_NTZ(8) 451 );`) 452 defer dbt.mustExec("DROP TABLE IF EXISTS t_ntz") 453 dbt.mustExec(`INSERT INTO t_ntz VALUES ( 454 '9999-12-31T23:59:59.9', 455 '9999-12-31T23:59:59.99', 456 '9999-12-31T23:59:59.999', 457 '9999-12-31T23:59:59.9999', 458 '9999-12-31T23:59:59.99999', 459 '9999-12-31T23:59:59.999999', 460 '9999-12-31T23:59:59.9999999', 461 '9999-12-31T23:59:59.99999999' 462 );`) 463 464 rows2 := dbt.mustQuery("select * from t_ntz") 465 defer rows2.Close() 466 var c1, c2, c3, c4 time.Time 467 for rows2.Next() { 468 if err := rows2.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8); err != nil { 469 t.Errorf("values were not scanned: %v", err) 470 } 471 } 472 473 expected = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) 474 if c8.Unix() != expected.Unix() || c8.Nanosecond() != nano { 475 t.Errorf("the value did not match. expected: %v, got: %v", expected, c8) 476 } 477 if c7.Unix() != expected.Unix() || c7.Nanosecond() != nano-(nano%1e2) { 478 t.Errorf("the value did not match. expected: %v, got: %v", expected, c7) 479 } 480 if c6.Unix() != expected.Unix() || c6.Nanosecond() != nano-(nano%1e3) { 481 t.Errorf("the value did not match. expected: %v, got: %v", expected, c6) 482 } 483 if c5.Unix() != expected.Unix() || c5.Nanosecond() != nano-(nano%1e4) { 484 t.Errorf("the value did not match. expected: %v, got: %v", expected, c5) 485 } 486 if c4.Unix() != expected.Unix() || c4.Nanosecond() != nano-(nano%1e5) { 487 t.Errorf("the value did not match. expected: %v, got: %v", expected, c4) 488 } 489 if c3.Unix() != expected.Unix() || c3.Nanosecond() != nano-(nano%1e6) { 490 t.Errorf("the value did not match. expected: %v, got: %v", expected, c3) 491 } 492 if c2.Unix() != expected.Unix() || c2.Nanosecond() != nano-(nano%1e7) { 493 t.Errorf("the value did not match. expected: %v, got: %v", expected, c2) 494 } 495 if c1.Unix() != expected.Unix() || c1.Nanosecond() != nano-(nano%1e8) { 496 t.Errorf("the value did not match. expected: %v, got: %v", expected, c1) 497 } 498 }) 499 } 500 501 func TestArrowVariousTypes(t *testing.T) { 502 runDBTest(t, func(dbt *DBTest) { 503 rows := dbt.mustQueryContext( 504 WithHigherPrecision(context.Background()), selectVariousTypes) 505 defer rows.Close() 506 if !rows.Next() { 507 dbt.Error("failed to query") 508 } 509 cc, err := rows.Columns() 510 if err != nil { 511 dbt.Errorf("columns: %v", cc) 512 } 513 ct, err := rows.ColumnTypes() 514 if err != nil { 515 dbt.Errorf("column types: %v", ct) 516 } 517 var v1 *big.Float 518 var v2 int 519 var v3 string 520 var v4 float64 521 var v5 []byte 522 var v6 bool 523 if err = rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6); err != nil { 524 dbt.Errorf("failed to scan: %#v", err) 525 } 526 if v1.Cmp(big.NewFloat(1.0)) != 0 { 527 dbt.Errorf("failed to scan. %#v", *v1) 528 } 529 if ct[0].Name() != "C1" || ct[1].Name() != "C2" || ct[2].Name() != "C3" || ct[3].Name() != "C4" || ct[4].Name() != "C5" || ct[5].Name() != "C6" { 530 dbt.Errorf("failed to get column names: %#v", ct) 531 } 532 if ct[0].ScanType() != reflect.TypeOf(float64(0)) { 533 dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeOf(float64(0)), ct[0].ScanType()) 534 } 535 if ct[1].ScanType() != reflect.TypeOf(int64(0)) { 536 dbt.Errorf("failed to get scan type. expected: %v, got: %v", reflect.TypeOf(int64(0)), ct[1].ScanType()) 537 } 538 var pr, sc int64 539 var cLen int64 540 pr, sc = dbt.mustDecimalSize(ct[0]) 541 if pr != 30 || sc != 2 { 542 dbt.Errorf("failed to get precision and scale. %#v", ct[0]) 543 } 544 dbt.mustFailLength(ct[0]) 545 if canNull := dbt.mustNullable(ct[0]); canNull { 546 dbt.Errorf("failed to get nullable. %#v", ct[0]) 547 } 548 if cLen != 0 { 549 dbt.Errorf("failed to get length. %#v", ct[0]) 550 } 551 if v2 != 2 { 552 dbt.Errorf("failed to scan. %#v", v2) 553 } 554 pr, sc = dbt.mustDecimalSize(ct[1]) 555 if pr != 38 || sc != 0 { 556 dbt.Errorf("failed to get precision and scale. %#v", ct[1]) 557 } 558 dbt.mustFailLength(ct[1]) 559 if canNull := dbt.mustNullable(ct[1]); canNull { 560 dbt.Errorf("failed to get nullable. %#v", ct[1]) 561 } 562 if v3 != "t3" { 563 dbt.Errorf("failed to scan. %#v", v3) 564 } 565 dbt.mustFailDecimalSize(ct[2]) 566 if cLen = dbt.mustLength(ct[2]); cLen != 2 { 567 dbt.Errorf("failed to get length. %#v", ct[2]) 568 } 569 if canNull := dbt.mustNullable(ct[2]); canNull { 570 dbt.Errorf("failed to get nullable. %#v", ct[2]) 571 } 572 if v4 != 4.2 { 573 dbt.Errorf("failed to scan. %#v", v4) 574 } 575 dbt.mustFailDecimalSize(ct[3]) 576 dbt.mustFailLength(ct[3]) 577 if canNull := dbt.mustNullable(ct[3]); canNull { 578 dbt.Errorf("failed to get nullable. %#v", ct[3]) 579 } 580 if !bytes.Equal(v5, []byte{0xab, 0xcd}) { 581 dbt.Errorf("failed to scan. %#v", v5) 582 } 583 dbt.mustFailDecimalSize(ct[4]) 584 if cLen = dbt.mustLength(ct[4]); cLen != 8388608 { // BINARY 585 dbt.Errorf("failed to get length. %#v", ct[4]) 586 } 587 if canNull := dbt.mustNullable(ct[4]); canNull { 588 dbt.Errorf("failed to get nullable. %#v", ct[4]) 589 } 590 if !v6 { 591 dbt.Errorf("failed to scan. %#v", v6) 592 } 593 dbt.mustFailDecimalSize(ct[5]) 594 dbt.mustFailLength(ct[5]) 595 /*canNull = dbt.mustNullable(ct[5]) 596 if canNull { 597 dbt.Errorf("failed to get nullable. %#v", ct[5]) 598 }*/ 599 }) 600 }