github.com/jackc/pgx/v5@v5.5.5/tx_test.go (about) 1 package pgx_test 2 3 import ( 4 "context" 5 "errors" 6 "os" 7 "testing" 8 "time" 9 10 "github.com/jackc/pgx/v5" 11 "github.com/jackc/pgx/v5/pgconn" 12 "github.com/jackc/pgx/v5/pgxtest" 13 "github.com/stretchr/testify/require" 14 ) 15 16 func TestTransactionSuccessfulCommit(t *testing.T) { 17 t.Parallel() 18 19 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 20 defer closeConn(t, conn) 21 22 createSql := ` 23 create temporary table foo( 24 id integer, 25 unique (id) 26 ); 27 ` 28 29 if _, err := conn.Exec(context.Background(), createSql); err != nil { 30 t.Fatalf("Failed to create table: %v", err) 31 } 32 33 tx, err := conn.Begin(context.Background()) 34 if err != nil { 35 t.Fatalf("conn.Begin failed: %v", err) 36 } 37 38 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") 39 if err != nil { 40 t.Fatalf("tx.Exec failed: %v", err) 41 } 42 43 err = tx.Commit(context.Background()) 44 if err != nil { 45 t.Fatalf("tx.Commit failed: %v", err) 46 } 47 48 var n int64 49 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 50 if err != nil { 51 t.Fatalf("QueryRow Scan failed: %v", err) 52 } 53 if n != 1 { 54 t.Fatalf("Did not receive correct number of rows: %v", n) 55 } 56 } 57 58 func TestTxCommitWhenTxBroken(t *testing.T) { 59 t.Parallel() 60 61 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 62 defer closeConn(t, conn) 63 64 createSql := ` 65 create temporary table foo( 66 id integer, 67 unique (id) 68 ); 69 ` 70 71 if _, err := conn.Exec(context.Background(), createSql); err != nil { 72 t.Fatalf("Failed to create table: %v", err) 73 } 74 75 tx, err := conn.Begin(context.Background()) 76 if err != nil { 77 t.Fatalf("conn.Begin failed: %v", err) 78 } 79 80 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { 81 t.Fatalf("tx.Exec failed: %v", err) 82 } 83 84 // Purposely break transaction 85 if _, err := tx.Exec(context.Background(), "syntax error"); err == nil { 86 t.Fatal("Unexpected success") 87 } 88 89 err = tx.Commit(context.Background()) 90 if err != pgx.ErrTxCommitRollback { 91 t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) 92 } 93 94 var n int64 95 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 96 if err != nil { 97 t.Fatalf("QueryRow Scan failed: %v", err) 98 } 99 if n != 0 { 100 t.Fatalf("Did not receive correct number of rows: %v", n) 101 } 102 } 103 104 func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { 105 t.Parallel() 106 107 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 108 defer closeConn(t, conn) 109 110 pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") 111 112 createSql := ` 113 create temporary table foo( 114 id integer, 115 unique (id) initially deferred 116 ); 117 ` 118 119 if _, err := conn.Exec(context.Background(), createSql); err != nil { 120 t.Fatalf("Failed to create table: %v", err) 121 } 122 123 tx, err := conn.Begin(context.Background()) 124 if err != nil { 125 t.Fatalf("conn.Begin failed: %v", err) 126 } 127 128 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { 129 t.Fatalf("tx.Exec failed: %v", err) 130 } 131 132 if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { 133 t.Fatalf("tx.Exec failed: %v", err) 134 } 135 136 err = tx.Commit(context.Background()) 137 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" { 138 t.Fatalf("Expected unique constraint violation 23505, got %#v", err) 139 } 140 141 var n int64 142 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 143 if err != nil { 144 t.Fatalf("QueryRow Scan failed: %v", err) 145 } 146 if n != 0 { 147 t.Fatalf("Did not receive correct number of rows: %v", n) 148 } 149 } 150 151 func TestTxCommitSerializationFailure(t *testing.T) { 152 t.Parallel() 153 154 c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 155 defer closeConn(t, c1) 156 157 if c1.PgConn().ParameterStatus("crdb_version") != "" { 158 t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)") 159 } 160 161 c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 162 defer closeConn(t, c2) 163 164 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 165 defer cancel() 166 167 c1.Exec(ctx, `drop table if exists tx_serializable_sums`) 168 _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`) 169 if err != nil { 170 t.Fatalf("Unable to create temporary table: %v", err) 171 } 172 defer c1.Exec(ctx, `drop table tx_serializable_sums`) 173 174 tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) 175 if err != nil { 176 t.Fatalf("Begin failed: %v", err) 177 } 178 defer tx1.Rollback(ctx) 179 180 tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) 181 if err != nil { 182 t.Fatalf("Begin failed: %v", err) 183 } 184 defer tx2.Rollback(ctx) 185 186 _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) 187 if err != nil { 188 t.Fatalf("Exec failed: %v", err) 189 } 190 191 _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) 192 if err != nil { 193 t.Fatalf("Exec failed: %v", err) 194 } 195 196 err = tx1.Commit(ctx) 197 if err != nil { 198 t.Fatalf("Commit failed: %v", err) 199 } 200 201 err = tx2.Commit(ctx) 202 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { 203 t.Fatalf("Expected serialization error 40001, got %#v", err) 204 } 205 206 ensureConnValid(t, c1) 207 ensureConnValid(t, c2) 208 } 209 210 func TestTransactionSuccessfulRollback(t *testing.T) { 211 t.Parallel() 212 213 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 214 defer closeConn(t, conn) 215 216 createSql := ` 217 create temporary table foo( 218 id integer, 219 unique (id) 220 ); 221 ` 222 223 if _, err := conn.Exec(context.Background(), createSql); err != nil { 224 t.Fatalf("Failed to create table: %v", err) 225 } 226 227 tx, err := conn.Begin(context.Background()) 228 if err != nil { 229 t.Fatalf("conn.Begin failed: %v", err) 230 } 231 232 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") 233 if err != nil { 234 t.Fatalf("tx.Exec failed: %v", err) 235 } 236 237 err = tx.Rollback(context.Background()) 238 if err != nil { 239 t.Fatalf("tx.Rollback failed: %v", err) 240 } 241 242 var n int64 243 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 244 if err != nil { 245 t.Fatalf("QueryRow Scan failed: %v", err) 246 } 247 if n != 0 { 248 t.Fatalf("Did not receive correct number of rows: %v", n) 249 } 250 } 251 252 func TestTransactionRollbackFailsClosesConnection(t *testing.T) { 253 t.Parallel() 254 255 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 256 defer closeConn(t, conn) 257 258 ctx, cancel := context.WithCancel(context.Background()) 259 260 tx, err := conn.Begin(ctx) 261 require.NoError(t, err) 262 263 cancel() 264 265 err = tx.Rollback(ctx) 266 require.Error(t, err) 267 268 require.True(t, conn.IsClosed()) 269 } 270 271 func TestBeginIsoLevels(t *testing.T) { 272 t.Parallel() 273 274 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 275 defer closeConn(t, conn) 276 277 pgxtest.SkipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") 278 279 isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} 280 for _, iso := range isoLevels { 281 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso}) 282 if err != nil { 283 t.Fatalf("conn.Begin failed: %v", err) 284 } 285 286 var level pgx.TxIsoLevel 287 conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level) 288 if level != iso { 289 t.Errorf("Expected to be in isolation level %v but was %v", iso, level) 290 } 291 292 err = tx.Rollback(context.Background()) 293 if err != nil { 294 t.Fatalf("tx.Rollback failed: %v", err) 295 } 296 } 297 } 298 299 func TestBeginFunc(t *testing.T) { 300 t.Parallel() 301 302 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 303 defer closeConn(t, conn) 304 305 createSql := ` 306 create temporary table foo( 307 id integer, 308 unique (id) 309 ); 310 ` 311 312 _, err := conn.Exec(context.Background(), createSql) 313 require.NoError(t, err) 314 315 err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { 316 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") 317 require.NoError(t, err) 318 return nil 319 }) 320 require.NoError(t, err) 321 322 var n int64 323 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 324 require.NoError(t, err) 325 require.EqualValues(t, 1, n) 326 } 327 328 func TestBeginFuncRollbackOnError(t *testing.T) { 329 t.Parallel() 330 331 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 332 defer closeConn(t, conn) 333 334 createSql := ` 335 create temporary table foo( 336 id integer, 337 unique (id) 338 ); 339 ` 340 341 _, err := conn.Exec(context.Background(), createSql) 342 require.NoError(t, err) 343 344 err = pgx.BeginFunc(context.Background(), conn, func(tx pgx.Tx) error { 345 _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") 346 require.NoError(t, err) 347 return errors.New("some error") 348 }) 349 require.EqualError(t, err, "some error") 350 351 var n int64 352 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 353 require.NoError(t, err) 354 require.EqualValues(t, 0, n) 355 } 356 357 func TestBeginReadOnly(t *testing.T) { 358 t.Parallel() 359 360 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 361 defer closeConn(t, conn) 362 363 tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly}) 364 if err != nil { 365 t.Fatalf("conn.Begin failed: %v", err) 366 } 367 defer tx.Rollback(context.Background()) 368 369 _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)") 370 if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" { 371 t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) 372 } 373 } 374 375 func TestBeginTxBeginQuery(t *testing.T) { 376 t.Parallel() 377 378 ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) 379 defer cancel() 380 381 pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 382 tx, err := conn.BeginTx(ctx, pgx.TxOptions{BeginQuery: "begin read only"}) 383 require.NoError(t, err) 384 defer tx.Rollback(ctx) 385 386 var readOnly bool 387 conn.QueryRow(ctx, "select current_setting('transaction_read_only')::bool").Scan(&readOnly) 388 require.True(t, readOnly) 389 390 err = tx.Rollback(ctx) 391 require.NoError(t, err) 392 }) 393 } 394 395 func TestTxNestedTransactionCommit(t *testing.T) { 396 t.Parallel() 397 398 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 399 defer closeConn(t, conn) 400 401 createSql := ` 402 create temporary table foo( 403 id integer, 404 unique (id) 405 ); 406 ` 407 408 if _, err := conn.Exec(context.Background(), createSql); err != nil { 409 t.Fatalf("Failed to create table: %v", err) 410 } 411 412 tx, err := conn.Begin(context.Background()) 413 if err != nil { 414 t.Fatal(err) 415 } 416 417 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") 418 if err != nil { 419 t.Fatalf("tx.Exec failed: %v", err) 420 } 421 422 nestedTx, err := tx.Begin(context.Background()) 423 if err != nil { 424 t.Fatal(err) 425 } 426 427 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") 428 if err != nil { 429 t.Fatalf("nestedTx.Exec failed: %v", err) 430 } 431 432 doubleNestedTx, err := nestedTx.Begin(context.Background()) 433 if err != nil { 434 t.Fatal(err) 435 } 436 437 _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)") 438 if err != nil { 439 t.Fatalf("doubleNestedTx.Exec failed: %v", err) 440 } 441 442 err = doubleNestedTx.Commit(context.Background()) 443 if err != nil { 444 t.Fatalf("doubleNestedTx.Commit failed: %v", err) 445 } 446 447 err = nestedTx.Commit(context.Background()) 448 if err != nil { 449 t.Fatalf("nestedTx.Commit failed: %v", err) 450 } 451 452 err = tx.Commit(context.Background()) 453 if err != nil { 454 t.Fatalf("tx.Commit failed: %v", err) 455 } 456 457 var n int64 458 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 459 if err != nil { 460 t.Fatalf("QueryRow Scan failed: %v", err) 461 } 462 if n != 3 { 463 t.Fatalf("Did not receive correct number of rows: %v", n) 464 } 465 } 466 467 func TestTxNestedTransactionRollback(t *testing.T) { 468 t.Parallel() 469 470 conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 471 defer closeConn(t, conn) 472 473 createSql := ` 474 create temporary table foo( 475 id integer, 476 unique (id) 477 ); 478 ` 479 480 if _, err := conn.Exec(context.Background(), createSql); err != nil { 481 t.Fatalf("Failed to create table: %v", err) 482 } 483 484 tx, err := conn.Begin(context.Background()) 485 if err != nil { 486 t.Fatal(err) 487 } 488 489 _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") 490 if err != nil { 491 t.Fatalf("tx.Exec failed: %v", err) 492 } 493 494 nestedTx, err := tx.Begin(context.Background()) 495 if err != nil { 496 t.Fatal(err) 497 } 498 499 _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") 500 if err != nil { 501 t.Fatalf("nestedTx.Exec failed: %v", err) 502 } 503 504 err = nestedTx.Rollback(context.Background()) 505 if err != nil { 506 t.Fatalf("nestedTx.Rollback failed: %v", err) 507 } 508 509 _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)") 510 if err != nil { 511 t.Fatalf("tx.Exec failed: %v", err) 512 } 513 514 err = tx.Commit(context.Background()) 515 if err != nil { 516 t.Fatalf("tx.Commit failed: %v", err) 517 } 518 519 var n int64 520 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 521 if err != nil { 522 t.Fatalf("QueryRow Scan failed: %v", err) 523 } 524 if n != 2 { 525 t.Fatalf("Did not receive correct number of rows: %v", n) 526 } 527 } 528 529 func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { 530 t.Parallel() 531 532 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 533 defer closeConn(t, db) 534 535 createSql := ` 536 create temporary table foo( 537 id integer, 538 unique (id) 539 ); 540 ` 541 542 _, err := db.Exec(context.Background(), createSql) 543 require.NoError(t, err) 544 545 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { 546 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") 547 require.NoError(t, err) 548 549 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { 550 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") 551 require.NoError(t, err) 552 553 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { 554 _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") 555 require.NoError(t, err) 556 return nil 557 }) 558 require.NoError(t, err) 559 560 return nil 561 }) 562 require.NoError(t, err) 563 return nil 564 }) 565 require.NoError(t, err) 566 567 var n int64 568 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 569 require.NoError(t, err) 570 require.EqualValues(t, 3, n) 571 } 572 573 func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { 574 t.Parallel() 575 576 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 577 defer closeConn(t, db) 578 579 createSql := ` 580 create temporary table foo( 581 id integer, 582 unique (id) 583 ); 584 ` 585 586 _, err := db.Exec(context.Background(), createSql) 587 require.NoError(t, err) 588 589 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { 590 _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") 591 require.NoError(t, err) 592 593 err = pgx.BeginFunc(context.Background(), db, func(db pgx.Tx) error { 594 _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") 595 require.NoError(t, err) 596 return errors.New("do a rollback") 597 }) 598 require.EqualError(t, err, "do a rollback") 599 600 _, err = db.Exec(context.Background(), "insert into foo(id) values (3)") 601 require.NoError(t, err) 602 603 return nil 604 }) 605 require.NoError(t, err) 606 607 var n int64 608 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) 609 require.NoError(t, err) 610 require.EqualValues(t, 2, n) 611 } 612 613 func TestTxSendBatchClosed(t *testing.T) { 614 t.Parallel() 615 616 db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 617 defer closeConn(t, db) 618 619 tx, err := db.Begin(context.Background()) 620 require.NoError(t, err) 621 defer tx.Rollback(context.Background()) 622 623 err = tx.Commit(context.Background()) 624 require.NoError(t, err) 625 626 batch := &pgx.Batch{} 627 batch.Queue("select 1") 628 batch.Queue("select 2") 629 batch.Queue("select 3") 630 631 br := tx.SendBatch(context.Background(), batch) 632 defer br.Close() 633 634 var n int 635 636 _, err = br.Exec() 637 require.Error(t, err) 638 639 err = br.QueryRow().Scan(&n) 640 require.Error(t, err) 641 642 _, err = br.Query() 643 require.Error(t, err) 644 }