github.com/jackc/pgx/v5@v5.5.5/stdlib/sql_test.go (about) 1 package stdlib_test 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "encoding/json" 8 "fmt" 9 "math" 10 "os" 11 "reflect" 12 "regexp" 13 "strconv" 14 "sync" 15 "testing" 16 "time" 17 18 "github.com/jackc/pgx/v5" 19 "github.com/jackc/pgx/v5/pgconn" 20 "github.com/jackc/pgx/v5/pgtype" 21 "github.com/jackc/pgx/v5/pgxpool" 22 "github.com/jackc/pgx/v5/stdlib" 23 "github.com/jackc/pgx/v5/tracelog" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 ) 27 28 func openDB(t testing.TB) *sql.DB { 29 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 30 require.NoError(t, err) 31 return stdlib.OpenDB(*config) 32 } 33 34 func closeDB(t testing.TB, db *sql.DB) { 35 err := db.Close() 36 require.NoError(t, err) 37 } 38 39 func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { 40 conn, err := db.Conn(context.Background()) 41 require.NoError(t, err) 42 defer conn.Close() 43 44 err = conn.Raw(func(driverConn any) error { 45 conn := driverConn.(*stdlib.Conn).Conn() 46 if conn.PgConn().ParameterStatus("crdb_version") != "" { 47 t.Skip(msg) 48 } 49 return nil 50 }) 51 require.NoError(t, err) 52 } 53 54 func skipPostgreSQLVersionLessThan(t testing.TB, db *sql.DB, minVersion int64) { 55 conn, err := db.Conn(context.Background()) 56 require.NoError(t, err) 57 defer conn.Close() 58 59 err = conn.Raw(func(driverConn any) error { 60 conn := driverConn.(*stdlib.Conn).Conn() 61 serverVersionStr := conn.PgConn().ParameterStatus("server_version") 62 serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) 63 // if not PostgreSQL do nothing 64 if serverVersionStr == "" { 65 return nil 66 } 67 68 serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) 69 if err != nil { 70 return err 71 } 72 73 if serverVersion < minVersion { 74 t.Skipf("Test requires PostgreSQL v%d+", minVersion) 75 } 76 77 return nil 78 }) 79 require.NoError(t, err) 80 } 81 82 func testWithAllQueryExecModes(t *testing.T, f func(t *testing.T, db *sql.DB)) { 83 for _, mode := range []pgx.QueryExecMode{ 84 pgx.QueryExecModeCacheStatement, 85 pgx.QueryExecModeCacheDescribe, 86 pgx.QueryExecModeDescribeExec, 87 pgx.QueryExecModeExec, 88 pgx.QueryExecModeSimpleProtocol, 89 } { 90 t.Run(mode.String(), 91 func(t *testing.T) { 92 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 93 require.NoError(t, err) 94 95 config.DefaultQueryExecMode = mode 96 db := stdlib.OpenDB(*config) 97 defer func() { 98 err := db.Close() 99 require.NoError(t, err) 100 }() 101 102 f(t, db) 103 104 ensureDBValid(t, db) 105 }, 106 ) 107 } 108 } 109 110 // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should 111 // cover broken connections. 112 func ensureDBValid(t testing.TB, db *sql.DB) { 113 var sum, rowCount int32 114 115 rows, err := db.Query("select generate_series(1,$1)", 10) 116 require.NoError(t, err) 117 defer rows.Close() 118 119 for rows.Next() { 120 var n int32 121 rows.Scan(&n) 122 sum += n 123 rowCount++ 124 } 125 126 require.NoError(t, rows.Err()) 127 128 if rowCount != 10 { 129 t.Error("Select called onDataRow wrong number of times") 130 } 131 if sum != 55 { 132 t.Error("Wrong values returned") 133 } 134 } 135 136 type preparer interface { 137 Prepare(query string) (*sql.Stmt, error) 138 } 139 140 func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { 141 stmt, err := p.Prepare(sql) 142 require.NoError(t, err) 143 return stmt 144 } 145 146 func closeStmt(t *testing.T, stmt *sql.Stmt) { 147 err := stmt.Close() 148 require.NoError(t, err) 149 } 150 151 func TestSQLOpen(t *testing.T) { 152 tests := []struct { 153 driverName string 154 }{ 155 {driverName: "pgx"}, 156 {driverName: "pgx/v5"}, 157 } 158 159 for _, tt := range tests { 160 tt := tt 161 162 t.Run(tt.driverName, func(t *testing.T) { 163 db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) 164 require.NoError(t, err) 165 closeDB(t, db) 166 }) 167 } 168 } 169 170 func TestSQLOpenFromPool(t *testing.T) { 171 pool, err := pgxpool.New(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 172 require.NoError(t, err) 173 t.Cleanup(pool.Close) 174 175 db := stdlib.OpenDBFromPool(pool) 176 ensureDBValid(t, db) 177 178 db.Close() 179 } 180 181 func TestNormalLifeCycle(t *testing.T) { 182 db := openDB(t) 183 defer closeDB(t, db) 184 185 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") 186 187 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") 188 defer closeStmt(t, stmt) 189 190 rows, err := stmt.Query(int32(1), int32(10)) 191 require.NoError(t, err) 192 193 rowCount := int64(0) 194 195 for rows.Next() { 196 rowCount++ 197 198 var s string 199 var n int64 200 err := rows.Scan(&s, &n) 201 require.NoError(t, err) 202 203 if s != "foo" { 204 t.Errorf(`Expected "foo", received "%v"`, s) 205 } 206 if n != rowCount { 207 t.Errorf("Expected %d, received %d", rowCount, n) 208 } 209 } 210 require.NoError(t, rows.Err()) 211 212 require.EqualValues(t, 10, rowCount) 213 214 err = rows.Close() 215 require.NoError(t, err) 216 217 ensureDBValid(t, db) 218 } 219 220 func TestStmtExec(t *testing.T) { 221 db := openDB(t) 222 defer closeDB(t, db) 223 224 tx, err := db.Begin() 225 require.NoError(t, err) 226 227 createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") 228 _, err = createStmt.Exec() 229 require.NoError(t, err) 230 closeStmt(t, createStmt) 231 232 insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") 233 result, err := insertStmt.Exec("foo") 234 require.NoError(t, err) 235 236 n, err := result.RowsAffected() 237 require.NoError(t, err) 238 require.EqualValues(t, 1, n) 239 closeStmt(t, insertStmt) 240 241 ensureDBValid(t, db) 242 } 243 244 func TestQueryCloseRowsEarly(t *testing.T) { 245 db := openDB(t) 246 defer closeDB(t, db) 247 248 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") 249 250 stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") 251 defer closeStmt(t, stmt) 252 253 rows, err := stmt.Query(int32(1), int32(10)) 254 require.NoError(t, err) 255 256 // Close rows immediately without having read them 257 err = rows.Close() 258 require.NoError(t, err) 259 260 // Run the query again to ensure the connection and statement are still ok 261 rows, err = stmt.Query(int32(1), int32(10)) 262 require.NoError(t, err) 263 264 rowCount := int64(0) 265 266 for rows.Next() { 267 rowCount++ 268 269 var s string 270 var n int64 271 err := rows.Scan(&s, &n) 272 require.NoError(t, err) 273 if s != "foo" { 274 t.Errorf(`Expected "foo", received "%v"`, s) 275 } 276 if n != rowCount { 277 t.Errorf("Expected %d, received %d", rowCount, n) 278 } 279 } 280 require.NoError(t, rows.Err()) 281 require.EqualValues(t, 10, rowCount) 282 283 err = rows.Close() 284 require.NoError(t, err) 285 286 ensureDBValid(t, db) 287 } 288 289 func TestConnExec(t *testing.T) { 290 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 291 _, err := db.Exec("create temporary table t(a varchar not null)") 292 require.NoError(t, err) 293 294 result, err := db.Exec("insert into t values('hey')") 295 require.NoError(t, err) 296 297 n, err := result.RowsAffected() 298 require.NoError(t, err) 299 require.EqualValues(t, 1, n) 300 }) 301 } 302 303 func TestConnQuery(t *testing.T) { 304 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 305 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") 306 307 rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) 308 require.NoError(t, err) 309 310 rowCount := int64(0) 311 312 for rows.Next() { 313 rowCount++ 314 315 var s string 316 var n int64 317 err := rows.Scan(&s, &n) 318 require.NoError(t, err) 319 if s != "foo" { 320 t.Errorf(`Expected "foo", received "%v"`, s) 321 } 322 if n != rowCount { 323 t.Errorf("Expected %d, received %d", rowCount, n) 324 } 325 } 326 require.NoError(t, rows.Err()) 327 require.EqualValues(t, 10, rowCount) 328 329 err = rows.Close() 330 require.NoError(t, err) 331 }) 332 } 333 334 func TestConnConcurrency(t *testing.T) { 335 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 336 _, err := db.Exec("create table t (id integer primary key, str text, dur_str interval)") 337 require.NoError(t, err) 338 339 defer func() { 340 _, err := db.Exec("drop table t") 341 require.NoError(t, err) 342 }() 343 344 var wg sync.WaitGroup 345 346 concurrency := 50 347 errChan := make(chan error, concurrency) 348 349 for i := 1; i <= concurrency; i++ { 350 wg.Add(1) 351 352 go func(idx int) { 353 defer wg.Done() 354 355 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) 356 defer cancel() 357 358 str := strconv.Itoa(idx) 359 duration := time.Duration(idx) * time.Second 360 _, err := db.ExecContext(ctx, "insert into t values($1)", idx) 361 if err != nil { 362 errChan <- fmt.Errorf("insert failed: %d %w", idx, err) 363 return 364 } 365 _, err = db.ExecContext(ctx, "update t set str = $1 where id = $2", str, idx) 366 if err != nil { 367 errChan <- fmt.Errorf("update 1 failed: %d %w", idx, err) 368 return 369 } 370 _, err = db.ExecContext(ctx, "update t set dur_str = $1 where id = $2", duration, idx) 371 if err != nil { 372 errChan <- fmt.Errorf("update 2 failed: %d %w", idx, err) 373 return 374 } 375 376 errChan <- nil 377 }(i) 378 } 379 wg.Wait() 380 for i := 1; i <= concurrency; i++ { 381 err := <-errChan 382 require.NoError(t, err) 383 } 384 385 for i := 1; i <= concurrency; i++ { 386 wg.Add(1) 387 388 go func(idx int) { 389 defer wg.Done() 390 391 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) 392 defer cancel() 393 394 var id int 395 var str string 396 var duration pgtype.Interval 397 err := db.QueryRowContext(ctx, "select id,str,dur_str from t where id = $1", idx).Scan(&id, &str, &duration) 398 if err != nil { 399 errChan <- fmt.Errorf("select failed: %d %w", idx, err) 400 return 401 } 402 if id != idx { 403 errChan <- fmt.Errorf("id mismatch: %d %d", idx, id) 404 return 405 } 406 if str != strconv.Itoa(idx) { 407 errChan <- fmt.Errorf("str mismatch: %d %s", idx, str) 408 return 409 } 410 expectedDuration := pgtype.Interval{ 411 Microseconds: int64(idx) * time.Second.Microseconds(), 412 Valid: true, 413 } 414 if duration != expectedDuration { 415 errChan <- fmt.Errorf("duration mismatch: %d %v", idx, duration) 416 return 417 } 418 419 errChan <- nil 420 }(i) 421 } 422 wg.Wait() 423 for i := 1; i <= concurrency; i++ { 424 err := <-errChan 425 require.NoError(t, err) 426 } 427 }) 428 } 429 430 // https://github.com/jackc/pgx/issues/781 431 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { 432 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 433 var s string 434 var b bool 435 436 rows, err := db.Query("select true, 'foo'") 437 require.NoError(t, err) 438 439 require.True(t, rows.Next()) 440 require.NoError(t, rows.Scan(&b, &s)) 441 assert.Equal(t, true, b) 442 assert.Equal(t, "foo", s) 443 }) 444 } 445 446 func TestConnQueryNull(t *testing.T) { 447 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 448 rows, err := db.Query("select $1::int", nil) 449 require.NoError(t, err) 450 451 rowCount := int64(0) 452 453 for rows.Next() { 454 rowCount++ 455 456 var n sql.NullInt64 457 err := rows.Scan(&n) 458 require.NoError(t, err) 459 if n.Valid != false { 460 t.Errorf("Expected n to be null, but it was %v", n) 461 } 462 } 463 require.NoError(t, rows.Err()) 464 require.EqualValues(t, 1, rowCount) 465 466 err = rows.Close() 467 require.NoError(t, err) 468 }) 469 } 470 471 func TestConnQueryRowByteSlice(t *testing.T) { 472 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 473 expected := []byte{222, 173, 190, 239} 474 var actual []byte 475 476 err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) 477 require.NoError(t, err) 478 require.EqualValues(t, expected, actual) 479 }) 480 } 481 482 func TestConnQueryFailure(t *testing.T) { 483 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 484 _, err := db.Query("select 'foo") 485 require.Error(t, err) 486 require.IsType(t, new(pgconn.PgError), err) 487 }) 488 } 489 490 func TestConnSimpleSlicePassThrough(t *testing.T) { 491 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 492 skipCockroachDB(t, db, "Server does not support cardinality function") 493 494 var n int64 495 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n) 496 require.NoError(t, err) 497 assert.EqualValues(t, 3, n) 498 }) 499 } 500 501 func TestConnQueryScanGoArray(t *testing.T) { 502 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 503 m := pgtype.NewMap() 504 505 var a []int64 506 err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) 507 require.NoError(t, err) 508 assert.Equal(t, []int64{1, 2, 3}, a) 509 }) 510 } 511 512 func TestConnQueryScanArray(t *testing.T) { 513 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 514 m := pgtype.NewMap() 515 516 var a pgtype.Array[int64] 517 err := db.QueryRow("select '{1,2,3}'::bigint[]").Scan(m.SQLScanner(&a)) 518 require.NoError(t, err) 519 assert.Equal(t, pgtype.Array[int64]{Elements: []int64{1, 2, 3}, Dims: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Valid: true}, a) 520 521 err = db.QueryRow("select null::bigint[]").Scan(m.SQLScanner(&a)) 522 require.NoError(t, err) 523 assert.Equal(t, pgtype.Array[int64]{Elements: nil, Dims: nil, Valid: false}, a) 524 }) 525 } 526 527 func TestConnQueryScanRange(t *testing.T) { 528 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 529 skipCockroachDB(t, db, "Server does not support int4range") 530 531 m := pgtype.NewMap() 532 533 var r pgtype.Range[pgtype.Int4] 534 err := db.QueryRow("select int4range(1, 5)").Scan(m.SQLScanner(&r)) 535 require.NoError(t, err) 536 assert.Equal( 537 t, 538 pgtype.Range[pgtype.Int4]{ 539 Lower: pgtype.Int4{Int32: 1, Valid: true}, 540 Upper: pgtype.Int4{Int32: 5, Valid: true}, 541 LowerType: pgtype.Inclusive, 542 UpperType: pgtype.Exclusive, 543 Valid: true, 544 }, 545 r) 546 }) 547 } 548 549 // Test type that pgx would handle natively in binary, but since it is not a 550 // database/sql native type should be passed through as a string 551 func TestConnQueryRowPgxBinary(t *testing.T) { 552 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 553 sql := "select $1::int4[]" 554 expected := "{1,2,3}" 555 var actual string 556 557 err := db.QueryRow(sql, expected).Scan(&actual) 558 require.NoError(t, err) 559 require.EqualValues(t, expected, actual) 560 }) 561 } 562 563 func TestConnQueryRowUnknownType(t *testing.T) { 564 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 565 skipCockroachDB(t, db, "Server does not support point type") 566 567 sql := "select $1::point" 568 expected := "(1,2)" 569 var actual string 570 571 err := db.QueryRow(sql, expected).Scan(&actual) 572 require.NoError(t, err) 573 require.EqualValues(t, expected, actual) 574 }) 575 } 576 577 func TestConnQueryJSONIntoByteSlice(t *testing.T) { 578 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 579 _, err := db.Exec(` 580 create temporary table docs( 581 body json not null 582 ); 583 584 insert into docs(body) values('{"foo": "bar"}'); 585 `) 586 require.NoError(t, err) 587 588 sql := `select * from docs` 589 expected := []byte(`{"foo": "bar"}`) 590 var actual []byte 591 592 err = db.QueryRow(sql).Scan(&actual) 593 if err != nil { 594 t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) 595 } 596 597 if !bytes.Equal(actual, expected) { 598 t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) 599 } 600 601 _, err = db.Exec(`drop table docs`) 602 require.NoError(t, err) 603 }) 604 } 605 606 func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { 607 // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data 608 // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. 609 610 db := openDB(t) 611 defer closeDB(t, db) 612 613 _, err := db.Exec(` 614 create temporary table docs( 615 body json not null 616 ); 617 `) 618 require.NoError(t, err) 619 620 expected := []byte(`{"foo": "bar"}`) 621 622 _, err = db.Exec(`insert into docs(body) values($1)`, expected) 623 require.NoError(t, err) 624 625 var actual []byte 626 err = db.QueryRow(`select body from docs`).Scan(&actual) 627 require.NoError(t, err) 628 629 if !bytes.Equal(actual, expected) { 630 t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) 631 } 632 633 _, err = db.Exec(`drop table docs`) 634 require.NoError(t, err) 635 } 636 637 func TestTransactionLifeCycle(t *testing.T) { 638 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 639 _, err := db.Exec("create temporary table t(a varchar not null)") 640 require.NoError(t, err) 641 642 tx, err := db.Begin() 643 require.NoError(t, err) 644 645 _, err = tx.Exec("insert into t values('hi')") 646 require.NoError(t, err) 647 648 err = tx.Rollback() 649 require.NoError(t, err) 650 651 var n int64 652 err = db.QueryRow("select count(*) from t").Scan(&n) 653 require.NoError(t, err) 654 require.EqualValues(t, 0, n) 655 656 tx, err = db.Begin() 657 require.NoError(t, err) 658 659 _, err = tx.Exec("insert into t values('hi')") 660 require.NoError(t, err) 661 662 err = tx.Commit() 663 require.NoError(t, err) 664 665 err = db.QueryRow("select count(*) from t").Scan(&n) 666 require.NoError(t, err) 667 require.EqualValues(t, 1, n) 668 }) 669 } 670 671 func TestConnBeginTxIsolation(t *testing.T) { 672 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 673 skipCockroachDB(t, db, "Server always uses serializable isolation level") 674 675 var defaultIsoLevel string 676 err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) 677 require.NoError(t, err) 678 679 supportedTests := []struct { 680 sqlIso sql.IsolationLevel 681 pgIso string 682 }{ 683 {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, 684 {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, 685 {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, 686 {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, 687 {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, 688 {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, 689 } 690 for i, tt := range supportedTests { 691 func() { 692 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) 693 if err != nil { 694 t.Errorf("%d. BeginTx failed: %v", i, err) 695 return 696 } 697 defer tx.Rollback() 698 699 var pgIso string 700 err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) 701 if err != nil { 702 t.Errorf("%d. QueryRow failed: %v", i, err) 703 } 704 705 if pgIso != tt.pgIso { 706 t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) 707 } 708 }() 709 } 710 711 unsupportedTests := []struct { 712 sqlIso sql.IsolationLevel 713 }{ 714 {sqlIso: sql.LevelWriteCommitted}, 715 {sqlIso: sql.LevelLinearizable}, 716 } 717 for i, tt := range unsupportedTests { 718 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) 719 if err == nil { 720 t.Errorf("%d. BeginTx should have failed", i) 721 tx.Rollback() 722 } 723 } 724 }) 725 } 726 727 func TestConnBeginTxReadOnly(t *testing.T) { 728 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 729 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) 730 require.NoError(t, err) 731 defer tx.Rollback() 732 733 var pgReadOnly string 734 err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) 735 if err != nil { 736 t.Errorf("QueryRow failed: %v", err) 737 } 738 739 if pgReadOnly != "on" { 740 t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") 741 } 742 }) 743 } 744 745 func TestBeginTxContextCancel(t *testing.T) { 746 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 747 _, err := db.Exec("drop table if exists t") 748 require.NoError(t, err) 749 750 ctx, cancelFn := context.WithCancel(context.Background()) 751 752 tx, err := db.BeginTx(ctx, nil) 753 require.NoError(t, err) 754 755 _, err = tx.Exec("create table t(id serial)") 756 require.NoError(t, err) 757 758 cancelFn() 759 760 err = tx.Commit() 761 if err != context.Canceled && err != sql.ErrTxDone { 762 t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) 763 } 764 765 var n int 766 err = db.QueryRow("select count(*) from t").Scan(&n) 767 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { 768 t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) 769 } 770 }) 771 } 772 773 func TestConnRaw(t *testing.T) { 774 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 775 conn, err := db.Conn(context.Background()) 776 require.NoError(t, err) 777 778 var n int 779 err = conn.Raw(func(driverConn any) error { 780 conn := driverConn.(*stdlib.Conn).Conn() 781 return conn.QueryRow(context.Background(), "select 42").Scan(&n) 782 }) 783 require.NoError(t, err) 784 assert.EqualValues(t, 42, n) 785 }) 786 } 787 788 func TestConnPingContextSuccess(t *testing.T) { 789 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 790 err := db.PingContext(context.Background()) 791 require.NoError(t, err) 792 }) 793 } 794 795 func TestConnPrepareContextSuccess(t *testing.T) { 796 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 797 stmt, err := db.PrepareContext(context.Background(), "select now()") 798 require.NoError(t, err) 799 err = stmt.Close() 800 require.NoError(t, err) 801 }) 802 } 803 804 // https://github.com/jackc/pgx/issues/1753#issuecomment-1746033281 805 // https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634 806 func TestConnMultiplePrepareAndDeallocate(t *testing.T) { 807 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 808 skipCockroachDB(t, db, "Server does not support pg_prepared_statements") 809 810 sql := "select 42" 811 stmt1, err := db.PrepareContext(context.Background(), sql) 812 require.NoError(t, err) 813 stmt2, err := db.PrepareContext(context.Background(), sql) 814 require.NoError(t, err) 815 err = stmt1.Close() 816 require.NoError(t, err) 817 818 var preparedStmtCount int64 819 err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) 820 require.NoError(t, err) 821 require.EqualValues(t, 1, preparedStmtCount) 822 823 err = stmt2.Close() // err isn't as useful as it should be as database/sql will ignore errors from Deallocate. 824 require.NoError(t, err) 825 826 err = db.QueryRowContext(context.Background(), "select count(*) from pg_prepared_statements where statement = $1", sql).Scan(&preparedStmtCount) 827 require.NoError(t, err) 828 require.EqualValues(t, 0, preparedStmtCount) 829 }) 830 } 831 832 func TestConnExecContextSuccess(t *testing.T) { 833 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 834 _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") 835 require.NoError(t, err) 836 }) 837 } 838 839 func TestConnQueryContextSuccess(t *testing.T) { 840 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 841 rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") 842 require.NoError(t, err) 843 844 for rows.Next() { 845 var n int64 846 err := rows.Scan(&n) 847 require.NoError(t, err) 848 } 849 require.NoError(t, rows.Err()) 850 }) 851 } 852 853 func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { 854 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 855 rows, err := db.Query("select 42::bigint") 856 require.NoError(t, err) 857 858 columnTypes, err := rows.ColumnTypes() 859 require.NoError(t, err) 860 require.Len(t, columnTypes, 1) 861 862 if columnTypes[0].DatabaseTypeName() != "INT8" { 863 t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8") 864 } 865 866 err = rows.Close() 867 require.NoError(t, err) 868 }) 869 } 870 871 func TestStmtExecContextSuccess(t *testing.T) { 872 db := openDB(t) 873 defer closeDB(t, db) 874 875 _, err := db.Exec("create temporary table t(id int primary key)") 876 require.NoError(t, err) 877 878 stmt, err := db.Prepare("insert into t(id) values ($1::int4)") 879 require.NoError(t, err) 880 defer stmt.Close() 881 882 _, err = stmt.ExecContext(context.Background(), 42) 883 require.NoError(t, err) 884 885 ensureDBValid(t, db) 886 } 887 888 func TestStmtExecContextCancel(t *testing.T) { 889 db := openDB(t) 890 defer closeDB(t, db) 891 892 _, err := db.Exec("create temporary table t(id int primary key)") 893 require.NoError(t, err) 894 895 stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") 896 require.NoError(t, err) 897 defer stmt.Close() 898 899 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 900 defer cancel() 901 902 _, err = stmt.ExecContext(ctx, 42) 903 if !pgconn.Timeout(err) { 904 t.Errorf("expected timeout error, got %v", err) 905 } 906 907 ensureDBValid(t, db) 908 } 909 910 func TestStmtQueryContextSuccess(t *testing.T) { 911 db := openDB(t) 912 defer closeDB(t, db) 913 914 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") 915 916 stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") 917 require.NoError(t, err) 918 defer stmt.Close() 919 920 rows, err := stmt.QueryContext(context.Background(), 5) 921 require.NoError(t, err) 922 923 for rows.Next() { 924 var n int64 925 if err := rows.Scan(&n); err != nil { 926 t.Error(err) 927 } 928 } 929 930 if rows.Err() != nil { 931 t.Error(rows.Err()) 932 } 933 934 ensureDBValid(t, db) 935 } 936 937 func TestRowsColumnTypes(t *testing.T) { 938 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 939 columnTypesTests := []struct { 940 Name string 941 TypeName string 942 Length struct { 943 Len int64 944 OK bool 945 } 946 DecimalSize struct { 947 Precision int64 948 Scale int64 949 OK bool 950 } 951 ScanType reflect.Type 952 }{ 953 { 954 Name: "a", 955 TypeName: "INT8", 956 Length: struct { 957 Len int64 958 OK bool 959 }{ 960 Len: 0, 961 OK: false, 962 }, 963 DecimalSize: struct { 964 Precision int64 965 Scale int64 966 OK bool 967 }{ 968 Precision: 0, 969 Scale: 0, 970 OK: false, 971 }, 972 ScanType: reflect.TypeOf(int64(0)), 973 }, { 974 Name: "bar", 975 TypeName: "TEXT", 976 Length: struct { 977 Len int64 978 OK bool 979 }{ 980 Len: math.MaxInt64, 981 OK: true, 982 }, 983 DecimalSize: struct { 984 Precision int64 985 Scale int64 986 OK bool 987 }{ 988 Precision: 0, 989 Scale: 0, 990 OK: false, 991 }, 992 ScanType: reflect.TypeOf(""), 993 }, { 994 Name: "dec", 995 TypeName: "NUMERIC", 996 Length: struct { 997 Len int64 998 OK bool 999 }{ 1000 Len: 0, 1001 OK: false, 1002 }, 1003 DecimalSize: struct { 1004 Precision int64 1005 Scale int64 1006 OK bool 1007 }{ 1008 Precision: 9, 1009 Scale: 2, 1010 OK: true, 1011 }, 1012 ScanType: reflect.TypeOf(float64(0)), 1013 }, { 1014 Name: "d", 1015 TypeName: "1266", 1016 Length: struct { 1017 Len int64 1018 OK bool 1019 }{ 1020 Len: 0, 1021 OK: false, 1022 }, 1023 DecimalSize: struct { 1024 Precision int64 1025 Scale int64 1026 OK bool 1027 }{ 1028 Precision: 0, 1029 Scale: 0, 1030 OK: false, 1031 }, 1032 ScanType: reflect.TypeOf(""), 1033 }, 1034 } 1035 1036 rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") 1037 require.NoError(t, err) 1038 1039 columns, err := rows.ColumnTypes() 1040 require.NoError(t, err) 1041 assert.Len(t, columns, 4) 1042 1043 for i, tt := range columnTypesTests { 1044 c := columns[i] 1045 if c.Name() != tt.Name { 1046 t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) 1047 } 1048 if c.DatabaseTypeName() != tt.TypeName { 1049 t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) 1050 } 1051 l, ok := c.Length() 1052 if l != tt.Length.Len { 1053 t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) 1054 } 1055 if ok != tt.Length.OK { 1056 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) 1057 } 1058 p, s, ok := c.DecimalSize() 1059 if p != tt.DecimalSize.Precision { 1060 t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) 1061 } 1062 if s != tt.DecimalSize.Scale { 1063 t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) 1064 } 1065 if ok != tt.DecimalSize.OK { 1066 t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) 1067 } 1068 if c.ScanType() != tt.ScanType { 1069 t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) 1070 } 1071 } 1072 }) 1073 } 1074 1075 func TestQueryLifeCycle(t *testing.T) { 1076 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 1077 skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") 1078 1079 rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) 1080 require.NoError(t, err) 1081 1082 rowCount := int64(0) 1083 1084 for rows.Next() { 1085 rowCount++ 1086 var ( 1087 s string 1088 n int64 1089 ) 1090 1091 err := rows.Scan(&s, &n) 1092 require.NoError(t, err) 1093 1094 if s != "foo" { 1095 t.Errorf(`Expected "foo", received "%v"`, s) 1096 } 1097 1098 if n != rowCount { 1099 t.Errorf("Expected %d, received %d", rowCount, n) 1100 } 1101 } 1102 require.NoError(t, rows.Err()) 1103 1104 err = rows.Close() 1105 require.NoError(t, err) 1106 1107 rows, err = db.Query("select 1 where false") 1108 require.NoError(t, err) 1109 1110 rowCount = int64(0) 1111 1112 for rows.Next() { 1113 rowCount++ 1114 } 1115 require.NoError(t, rows.Err()) 1116 require.EqualValues(t, 0, rowCount) 1117 1118 err = rows.Close() 1119 require.NoError(t, err) 1120 }) 1121 } 1122 1123 // https://github.com/jackc/pgx/issues/409 1124 func TestScanJSONIntoJSONRawMessage(t *testing.T) { 1125 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 1126 var msg json.RawMessage 1127 1128 err := db.QueryRow("select '{}'::json").Scan(&msg) 1129 require.NoError(t, err) 1130 require.EqualValues(t, []byte("{}"), []byte(msg)) 1131 }) 1132 } 1133 1134 type testLog struct { 1135 lvl tracelog.LogLevel 1136 msg string 1137 data map[string]any 1138 } 1139 1140 type testLogger struct { 1141 logs []testLog 1142 } 1143 1144 func (l *testLogger) Log(ctx context.Context, lvl tracelog.LogLevel, msg string, data map[string]any) { 1145 l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) 1146 } 1147 1148 func TestRegisterConnConfig(t *testing.T) { 1149 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 1150 require.NoError(t, err) 1151 1152 logger := &testLogger{} 1153 connConfig.Tracer = &tracelog.TraceLog{Logger: logger, LogLevel: tracelog.LogLevelInfo} 1154 1155 // Issue 947: Register and unregister a ConnConfig and ensure that the 1156 // returned connection string is not reused. 1157 connStr := stdlib.RegisterConnConfig(connConfig) 1158 require.Equal(t, "registeredConnConfig0", connStr) 1159 stdlib.UnregisterConnConfig(connStr) 1160 1161 connStr = stdlib.RegisterConnConfig(connConfig) 1162 defer stdlib.UnregisterConnConfig(connStr) 1163 require.Equal(t, "registeredConnConfig1", connStr) 1164 1165 db, err := sql.Open("pgx", connStr) 1166 require.NoError(t, err) 1167 defer closeDB(t, db) 1168 1169 var n int64 1170 err = db.QueryRow("select 1").Scan(&n) 1171 require.NoError(t, err) 1172 1173 l := logger.logs[len(logger.logs)-1] 1174 assert.Equal(t, "Query", l.msg) 1175 assert.Equal(t, "select 1", l.data["sql"]) 1176 } 1177 1178 // https://github.com/jackc/pgx/issues/958 1179 func TestConnQueryRowConstraintErrors(t *testing.T) { 1180 testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) { 1181 skipPostgreSQLVersionLessThan(t, db, 11) 1182 skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") 1183 1184 _, err := db.Exec(`create temporary table defer_test ( 1185 id text primary key, 1186 n int not null, unique (n), 1187 unique (n) deferrable initially deferred )`) 1188 require.NoError(t, err) 1189 1190 _, err = db.Exec(`drop function if exists test_trigger cascade`) 1191 require.NoError(t, err) 1192 1193 _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$ 1194 begin 1195 if new.n = 4 then 1196 raise exception 'n cant be 4!'; 1197 end if; 1198 return new; 1199 end$$`) 1200 require.NoError(t, err) 1201 1202 _, err = db.Exec(`create constraint trigger test 1203 after insert or update on defer_test 1204 deferrable initially deferred 1205 for each row 1206 execute function test_trigger()`) 1207 require.NoError(t, err) 1208 1209 _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`) 1210 require.NoError(t, err) 1211 1212 var id string 1213 err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id) 1214 assert.Error(t, err) 1215 }) 1216 } 1217 1218 func TestOptionBeforeAfterConnect(t *testing.T) { 1219 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 1220 require.NoError(t, err) 1221 1222 var beforeConnConfigs []*pgx.ConnConfig 1223 var afterConns []*pgx.Conn 1224 db := stdlib.OpenDB(*config, 1225 stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error { 1226 beforeConnConfigs = append(beforeConnConfigs, connConfig) 1227 return nil 1228 }), 1229 stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { 1230 afterConns = append(afterConns, conn) 1231 return nil 1232 })) 1233 defer closeDB(t, db) 1234 1235 // Force it to close and reopen a new connection after each query 1236 db.SetMaxIdleConns(0) 1237 1238 _, err = db.Exec("select 1") 1239 require.NoError(t, err) 1240 1241 _, err = db.Exec("select 1") 1242 require.NoError(t, err) 1243 1244 require.Len(t, beforeConnConfigs, 2) 1245 require.Len(t, afterConns, 2) 1246 1247 // Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they 1248 // are different objects, so can't use require.NotEqual 1249 require.False(t, config == beforeConnConfigs[0]) 1250 require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1]) 1251 } 1252 1253 func TestRandomizeHostOrderFunc(t *testing.T) { 1254 config, err := pgx.ParseConfig("postgres://host1,host2,host3") 1255 require.NoError(t, err) 1256 1257 // Test that at some point we connect to all 3 hosts 1258 hostsNotSeenYet := map[string]struct{}{ 1259 "host1": {}, 1260 "host2": {}, 1261 "host3": {}, 1262 } 1263 1264 // If we don't succeed within this many iterations, something is certainly wrong 1265 for i := 0; i < 100000; i++ { 1266 connCopy := *config 1267 stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy) 1268 1269 delete(hostsNotSeenYet, connCopy.Host) 1270 if len(hostsNotSeenYet) == 0 { 1271 return 1272 } 1273 1274 hostCheckLoop: 1275 for _, h := range []string{"host1", "host2", "host3"} { 1276 if connCopy.Host == h { 1277 continue 1278 } 1279 for _, f := range connCopy.Fallbacks { 1280 if f.Host == h { 1281 continue hostCheckLoop 1282 } 1283 } 1284 require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy) 1285 } 1286 } 1287 1288 require.Fail(t, "did not get all hosts as primaries after many randomizations") 1289 } 1290 1291 func TestResetSessionHookCalled(t *testing.T) { 1292 var mockCalled bool 1293 1294 connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 1295 require.NoError(t, err) 1296 1297 db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error { 1298 mockCalled = true 1299 1300 return nil 1301 })) 1302 1303 defer closeDB(t, db) 1304 1305 err = db.Ping() 1306 require.NoError(t, err) 1307 1308 err = db.Ping() 1309 require.NoError(t, err) 1310 1311 require.True(t, mockCalled) 1312 } 1313 1314 func TestCheckIdleConn(t *testing.T) { 1315 controllerConn, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) 1316 require.NoError(t, err) 1317 defer closeDB(t, controllerConn) 1318 1319 skipCockroachDB(t, controllerConn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") 1320 1321 db, err := sql.Open("pgx", os.Getenv("PGX_TEST_DATABASE")) 1322 require.NoError(t, err) 1323 defer closeDB(t, db) 1324 1325 var conns []*sql.Conn 1326 for i := 0; i < 3; i++ { 1327 c, err := db.Conn(context.Background()) 1328 require.NoError(t, err) 1329 conns = append(conns, c) 1330 } 1331 1332 require.EqualValues(t, 3, db.Stats().OpenConnections) 1333 1334 var pids []uint32 1335 for _, c := range conns { 1336 err := c.Raw(func(driverConn any) error { 1337 pids = append(pids, driverConn.(*stdlib.Conn).Conn().PgConn().PID()) 1338 return nil 1339 }) 1340 require.NoError(t, err) 1341 err = c.Close() 1342 require.NoError(t, err) 1343 } 1344 1345 // The database/sql connection pool seems to automatically close idle connections to only keep 2 alive. 1346 // require.EqualValues(t, 3, db.Stats().OpenConnections) 1347 1348 _, err = controllerConn.ExecContext(context.Background(), `select pg_terminate_backend(n) from unnest($1::int[]) n`, pids) 1349 require.NoError(t, err) 1350 1351 // All conns are dead they don't know it and neither does the pool. But because of database/sql automatically closing 1352 // idle connections we can't be sure how many we should have. require.EqualValues(t, 3, db.Stats().OpenConnections) 1353 1354 // Wait long enough so the pool will realize it needs to check the connections. 1355 time.Sleep(time.Second) 1356 1357 // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. 1358 err = db.PingContext(context.Background()) 1359 require.NoError(t, err) 1360 1361 // The original 3 conns should have been terminated and the a new conn established for the ping. 1362 require.EqualValues(t, 1, db.Stats().OpenConnections) 1363 c, err := db.Conn(context.Background()) 1364 require.NoError(t, err) 1365 1366 var cPID uint32 1367 err = c.Raw(func(driverConn any) error { 1368 cPID = driverConn.(*stdlib.Conn).Conn().PgConn().PID() 1369 return nil 1370 }) 1371 require.NoError(t, err) 1372 err = c.Close() 1373 require.NoError(t, err) 1374 1375 require.NotContains(t, pids, cPID) 1376 }