github.com/jackc/pgx/v5@v5.5.5/helper_test.go (about) 1 package pgx_test 2 3 import ( 4 "context" 5 "os" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 10 "github.com/jackc/pgx/v5" 11 "github.com/jackc/pgx/v5/pgconn" 12 "github.com/jackc/pgx/v5/pgxtest" 13 "github.com/stretchr/testify/require" 14 ) 15 16 var defaultConnTestRunner pgxtest.ConnTestRunner 17 18 func init() { 19 defaultConnTestRunner = pgxtest.DefaultConnTestRunner() 20 defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { 21 config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 22 require.NoError(t, err) 23 return config 24 } 25 } 26 27 func mustConnectString(t testing.TB, connString string) *pgx.Conn { 28 conn, err := pgx.Connect(context.Background(), connString) 29 if err != nil { 30 t.Fatalf("Unable to establish connection: %v", err) 31 } 32 return conn 33 } 34 35 func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig { 36 config, err := pgx.ParseConfig(connString) 37 require.Nil(t, err) 38 return config 39 } 40 41 func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn { 42 conn, err := pgx.ConnectConfig(context.Background(), config) 43 if err != nil { 44 t.Fatalf("Unable to establish connection: %v", err) 45 } 46 return conn 47 } 48 49 func closeConn(t testing.TB, conn *pgx.Conn) { 50 err := conn.Close(context.Background()) 51 if err != nil { 52 t.Fatalf("conn.Close unexpectedly failed: %v", err) 53 } 54 } 55 56 func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) { 57 var err error 58 if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { 59 t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) 60 } 61 return 62 } 63 64 // Do a simple query to ensure the connection is still usable 65 func ensureConnValid(t testing.TB, conn *pgx.Conn) { 66 var sum, rowCount int32 67 68 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) 69 if err != nil { 70 t.Fatalf("conn.Query failed: %v", err) 71 } 72 defer rows.Close() 73 74 for rows.Next() { 75 var n int32 76 rows.Scan(&n) 77 sum += n 78 rowCount++ 79 } 80 81 if rows.Err() != nil { 82 t.Fatalf("conn.Query failed: %v", rows.Err()) 83 } 84 85 if rowCount != 10 { 86 t.Error("Select called onDataRow wrong number of times") 87 } 88 if sum != 55 { 89 t.Error("Wrong values returned") 90 } 91 } 92 93 func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { 94 if !assert.NotNil(t, expected) { 95 return 96 } 97 if !assert.NotNil(t, actual) { 98 return 99 } 100 101 assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName) 102 assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) 103 assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName) 104 assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName) 105 assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName) 106 assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) 107 assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) 108 assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) 109 assert.Equalf(t, expected.User, actual.User, "%s - User", testName) 110 assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) 111 assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) 112 assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) 113 114 // Can't test function equality, so just test that they are set or not. 115 assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) 116 assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) 117 118 if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { 119 if expected.TLSConfig != nil { 120 assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) 121 assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) 122 } 123 } 124 125 if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { 126 for i := range expected.Fallbacks { 127 assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) 128 assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) 129 130 if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { 131 if expected.Fallbacks[i].TLSConfig != nil { 132 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) 133 assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) 134 } 135 } 136 } 137 } 138 }