github.com/jackc/pgx/v5@v5.5.5/pgxtest/pgxtest.go (about) 1 // Package pgxtest provides utilities for testing pgx and packages that integrate with pgx. 2 package pgxtest 3 4 import ( 5 "context" 6 "fmt" 7 "reflect" 8 "regexp" 9 "strconv" 10 "testing" 11 12 "github.com/jackc/pgx/v5" 13 ) 14 15 var AllQueryExecModes = []pgx.QueryExecMode{ 16 pgx.QueryExecModeCacheStatement, 17 pgx.QueryExecModeCacheDescribe, 18 pgx.QueryExecModeDescribeExec, 19 pgx.QueryExecModeExec, 20 pgx.QueryExecModeSimpleProtocol, 21 } 22 23 // KnownOIDQueryExecModes is a slice of all query exec modes where the param and result OIDs are known before sending the query. 24 var KnownOIDQueryExecModes = []pgx.QueryExecMode{ 25 pgx.QueryExecModeCacheStatement, 26 pgx.QueryExecModeCacheDescribe, 27 pgx.QueryExecModeDescribeExec, 28 } 29 30 // ConnTestRunner controls how a *pgx.Conn is created and closed by tests. All fields are required. Use DefaultConnTestRunner to get a 31 // ConnTestRunner with reasonable default values. 32 type ConnTestRunner struct { 33 // CreateConfig returns a *pgx.ConnConfig suitable for use with pgx.ConnectConfig. 34 CreateConfig func(ctx context.Context, t testing.TB) *pgx.ConnConfig 35 36 // AfterConnect is called after conn is established. It allows for arbitrary connection setup before a test begins. 37 AfterConnect func(ctx context.Context, t testing.TB, conn *pgx.Conn) 38 39 // AfterTest is called after the test is run. It allows for validating the state of the connection before it is closed. 40 AfterTest func(ctx context.Context, t testing.TB, conn *pgx.Conn) 41 42 // CloseConn closes conn. 43 CloseConn func(ctx context.Context, t testing.TB, conn *pgx.Conn) 44 } 45 46 // DefaultConnTestRunner returns a new ConnTestRunner with all fields set to reasonable default values. 47 func DefaultConnTestRunner() ConnTestRunner { 48 return ConnTestRunner{ 49 CreateConfig: func(ctx context.Context, t testing.TB) *pgx.ConnConfig { 50 config, err := pgx.ParseConfig("") 51 if err != nil { 52 t.Fatalf("ParseConfig failed: %v", err) 53 } 54 return config 55 }, 56 AfterConnect: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, 57 AfterTest: func(ctx context.Context, t testing.TB, conn *pgx.Conn) {}, 58 CloseConn: func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 59 err := conn.Close(ctx) 60 if err != nil { 61 t.Errorf("Close failed: %v", err) 62 } 63 }, 64 } 65 } 66 67 func (ctr *ConnTestRunner) RunTest(ctx context.Context, t testing.TB, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { 68 t.Helper() 69 70 config := ctr.CreateConfig(ctx, t) 71 conn, err := pgx.ConnectConfig(ctx, config) 72 if err != nil { 73 t.Fatalf("ConnectConfig failed: %v", err) 74 } 75 defer ctr.CloseConn(ctx, t, conn) 76 77 ctr.AfterConnect(ctx, t, conn) 78 f(ctx, t, conn) 79 ctr.AfterTest(ctx, t, conn) 80 } 81 82 // RunWithQueryExecModes runs a f in a new test for each element of modes with a new connection created using connector. 83 // If modes is nil all pgx.QueryExecModes are tested. 84 func RunWithQueryExecModes(ctx context.Context, t *testing.T, ctr ConnTestRunner, modes []pgx.QueryExecMode, f func(ctx context.Context, t testing.TB, conn *pgx.Conn)) { 85 if modes == nil { 86 modes = AllQueryExecModes 87 } 88 89 for _, mode := range modes { 90 ctrWithMode := ctr 91 ctrWithMode.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { 92 config := ctr.CreateConfig(ctx, t) 93 config.DefaultQueryExecMode = mode 94 return config 95 } 96 97 t.Run(mode.String(), 98 func(t *testing.T) { 99 ctrWithMode.RunTest(ctx, t, f) 100 }, 101 ) 102 } 103 } 104 105 type ValueRoundTripTest struct { 106 Param any 107 Result any 108 Test func(any) bool 109 } 110 111 func RunValueRoundTripTests( 112 ctx context.Context, 113 t testing.TB, 114 ctr ConnTestRunner, 115 modes []pgx.QueryExecMode, 116 pgTypeName string, 117 tests []ValueRoundTripTest, 118 ) { 119 t.Helper() 120 121 if modes == nil { 122 modes = AllQueryExecModes 123 } 124 125 ctr.RunTest(ctx, t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { 126 t.Helper() 127 128 sql := fmt.Sprintf("select $1::%s", pgTypeName) 129 130 for i, tt := range tests { 131 for _, mode := range modes { 132 err := conn.QueryRow(ctx, sql, mode, tt.Param).Scan(tt.Result) 133 if err != nil { 134 t.Errorf("%d. %v: %v", i, mode, err) 135 } 136 137 result := reflect.ValueOf(tt.Result) 138 if result.Kind() == reflect.Ptr { 139 result = result.Elem() 140 } 141 142 if !tt.Test(result.Interface()) { 143 t.Errorf("%d. %v: unexpected result for %v: %v", i, mode, tt.Param, result.Interface()) 144 } 145 } 146 } 147 }) 148 } 149 150 // SkipCockroachDB calls Skip on t with msg if the connection is to a CockroachDB server. 151 func SkipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { 152 if conn.PgConn().ParameterStatus("crdb_version") != "" { 153 t.Skip(msg) 154 } 155 } 156 157 func SkipPostgreSQLVersionLessThan(t testing.TB, conn *pgx.Conn, minVersion int64) { 158 serverVersionStr := conn.PgConn().ParameterStatus("server_version") 159 serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) 160 // if not PostgreSQL do nothing 161 if serverVersionStr == "" { 162 return 163 } 164 165 serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) 166 if err != nil { 167 t.Fatalf("postgres version parsed failed: %s", err) 168 } 169 170 if serverVersion < minVersion { 171 t.Skipf("Test requires PostgreSQL v%d+", minVersion) 172 } 173 }