github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/sqlutils/sql_runner.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package sqlutils 12 13 import ( 14 "context" 15 gosql "database/sql" 16 "fmt" 17 "reflect" 18 "strings" 19 "testing" 20 21 "github.com/cockroachdb/cockroach/pkg/testutils" 22 "github.com/cockroachdb/errors" 23 ) 24 25 // SQLRunner wraps a testing.TB and *gosql.DB connection and provides 26 // convenience functions to run SQL statements and fail the test on any errors. 27 type SQLRunner struct { 28 DB DBHandle 29 } 30 31 // DBHandle is an interface that applies to *gosql.DB, *gosql.Conn, and 32 // *gosql.Tx, as well as *RoundRobinDBHandle. 33 type DBHandle interface { 34 ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error) 35 QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error) 36 QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row 37 } 38 39 var _ DBHandle = &gosql.DB{} 40 var _ DBHandle = &gosql.Conn{} 41 var _ DBHandle = &gosql.Tx{} 42 43 // MakeSQLRunner returns a SQLRunner for the given database connection. 44 // The argument can be a *gosql.DB, *gosql.Conn, or *gosql.Tx object. 45 func MakeSQLRunner(db DBHandle) *SQLRunner { 46 return &SQLRunner{DB: db} 47 } 48 49 // MakeRoundRobinSQLRunner returns a SQLRunner that uses a set of database 50 // connections, in a round-robin fashion. 51 func MakeRoundRobinSQLRunner(dbs ...DBHandle) *SQLRunner { 52 return MakeSQLRunner(MakeRoundRobinDBHandle(dbs...)) 53 } 54 55 // Exec is a wrapper around gosql.Exec that kills the test on error. 56 func (sr *SQLRunner) Exec(t testing.TB, query string, args ...interface{}) gosql.Result { 57 t.Helper() 58 r, err := sr.DB.ExecContext(context.Background(), query, args...) 59 if err != nil { 60 t.Fatalf("error executing '%s': %s", query, err) 61 } 62 return r 63 } 64 65 // ExecSucceedsSoon is a wrapper around gosql.Exec that wraps 66 // the exec in a succeeds soon. 67 func (sr *SQLRunner) ExecSucceedsSoon(t testing.TB, query string, args ...interface{}) { 68 t.Helper() 69 testutils.SucceedsSoon(t, func() error { 70 _, err := sr.DB.ExecContext(context.Background(), query, args...) 71 return err 72 }) 73 } 74 75 // ExecRowsAffected executes the statement and verifies that RowsAffected() 76 // matches the expected value. It kills the test on errors. 77 func (sr *SQLRunner) ExecRowsAffected( 78 t testing.TB, expRowsAffected int, query string, args ...interface{}, 79 ) { 80 t.Helper() 81 r := sr.Exec(t, query, args...) 82 numRows, err := r.RowsAffected() 83 if err != nil { 84 t.Fatal(err) 85 } 86 if numRows != int64(expRowsAffected) { 87 t.Fatalf("expected %d affected rows, got %d on '%s'", expRowsAffected, numRows, query) 88 } 89 } 90 91 // ExpectErr runs the given statement and verifies that it returns an error 92 // matching the given regex. 93 func (sr *SQLRunner) ExpectErr(t testing.TB, errRE string, query string, args ...interface{}) { 94 t.Helper() 95 _, err := sr.DB.ExecContext(context.Background(), query, args...) 96 if !testutils.IsError(err, errRE) { 97 t.Fatalf("expected error '%s', got: %v", errRE, err) 98 } 99 } 100 101 // ExpectErrSucceedsSoon wraps ExpectErr with a SucceedsSoon. 102 func (sr *SQLRunner) ExpectErrSucceedsSoon( 103 t testing.TB, errRE string, query string, args ...interface{}, 104 ) { 105 t.Helper() 106 testutils.SucceedsSoon(t, func() error { 107 _, err := sr.DB.ExecContext(context.Background(), query, args...) 108 if !testutils.IsError(err, errRE) { 109 return errors.Newf("expected error '%s', got: %v", errRE, err) 110 } 111 return nil 112 }) 113 } 114 115 // Query is a wrapper around gosql.Query that kills the test on error. 116 func (sr *SQLRunner) Query(t testing.TB, query string, args ...interface{}) *gosql.Rows { 117 t.Helper() 118 r, err := sr.DB.QueryContext(context.Background(), query, args...) 119 if err != nil { 120 t.Fatalf("error executing '%s': %s", query, err) 121 } 122 return r 123 } 124 125 // Row is a wrapper around gosql.Row that kills the test on error. 126 type Row struct { 127 testing.TB 128 row *gosql.Row 129 } 130 131 // Scan is a wrapper around (*gosql.Row).Scan that kills the test on error. 132 func (r *Row) Scan(dest ...interface{}) { 133 r.Helper() 134 if err := r.row.Scan(dest...); err != nil { 135 r.Fatalf("error scanning '%v': %+v", r.row, err) 136 } 137 } 138 139 // QueryRow is a wrapper around gosql.QueryRow that kills the test on error. 140 func (sr *SQLRunner) QueryRow(t testing.TB, query string, args ...interface{}) *Row { 141 t.Helper() 142 return &Row{t, sr.DB.QueryRowContext(context.Background(), query, args...)} 143 } 144 145 // QueryStr runs a Query and converts the result using RowsToStrMatrix. Kills 146 // the test on errors. 147 func (sr *SQLRunner) QueryStr(t testing.TB, query string, args ...interface{}) [][]string { 148 t.Helper() 149 rows := sr.Query(t, query, args...) 150 r, err := RowsToStrMatrix(rows) 151 if err != nil { 152 t.Fatal(err) 153 } 154 return r 155 } 156 157 // RowsToStrMatrix converts the given result rows to a string matrix; nulls are 158 // represented as "NULL". Empty results are represented by an empty (but 159 // non-nil) slice. 160 func RowsToStrMatrix(rows *gosql.Rows) ([][]string, error) { 161 cols, err := rows.Columns() 162 if err != nil { 163 return nil, err 164 } 165 vals := make([]interface{}, len(cols)) 166 for i := range vals { 167 vals[i] = new(interface{}) 168 } 169 res := [][]string{} 170 for rows.Next() { 171 if err := rows.Scan(vals...); err != nil { 172 return nil, err 173 } 174 row := make([]string, len(vals)) 175 for j, v := range vals { 176 if val := *v.(*interface{}); val != nil { 177 switch t := val.(type) { 178 case []byte: 179 row[j] = string(t) 180 default: 181 row[j] = fmt.Sprint(val) 182 } 183 } else { 184 row[j] = "NULL" 185 } 186 } 187 res = append(res, row) 188 } 189 if err := rows.Err(); err != nil { 190 return nil, err 191 } 192 return res, nil 193 } 194 195 // MatrixToStr converts a set of rows into a single string where each row is on 196 // a separate line and the columns with a row are comma separated. 197 func MatrixToStr(rows [][]string) string { 198 res := strings.Builder{} 199 for _, row := range rows { 200 res.WriteString(strings.Join(row, ", ")) 201 res.WriteRune('\n') 202 } 203 return res.String() 204 } 205 206 // CheckQueryResults checks that the rows returned by a query match the expected 207 // response. 208 func (sr *SQLRunner) CheckQueryResults(t testing.TB, query string, expected [][]string) { 209 t.Helper() 210 res := sr.QueryStr(t, query) 211 if !reflect.DeepEqual(res, expected) { 212 t.Errorf("query '%s': expected:\n%v\ngot:\n%v\n", 213 query, MatrixToStr(expected), MatrixToStr(res), 214 ) 215 } 216 } 217 218 // CheckQueryResultsRetry checks that the rows returned by a query match the 219 // expected response. If the results don't match right away, it will retry 220 // using testutils.SucceedsSoon. 221 func (sr *SQLRunner) CheckQueryResultsRetry(t testing.TB, query string, expected [][]string) { 222 t.Helper() 223 testutils.SucceedsSoon(t, func() error { 224 res := sr.QueryStr(t, query) 225 if !reflect.DeepEqual(res, expected) { 226 return errors.Errorf("query '%s': expected:\n%v\ngot:\n%v\n", 227 query, MatrixToStr(expected), MatrixToStr(res), 228 ) 229 } 230 return nil 231 }) 232 } 233 234 // RoundRobinDBHandle aggregates multiple DBHandles into a single one; each time 235 // a query is issued, a handle is selected in round-robin fashion. 236 type RoundRobinDBHandle struct { 237 handles []DBHandle 238 current int 239 } 240 241 var _ DBHandle = &RoundRobinDBHandle{} 242 243 // MakeRoundRobinDBHandle creates a RoundRobinDBHandle. 244 func MakeRoundRobinDBHandle(handles ...DBHandle) *RoundRobinDBHandle { 245 return &RoundRobinDBHandle{handles: handles} 246 } 247 248 func (rr *RoundRobinDBHandle) next() DBHandle { 249 h := rr.handles[rr.current] 250 rr.current = (rr.current + 1) % len(rr.handles) 251 return h 252 } 253 254 // ExecContext is part of the DBHandle interface. 255 func (rr *RoundRobinDBHandle) ExecContext( 256 ctx context.Context, query string, args ...interface{}, 257 ) (gosql.Result, error) { 258 return rr.next().ExecContext(ctx, query, args...) 259 } 260 261 // QueryContext is part of the DBHandle interface. 262 func (rr *RoundRobinDBHandle) QueryContext( 263 ctx context.Context, query string, args ...interface{}, 264 ) (*gosql.Rows, error) { 265 return rr.next().QueryContext(ctx, query, args...) 266 } 267 268 // QueryRowContext is part of the DBHandle interface. 269 func (rr *RoundRobinDBHandle) QueryRowContext( 270 ctx context.Context, query string, args ...interface{}, 271 ) *gosql.Row { 272 return rr.next().QueryRowContext(ctx, query, args...) 273 }