github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/sqlutils/sql_runner.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sqlutils
    12  
    13  import (
    14  	"context"
    15  	gosql "database/sql"
    16  	"fmt"
    17  	"reflect"
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/testutils"
    22  	"github.com/cockroachdb/errors"
    23  )
    24  
    25  // SQLRunner wraps a testing.TB and *gosql.DB connection and provides
    26  // convenience functions to run SQL statements and fail the test on any errors.
    27  type SQLRunner struct {
    28  	DB DBHandle
    29  }
    30  
    31  // DBHandle is an interface that applies to *gosql.DB, *gosql.Conn, and
    32  // *gosql.Tx, as well as *RoundRobinDBHandle.
    33  type DBHandle interface {
    34  	ExecContext(ctx context.Context, query string, args ...interface{}) (gosql.Result, error)
    35  	QueryContext(ctx context.Context, query string, args ...interface{}) (*gosql.Rows, error)
    36  	QueryRowContext(ctx context.Context, query string, args ...interface{}) *gosql.Row
    37  }
    38  
    39  var _ DBHandle = &gosql.DB{}
    40  var _ DBHandle = &gosql.Conn{}
    41  var _ DBHandle = &gosql.Tx{}
    42  
    43  // MakeSQLRunner returns a SQLRunner for the given database connection.
    44  // The argument can be a *gosql.DB, *gosql.Conn, or *gosql.Tx object.
    45  func MakeSQLRunner(db DBHandle) *SQLRunner {
    46  	return &SQLRunner{DB: db}
    47  }
    48  
    49  // MakeRoundRobinSQLRunner returns a SQLRunner that uses a set of database
    50  // connections, in a round-robin fashion.
    51  func MakeRoundRobinSQLRunner(dbs ...DBHandle) *SQLRunner {
    52  	return MakeSQLRunner(MakeRoundRobinDBHandle(dbs...))
    53  }
    54  
    55  // Exec is a wrapper around gosql.Exec that kills the test on error.
    56  func (sr *SQLRunner) Exec(t testing.TB, query string, args ...interface{}) gosql.Result {
    57  	t.Helper()
    58  	r, err := sr.DB.ExecContext(context.Background(), query, args...)
    59  	if err != nil {
    60  		t.Fatalf("error executing '%s': %s", query, err)
    61  	}
    62  	return r
    63  }
    64  
    65  // ExecSucceedsSoon is a wrapper around gosql.Exec that wraps
    66  // the exec in a succeeds soon.
    67  func (sr *SQLRunner) ExecSucceedsSoon(t testing.TB, query string, args ...interface{}) {
    68  	t.Helper()
    69  	testutils.SucceedsSoon(t, func() error {
    70  		_, err := sr.DB.ExecContext(context.Background(), query, args...)
    71  		return err
    72  	})
    73  }
    74  
    75  // ExecRowsAffected executes the statement and verifies that RowsAffected()
    76  // matches the expected value. It kills the test on errors.
    77  func (sr *SQLRunner) ExecRowsAffected(
    78  	t testing.TB, expRowsAffected int, query string, args ...interface{},
    79  ) {
    80  	t.Helper()
    81  	r := sr.Exec(t, query, args...)
    82  	numRows, err := r.RowsAffected()
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	if numRows != int64(expRowsAffected) {
    87  		t.Fatalf("expected %d affected rows, got %d on '%s'", expRowsAffected, numRows, query)
    88  	}
    89  }
    90  
    91  // ExpectErr runs the given statement and verifies that it returns an error
    92  // matching the given regex.
    93  func (sr *SQLRunner) ExpectErr(t testing.TB, errRE string, query string, args ...interface{}) {
    94  	t.Helper()
    95  	_, err := sr.DB.ExecContext(context.Background(), query, args...)
    96  	if !testutils.IsError(err, errRE) {
    97  		t.Fatalf("expected error '%s', got: %v", errRE, err)
    98  	}
    99  }
   100  
   101  // ExpectErrSucceedsSoon wraps ExpectErr with a SucceedsSoon.
   102  func (sr *SQLRunner) ExpectErrSucceedsSoon(
   103  	t testing.TB, errRE string, query string, args ...interface{},
   104  ) {
   105  	t.Helper()
   106  	testutils.SucceedsSoon(t, func() error {
   107  		_, err := sr.DB.ExecContext(context.Background(), query, args...)
   108  		if !testutils.IsError(err, errRE) {
   109  			return errors.Newf("expected error '%s', got: %v", errRE, err)
   110  		}
   111  		return nil
   112  	})
   113  }
   114  
   115  // Query is a wrapper around gosql.Query that kills the test on error.
   116  func (sr *SQLRunner) Query(t testing.TB, query string, args ...interface{}) *gosql.Rows {
   117  	t.Helper()
   118  	r, err := sr.DB.QueryContext(context.Background(), query, args...)
   119  	if err != nil {
   120  		t.Fatalf("error executing '%s': %s", query, err)
   121  	}
   122  	return r
   123  }
   124  
   125  // Row is a wrapper around gosql.Row that kills the test on error.
   126  type Row struct {
   127  	testing.TB
   128  	row *gosql.Row
   129  }
   130  
   131  // Scan is a wrapper around (*gosql.Row).Scan that kills the test on error.
   132  func (r *Row) Scan(dest ...interface{}) {
   133  	r.Helper()
   134  	if err := r.row.Scan(dest...); err != nil {
   135  		r.Fatalf("error scanning '%v': %+v", r.row, err)
   136  	}
   137  }
   138  
   139  // QueryRow is a wrapper around gosql.QueryRow that kills the test on error.
   140  func (sr *SQLRunner) QueryRow(t testing.TB, query string, args ...interface{}) *Row {
   141  	t.Helper()
   142  	return &Row{t, sr.DB.QueryRowContext(context.Background(), query, args...)}
   143  }
   144  
   145  // QueryStr runs a Query and converts the result using RowsToStrMatrix. Kills
   146  // the test on errors.
   147  func (sr *SQLRunner) QueryStr(t testing.TB, query string, args ...interface{}) [][]string {
   148  	t.Helper()
   149  	rows := sr.Query(t, query, args...)
   150  	r, err := RowsToStrMatrix(rows)
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  	return r
   155  }
   156  
   157  // RowsToStrMatrix converts the given result rows to a string matrix; nulls are
   158  // represented as "NULL". Empty results are represented by an empty (but
   159  // non-nil) slice.
   160  func RowsToStrMatrix(rows *gosql.Rows) ([][]string, error) {
   161  	cols, err := rows.Columns()
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  	vals := make([]interface{}, len(cols))
   166  	for i := range vals {
   167  		vals[i] = new(interface{})
   168  	}
   169  	res := [][]string{}
   170  	for rows.Next() {
   171  		if err := rows.Scan(vals...); err != nil {
   172  			return nil, err
   173  		}
   174  		row := make([]string, len(vals))
   175  		for j, v := range vals {
   176  			if val := *v.(*interface{}); val != nil {
   177  				switch t := val.(type) {
   178  				case []byte:
   179  					row[j] = string(t)
   180  				default:
   181  					row[j] = fmt.Sprint(val)
   182  				}
   183  			} else {
   184  				row[j] = "NULL"
   185  			}
   186  		}
   187  		res = append(res, row)
   188  	}
   189  	if err := rows.Err(); err != nil {
   190  		return nil, err
   191  	}
   192  	return res, nil
   193  }
   194  
   195  // MatrixToStr converts a set of rows into a single string where each row is on
   196  // a separate line and the columns with a row are comma separated.
   197  func MatrixToStr(rows [][]string) string {
   198  	res := strings.Builder{}
   199  	for _, row := range rows {
   200  		res.WriteString(strings.Join(row, ", "))
   201  		res.WriteRune('\n')
   202  	}
   203  	return res.String()
   204  }
   205  
   206  // CheckQueryResults checks that the rows returned by a query match the expected
   207  // response.
   208  func (sr *SQLRunner) CheckQueryResults(t testing.TB, query string, expected [][]string) {
   209  	t.Helper()
   210  	res := sr.QueryStr(t, query)
   211  	if !reflect.DeepEqual(res, expected) {
   212  		t.Errorf("query '%s': expected:\n%v\ngot:\n%v\n",
   213  			query, MatrixToStr(expected), MatrixToStr(res),
   214  		)
   215  	}
   216  }
   217  
   218  // CheckQueryResultsRetry checks that the rows returned by a query match the
   219  // expected response. If the results don't match right away, it will retry
   220  // using testutils.SucceedsSoon.
   221  func (sr *SQLRunner) CheckQueryResultsRetry(t testing.TB, query string, expected [][]string) {
   222  	t.Helper()
   223  	testutils.SucceedsSoon(t, func() error {
   224  		res := sr.QueryStr(t, query)
   225  		if !reflect.DeepEqual(res, expected) {
   226  			return errors.Errorf("query '%s': expected:\n%v\ngot:\n%v\n",
   227  				query, MatrixToStr(expected), MatrixToStr(res),
   228  			)
   229  		}
   230  		return nil
   231  	})
   232  }
   233  
   234  // RoundRobinDBHandle aggregates multiple DBHandles into a single one; each time
   235  // a query is issued, a handle is selected in round-robin fashion.
   236  type RoundRobinDBHandle struct {
   237  	handles []DBHandle
   238  	current int
   239  }
   240  
   241  var _ DBHandle = &RoundRobinDBHandle{}
   242  
   243  // MakeRoundRobinDBHandle creates a RoundRobinDBHandle.
   244  func MakeRoundRobinDBHandle(handles ...DBHandle) *RoundRobinDBHandle {
   245  	return &RoundRobinDBHandle{handles: handles}
   246  }
   247  
   248  func (rr *RoundRobinDBHandle) next() DBHandle {
   249  	h := rr.handles[rr.current]
   250  	rr.current = (rr.current + 1) % len(rr.handles)
   251  	return h
   252  }
   253  
   254  // ExecContext is part of the DBHandle interface.
   255  func (rr *RoundRobinDBHandle) ExecContext(
   256  	ctx context.Context, query string, args ...interface{},
   257  ) (gosql.Result, error) {
   258  	return rr.next().ExecContext(ctx, query, args...)
   259  }
   260  
   261  // QueryContext is part of the DBHandle interface.
   262  func (rr *RoundRobinDBHandle) QueryContext(
   263  	ctx context.Context, query string, args ...interface{},
   264  ) (*gosql.Rows, error) {
   265  	return rr.next().QueryContext(ctx, query, args...)
   266  }
   267  
   268  // QueryRowContext is part of the DBHandle interface.
   269  func (rr *RoundRobinDBHandle) QueryRowContext(
   270  	ctx context.Context, query string, args ...interface{},
   271  ) *gosql.Row {
   272  	return rr.next().QueryRowContext(ctx, query, args...)
   273  }