github.com/snowflakedb/gosnowflake@v1.9.0/multistatement_test.go (about) 1 // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "encoding/json" 8 "errors" 9 "io" 10 "net/http" 11 "net/url" 12 "os" 13 "reflect" 14 "testing" 15 "time" 16 ) 17 18 func TestMultiStatementExecuteNoResultSet(t *testing.T) { 19 ctx, _ := WithMultiStatement(context.Background(), 4) 20 multiStmtQuery := "begin;\n" + 21 "delete from test_multi_statement_txn;\n" + 22 "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + 23 "commit;" 24 25 runDBTest(t, func(dbt *DBTest) { 26 dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`) 27 28 res := dbt.mustExecContext(ctx, multiStmtQuery) 29 count, err := res.RowsAffected() 30 if err != nil { 31 t.Fatalf("res.RowsAffected() returned error: %v", err) 32 } 33 if count != 3 { 34 t.Fatalf("expected 3 affected rows, got %d", count) 35 } 36 }) 37 } 38 39 func TestMultiStatementQueryResultSet(t *testing.T) { 40 ctx, _ := WithMultiStatement(context.Background(), 4) 41 multiStmtQuery := "select 123;\n" + 42 "select 456;\n" + 43 "select 789;\n" + 44 "select '000';" 45 46 var v1, v2, v3 int64 47 var v4 string 48 49 runDBTest(t, func(dbt *DBTest) { 50 rows := dbt.mustQueryContext(ctx, multiStmtQuery) 51 defer rows.Close() 52 53 // first statement 54 if rows.Next() { 55 if err := rows.Scan(&v1); err != nil { 56 t.Errorf("failed to scan: %#v", err) 57 } 58 if v1 != 123 { 59 t.Fatalf("failed to fetch. value: %v", v1) 60 } 61 } else { 62 t.Error("failed to query") 63 } 64 65 // second statement 66 if !rows.NextResultSet() { 67 t.Error("failed to retrieve next result set") 68 } 69 if rows.Next() { 70 if err := rows.Scan(&v2); err != nil { 71 t.Errorf("failed to scan: %#v", err) 72 } 73 if v2 != 456 { 74 t.Fatalf("failed to fetch. value: %v", v2) 75 } 76 } else { 77 t.Error("failed to query") 78 } 79 80 // third statement 81 if !rows.NextResultSet() { 82 t.Error("failed to retrieve next result set") 83 } 84 if rows.Next() { 85 if err := rows.Scan(&v3); err != nil { 86 t.Errorf("failed to scan: %#v", err) 87 } 88 if v3 != 789 { 89 t.Fatalf("failed to fetch. value: %v", v3) 90 } 91 } else { 92 t.Error("failed to query") 93 } 94 95 // fourth statement 96 if !rows.NextResultSet() { 97 t.Error("failed to retrieve next result set") 98 } 99 if rows.Next() { 100 if err := rows.Scan(&v4); err != nil { 101 t.Errorf("failed to scan: %#v", err) 102 } 103 if v4 != "000" { 104 t.Fatalf("failed to fetch. value: %v", v4) 105 } 106 } else { 107 t.Error("failed to query") 108 } 109 }) 110 } 111 112 func TestMultiStatementExecuteResultSet(t *testing.T) { 113 ctx, _ := WithMultiStatement(context.Background(), 6) 114 multiStmtQuery := "begin;\n" + 115 "delete from test_multi_statement_txn_rb;\n" + 116 "insert into test_multi_statement_txn_rb values (1, 'a'), (2, 'b');\n" + 117 "select 1;\n" + 118 "select 2;\n" + 119 "rollback;" 120 121 runDBTest(t, func(dbt *DBTest) { 122 dbt.mustExec("drop table if exists test_multi_statement_txn_rb") 123 dbt.mustExec(`create or replace table test_multi_statement_txn_rb( 124 c1 number, c2 string) as select 10, 'z'`) 125 defer dbt.mustExec("drop table if exists test_multi_statement_txn_rb") 126 127 res := dbt.mustExecContext(ctx, multiStmtQuery) 128 count, err := res.RowsAffected() 129 if err != nil { 130 t.Fatalf("res.RowsAffected() returned error: %v", err) 131 } 132 if count != 3 { 133 t.Fatalf("expected 3 affected rows, got %d", count) 134 } 135 }) 136 } 137 138 func TestMultiStatementQueryNoResultSet(t *testing.T) { 139 ctx, _ := WithMultiStatement(context.Background(), 4) 140 multiStmtQuery := "begin;\n" + 141 "delete from test_multi_statement_txn;\n" + 142 "insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" + 143 "commit;" 144 145 runDBTest(t, func(dbt *DBTest) { 146 dbt.mustExec("drop table if exists test_multi_statement_txn") 147 dbt.mustExec(`create or replace table test_multi_statement_txn( 148 c1 number, c2 string) as select 10, 'z'`) 149 defer dbt.mustExec("drop table if exists tfmuest_multi_statement_txn") 150 151 rows := dbt.mustQueryContext(ctx, multiStmtQuery) 152 defer rows.Close() 153 }) 154 } 155 156 func TestMultiStatementExecuteMix(t *testing.T) { 157 ctx, _ := WithMultiStatement(context.Background(), 3) 158 multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" + 159 "insert into test_multi values (1), (2);\n" + 160 "select cola from test_multi order by cola asc;" 161 162 runDBTest(t, func(dbt *DBTest) { 163 dbt.mustExec("drop table if exists test_multi_statement_txn") 164 dbt.mustExec(`create or replace table test_multi_statement_txn( 165 c1 number, c2 string) as select 10, 'z'`) 166 defer dbt.mustExec("drop table if exists test_multi_statement_txn") 167 168 res := dbt.mustExecContext(ctx, multiStmtQuery) 169 count, err := res.RowsAffected() 170 if err != nil { 171 t.Fatalf("res.RowsAffected() returned error: %v", err) 172 } 173 if count != 2 { 174 t.Fatalf("expected 2 affected rows, got %d", count) 175 } 176 }) 177 } 178 179 func TestMultiStatementQueryMix(t *testing.T) { 180 ctx, _ := WithMultiStatement(context.Background(), 3) 181 multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" + 182 "insert into test_multi values (1), (2);\n" + 183 "select cola from test_multi order by cola asc;" 184 185 var count, v int 186 runDBTest(t, func(dbt *DBTest) { 187 dbt.mustExec("drop table if exists test_multi_statement_txn") 188 dbt.mustExec(`create or replace table test_multi_statement_txn( 189 c1 number, c2 string) as select 10, 'z'`) 190 defer dbt.mustExec("drop table if exists test_multi_statement_txn") 191 192 rows := dbt.mustQueryContext(ctx, multiStmtQuery) 193 defer rows.Close() 194 195 // first statement 196 if !rows.Next() { 197 t.Error("failed to query") 198 } 199 200 // second statement 201 rows.NextResultSet() 202 if rows.Next() { 203 if err := rows.Scan(&count); err != nil { 204 t.Errorf("failed to scan: %#v", err) 205 } 206 if count != 2 { 207 t.Fatalf("expected 2 affected rows, got %d", count) 208 } 209 } 210 211 expected := 1 212 // third statement 213 rows.NextResultSet() 214 for rows.Next() { 215 if err := rows.Scan(&v); err != nil { 216 t.Errorf("failed to scan: %#v", err) 217 } 218 if v != expected { 219 t.Fatalf("failed to fetch. value: %v", v) 220 } 221 expected++ 222 } 223 }) 224 } 225 226 func TestMultiStatementCountZero(t *testing.T) { 227 ctx, _ := WithMultiStatement(context.Background(), 0) 228 var v1 int 229 var v2 string 230 var v3 float64 231 var v4 bool 232 233 runDBTest(t, func(dbt *DBTest) { 234 // first query 235 multiStmtQuery1 := "select 123;\n" + 236 "select '456';" 237 rows1 := dbt.mustQueryContext(ctx, multiStmtQuery1) 238 defer rows1.Close() 239 // first statement 240 if rows1.Next() { 241 if err := rows1.Scan(&v1); err != nil { 242 t.Errorf("failed to scan: %#v", err) 243 } 244 if v1 != 123 { 245 t.Fatalf("failed to fetch. value: %v", v1) 246 } 247 } else { 248 t.Error("failed to query") 249 } 250 251 // second statement 252 if !rows1.NextResultSet() { 253 t.Error("failed to retrieve next result set") 254 } 255 if rows1.Next() { 256 if err := rows1.Scan(&v2); err != nil { 257 t.Errorf("failed to scan: %#v", err) 258 } 259 if v2 != "456" { 260 t.Fatalf("failed to fetch. value: %v", v2) 261 } 262 } else { 263 t.Error("failed to query") 264 } 265 266 // second query 267 multiStmtQuery2 := "select 789;\n" + 268 "select 'foo';\n" + 269 "select 0.123;\n" + 270 "select true;" 271 rows2 := dbt.mustQueryContext(ctx, multiStmtQuery2) 272 defer rows2.Close() 273 // first statement 274 if rows2.Next() { 275 if err := rows2.Scan(&v1); err != nil { 276 t.Errorf("failed to scan: %#v", err) 277 } 278 if v1 != 789 { 279 t.Fatalf("failed to fetch. value: %v", v1) 280 } 281 } else { 282 t.Error("failed to query") 283 } 284 285 // second statement 286 if !rows2.NextResultSet() { 287 t.Error("failed to retrieve next result set") 288 } 289 if rows2.Next() { 290 if err := rows2.Scan(&v2); err != nil { 291 t.Errorf("failed to scan: %#v", err) 292 } 293 if v2 != "foo" { 294 t.Fatalf("failed to fetch. value: %v", v2) 295 } 296 } else { 297 t.Error("failed to query") 298 } 299 300 // third statement 301 if !rows2.NextResultSet() { 302 t.Error("failed to retrieve next result set") 303 } 304 if rows2.Next() { 305 if err := rows2.Scan(&v3); err != nil { 306 t.Errorf("failed to scan: %#v", err) 307 } 308 if v3 != 0.123 { 309 t.Fatalf("failed to fetch. value: %v", v3) 310 } 311 } else { 312 t.Error("failed to query") 313 } 314 315 // fourth statement 316 if !rows2.NextResultSet() { 317 t.Error("failed to retrieve next result set") 318 } 319 if rows2.Next() { 320 if err := rows2.Scan(&v4); err != nil { 321 t.Errorf("failed to scan: %#v", err) 322 } 323 if v4 != true { 324 t.Fatalf("failed to fetch. value: %v", v4) 325 } 326 } else { 327 t.Error("failed to query") 328 } 329 }) 330 } 331 332 func TestMultiStatementCountMismatch(t *testing.T) { 333 conn := openConn(t) 334 defer conn.Close() 335 336 multiStmtQuery := "select 123;\n" + 337 "select 456;\n" + 338 "select 789;\n" + 339 "select '000';" 340 341 ctx, _ := WithMultiStatement(context.Background(), 3) 342 if _, err := conn.QueryContext(ctx, multiStmtQuery); err == nil { 343 t.Fatal("should have failed to query multiple statements") 344 } 345 } 346 347 func TestMultiStatementVaryingColumnCount(t *testing.T) { 348 multiStmtQuery := "select c1 from test_tbl;\n" + 349 "select c1,c2 from test_tbl;" 350 ctx, _ := WithMultiStatement(context.Background(), 0) 351 352 var v1, v2 int 353 runDBTest(t, func(dbt *DBTest) { 354 dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)") 355 dbt.mustExec("insert into test_tbl values(1, 0)") 356 defer dbt.mustExec("drop table if exists test_tbl") 357 358 rows := dbt.mustQueryContext(ctx, multiStmtQuery) 359 defer rows.Close() 360 361 if rows.Next() { 362 if err := rows.Scan(&v1); err != nil { 363 t.Errorf("failed to scan: %#v", err) 364 } 365 if v1 != 1 { 366 t.Fatalf("failed to fetch. value: %v", v1) 367 } 368 } else { 369 t.Error("failed to query") 370 } 371 372 if !rows.NextResultSet() { 373 t.Error("failed to retrieve next result set") 374 } 375 376 if rows.Next() { 377 if err := rows.Scan(&v1, &v2); err != nil { 378 t.Errorf("failed to scan: %#v", err) 379 } 380 if v1 != 1 || v2 != 0 { 381 t.Fatalf("failed to fetch. value: %v, %v", v1, v2) 382 } 383 } else { 384 t.Error("failed to query") 385 } 386 }) 387 } 388 389 // The total completion time should be similar to the duration of the query on Snowflake UI. 390 func TestMultiStatementExecutePerformance(t *testing.T) { 391 ctx, _ := WithMultiStatement(context.Background(), 100) 392 runDBTest(t, func(dbt *DBTest) { 393 file, err := os.Open("test_data/multistatements.sql") 394 if err != nil { 395 t.Fatalf("failed opening file: %s", err) 396 } 397 defer file.Close() 398 statements, err := io.ReadAll(file) 399 if err != nil { 400 t.Fatalf("failed reading file: %s", err) 401 } 402 403 sql := string(statements) 404 405 start := time.Now() 406 res := dbt.mustExecContext(ctx, sql) 407 duration := time.Since(start) 408 409 count, err := res.RowsAffected() 410 if err != nil { 411 t.Fatalf("res.RowsAffected() returned error: %v", err) 412 } 413 if count != 0 { 414 t.Fatalf("expected 0 affected rows, got %d", count) 415 } 416 t.Logf("The total completion time was %v", duration) 417 418 file, err = os.Open("test_data/multistatements_drop.sql") 419 if err != nil { 420 t.Fatalf("failed opening file: %s", err) 421 } 422 defer file.Close() 423 statements, err = io.ReadAll(file) 424 if err != nil { 425 t.Fatalf("failed reading file: %s", err) 426 } 427 sql = string(statements) 428 dbt.mustExecContext(ctx, sql) 429 }) 430 } 431 432 func TestUnitGetChildResults(t *testing.T) { 433 testcases := []struct { 434 ids string 435 types string 436 out []childResult 437 }{ 438 {"", "", nil}, 439 {"", "4096", nil}, 440 {"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba", "12544,12544", []childResult{ 441 {"01aa3265-0405-ab7c-0000-53b106343aba", "12544"}, 442 {"02aa3265-0405-ab7c-0000-53b106343aba", "12544"}}}, 443 {"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba,03aa3265-0405-ab7c-0000-53b106343aba", "25344,4096,12544", []childResult{ 444 {"01aa3265-0405-ab7c-0000-53b106343aba", "25344"}, 445 {"02aa3265-0405-ab7c-0000-53b106343aba", "4096"}, 446 {"03aa3265-0405-ab7c-0000-53b106343aba", "12544"}}}, 447 } 448 for _, test := range testcases { 449 t.Run(test.ids, func(t *testing.T) { 450 res := getChildResults(test.ids, test.types) 451 if !reflect.DeepEqual(res, test.out) { 452 t.Fatalf("Child result should be equal, expected %v, actual %v", res, test.out) 453 } 454 }) 455 } 456 } 457 458 func funcGetQueryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 459 return nil, errors.New("failed to get query response") 460 } 461 462 func funcGetQueryRespError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 463 dd := &execResponseData{} 464 er := &execResponse{ 465 Data: *dd, 466 Message: "query failed", 467 Code: "261000", 468 Success: false, 469 } 470 ba, err := json.Marshal(er) 471 if err != nil { 472 panic(err) 473 } 474 475 return &http.Response{ 476 StatusCode: http.StatusOK, 477 Body: &fakeResponseBody{body: ba}, 478 }, nil 479 } 480 481 func TestUnitHandleMultiExec(t *testing.T) { 482 runSnowflakeConnTest(t, func(sct *SCTest) { 483 data := execResponseData{ 484 ResultIDs: "", 485 ResultTypes: "", 486 } 487 _, err := sct.sc.handleMultiExec(context.Background(), data) 488 if err == nil { 489 t.Fatalf("should have failed") 490 } 491 driverErr, ok := err.(*SnowflakeError) 492 if !ok { 493 t.Fatalf("should be snowflake error. err: %v", err) 494 } 495 if driverErr.Number != ErrNoResultIDs { 496 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) 497 } 498 499 data = execResponseData{ 500 ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", 501 ResultTypes: "12544,12544", 502 } 503 sct.sc.rest = &snowflakeRestful{ 504 FuncGet: funcGetQueryRespFail, 505 FuncCloseSession: closeSessionMock, 506 TokenAccessor: getSimpleTokenAccessor(), 507 } 508 _, err = sct.sc.handleMultiExec(context.Background(), data) 509 if err == nil { 510 t.Fatalf("should have failed") 511 } 512 513 sct.sc.rest.FuncGet = funcGetQueryRespError 514 data.SQLState = "01112" 515 _, err = sct.sc.handleMultiExec(context.Background(), data) 516 if err == nil { 517 t.Fatalf("should have failed") 518 } 519 driverErr, ok = err.(*SnowflakeError) 520 if !ok { 521 t.Fatalf("should be snowflake error. err: %v", err) 522 } 523 if driverErr.Number != ErrFailedToPostQuery { 524 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) 525 } 526 }) 527 } 528 529 func TestUnitHandleMultiQuery(t *testing.T) { 530 runSnowflakeConnTest(t, func(sct *SCTest) { 531 data := execResponseData{ 532 ResultIDs: "", 533 ResultTypes: "", 534 } 535 rows := new(snowflakeRows) 536 err := sct.sc.handleMultiQuery(context.Background(), data, rows) 537 if err == nil { 538 t.Fatalf("should have failed") 539 } 540 driverErr, ok := err.(*SnowflakeError) 541 if !ok { 542 t.Fatalf("should be snowflake error. err: %v", err) 543 } 544 if driverErr.Number != ErrNoResultIDs { 545 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number) 546 } 547 data = execResponseData{ 548 ResultIDs: "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre", 549 ResultTypes: "12544,12544", 550 } 551 sct.sc.rest = &snowflakeRestful{ 552 FuncGet: funcGetQueryRespFail, 553 FuncCloseSession: closeSessionMock, 554 TokenAccessor: getSimpleTokenAccessor(), 555 } 556 err = sct.sc.handleMultiQuery(context.Background(), data, rows) 557 if err == nil { 558 t.Fatalf("should have failed") 559 } 560 561 sct.sc.rest.FuncGet = funcGetQueryRespError 562 data.SQLState = "01112" 563 err = sct.sc.handleMultiQuery(context.Background(), data, rows) 564 if err == nil { 565 t.Fatalf("should have failed") 566 } 567 driverErr, ok = err.(*SnowflakeError) 568 if !ok { 569 t.Fatalf("should be snowflake error. err: %v", err) 570 } 571 if driverErr.Number != ErrFailedToPostQuery { 572 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number) 573 } 574 }) 575 }