github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmpconn/conn.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // Package cmpconn assists in comparing results from DB connections.
    12  package cmpconn
    13  
    14  import (
    15  	"context"
    16  	gosql "database/sql"
    17  	"fmt"
    18  	"math/rand"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/sql/mutations"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    24  	"github.com/cockroachdb/errors"
    25  	"github.com/jackc/pgx"
    26  	"github.com/lib/pq"
    27  )
    28  
    29  // Conn holds gosql and pgx connections and provides some utility methods.
    30  type Conn interface {
    31  	// DB returns gosql connection.
    32  	DB() *gosql.DB
    33  	// PGX returns pgx connection.
    34  	PGX() *pgx.Conn
    35  	// Values executes prep and exec and returns the results of exec.
    36  	Values(ctx context.Context, prep, exec string) (rows *pgx.Rows, err error)
    37  	// Exec executes s.
    38  	Exec(ctx context.Context, s string) error
    39  	// Ping pings a connection.
    40  	Ping() error
    41  	// Close closes the connections.
    42  	Close()
    43  }
    44  
    45  type conn struct {
    46  	db  *gosql.DB
    47  	pgx *pgx.Conn
    48  }
    49  
    50  var _ Conn = &conn{}
    51  
    52  // DB is part of the Conn interface.
    53  func (c *conn) DB() *gosql.DB {
    54  	return c.db
    55  }
    56  
    57  // PGX is part of the Conn interface.
    58  func (c *conn) PGX() *pgx.Conn {
    59  	return c.pgx
    60  }
    61  
    62  var simpleProtocol = &pgx.QueryExOptions{SimpleProtocol: true}
    63  
    64  // Values executes prep and exec and returns the results of exec.
    65  func (c *conn) Values(ctx context.Context, prep, exec string) (rows *pgx.Rows, err error) {
    66  	if prep != "" {
    67  		rows, err = c.pgx.QueryEx(ctx, prep, simpleProtocol)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		rows.Close()
    72  	}
    73  	return c.pgx.QueryEx(ctx, exec, simpleProtocol)
    74  }
    75  
    76  // Exec executes s.
    77  func (c *conn) Exec(ctx context.Context, s string) error {
    78  	_, err := c.pgx.ExecEx(ctx, s, simpleProtocol)
    79  	return errors.Wrap(err, "exec")
    80  }
    81  
    82  // Ping pings a connection.
    83  func (c *conn) Ping() error {
    84  	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    85  	defer cancel()
    86  	return c.pgx.Ping(ctx)
    87  }
    88  
    89  // Close closes the connections.
    90  func (c *conn) Close() {
    91  	_ = c.db.Close()
    92  	_ = c.pgx.Close()
    93  }
    94  
    95  // NewConn returns a new Conn on the given uri and executes initSQL on it.
    96  func NewConn(uri string, initSQL ...string) (Conn, error) {
    97  	c := conn{}
    98  
    99  	{
   100  		connector, err := pq.NewConnector(uri)
   101  		if err != nil {
   102  			return nil, errors.Wrap(err, "pq conn")
   103  		}
   104  		c.db = gosql.OpenDB(connector)
   105  	}
   106  
   107  	{
   108  		config, err := pgx.ParseURI(uri)
   109  		if err != nil {
   110  			return nil, errors.Wrap(err, "pgx parse")
   111  		}
   112  		conn, err := pgx.Connect(config)
   113  		if err != nil {
   114  			return nil, errors.Wrap(err, "pgx conn")
   115  		}
   116  		c.pgx = conn
   117  	}
   118  
   119  	for _, s := range initSQL {
   120  		if s == "" {
   121  			continue
   122  		}
   123  
   124  		if _, err := c.pgx.Exec(s); err != nil {
   125  			return nil, errors.Wrap(err, "init SQL")
   126  		}
   127  	}
   128  
   129  	return &c, nil
   130  }
   131  
   132  // connWithMutators extends Conn by supporting application of mutations to
   133  // queries before their execution.
   134  type connWithMutators struct {
   135  	Conn
   136  	rng         *rand.Rand
   137  	sqlMutators []sqlbase.Mutator
   138  }
   139  
   140  var _ Conn = &connWithMutators{}
   141  
   142  // NewConnWithMutators returns a new Conn on the given uri and executes initSQL
   143  // on it. The mutators are applied to initSQL and will be applied to all
   144  // queries to be executed in CompareConns.
   145  func NewConnWithMutators(
   146  	uri string, rng *rand.Rand, sqlMutators []sqlbase.Mutator, initSQL ...string,
   147  ) (Conn, error) {
   148  	mutatedInitSQL := make([]string, len(initSQL))
   149  	for i, s := range initSQL {
   150  		mutatedInitSQL[i] = s
   151  		if s == "" {
   152  			continue
   153  		}
   154  
   155  		mutatedInitSQL[i], _ = mutations.ApplyString(rng, s, sqlMutators...)
   156  	}
   157  	conn, err := NewConn(uri, mutatedInitSQL...)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	return &connWithMutators{
   162  		Conn:        conn,
   163  		rng:         rng,
   164  		sqlMutators: sqlMutators,
   165  	}, nil
   166  }
   167  
   168  // CompareConns executes prep and exec on all connections in conns. If any
   169  // differ, an error is returned. ignoreSQLErrors determines whether SQL errors
   170  // are ignored.
   171  // NOTE: exec will be mutated for each connection of type connWithMutators.
   172  func CompareConns(
   173  	ctx context.Context,
   174  	timeout time.Duration,
   175  	conns map[string]Conn,
   176  	prep, exec string,
   177  	ignoreSQLErrors bool,
   178  ) (err error) {
   179  	ctx, cancel := context.WithTimeout(ctx, timeout)
   180  	defer cancel()
   181  	connRows := make(map[string]*pgx.Rows)
   182  	connExecs := make(map[string]string)
   183  	for name, conn := range conns {
   184  		connExecs[name] = exec
   185  		if cwm, withMutators := conn.(*connWithMutators); withMutators {
   186  			connExecs[name], _ = mutations.ApplyString(cwm.rng, exec, cwm.sqlMutators...)
   187  		}
   188  		rows, err := conn.Values(ctx, prep, connExecs[name])
   189  		if err != nil {
   190  			return nil //nolint:returnerrcheck
   191  		}
   192  		defer rows.Close()
   193  		connRows[name] = rows
   194  	}
   195  
   196  	// Annotate our error message with the exec queries since they can be
   197  	// mutated and differ per connection.
   198  	defer func() {
   199  		if err == nil {
   200  			return
   201  		}
   202  		var sb strings.Builder
   203  		prev := ""
   204  		for name, mutated := range connExecs {
   205  			fmt.Fprintf(&sb, "\n%s:", name)
   206  			if prev == mutated {
   207  				sb.WriteString(" [same as previous]\n")
   208  			} else {
   209  				fmt.Fprintf(&sb, "\n%s;\n", mutated)
   210  			}
   211  			prev = mutated
   212  		}
   213  		err = fmt.Errorf("%w%s", err, sb.String())
   214  	}()
   215  
   216  	return compareRows(connRows, ignoreSQLErrors)
   217  }
   218  
   219  // compareRows compares the results of executing of queries on all connections.
   220  // It always returns an error if there are any differences. Additionally,
   221  // ignoreSQLErrors specifies whether SQL errors should be ignored (in which
   222  // case the function returns nil if SQL error occurs).
   223  func compareRows(connRows map[string]*pgx.Rows, ignoreSQLErrors bool) error {
   224  	var first []interface{}
   225  	var firstName string
   226  	var minCount int
   227  	rowCounts := make(map[string]int)
   228  ReadRows:
   229  	for {
   230  		first = nil
   231  		firstName = ""
   232  		for name, rows := range connRows {
   233  			if !rows.Next() {
   234  				minCount = rowCounts[name]
   235  				break ReadRows
   236  			}
   237  			rowCounts[name]++
   238  			vals, err := rows.Values()
   239  			if err != nil {
   240  				if ignoreSQLErrors {
   241  					// This function can fail if, for example,
   242  					// a number doesn't fit into a float64. Ignore
   243  					// them and move along to another query.
   244  					err = nil
   245  				}
   246  				return err
   247  			}
   248  			if firstName == "" {
   249  				firstName = name
   250  				first = vals
   251  			} else {
   252  				if err := CompareVals(first, vals); err != nil {
   253  					return fmt.Errorf("compare %s to %s:\n%v", firstName, name, err)
   254  				}
   255  			}
   256  		}
   257  	}
   258  	// Make sure all are empty.
   259  	for name, rows := range connRows {
   260  		for rows.Next() {
   261  			rowCounts[name]++
   262  		}
   263  		if err := rows.Err(); err != nil {
   264  			if ignoreSQLErrors {
   265  				// Aww someone had a SQL error maybe, so we can't use this
   266  				// query.
   267  				err = nil
   268  			}
   269  			return err
   270  		}
   271  	}
   272  	// Ensure each connection returned the same number of rows.
   273  	for name, count := range rowCounts {
   274  		if minCount != count {
   275  			return fmt.Errorf("%s had %d rows, expected %d", name, count, minCount)
   276  		}
   277  	}
   278  	return nil
   279  }