github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmp-sql/main.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // cmp-sql connects to postgres and cockroach servers and compares the
    12  // results of SQL statements. Statements support both random generation
    13  // and random placeholder values. It can thus be used to do correctness or
    14  // compatibility testing.
    15  //
    16  // To use, start a cockroach and postgres server with SSL disabled. cmp-sql
    17  // will connect to both, generate some random SQL, and print an error when
    18  // difference results are returned. Currently it tests LIKE, binary operators,
    19  // and unary operators. cmp-sql runs a loop: 1) choose an Input, 2) generate a
    20  // random SQL string from from the Input, 3) generate random placeholders, 4)
    21  // execute the SQL + placeholders and compare results.
    22  //
    23  // The Inputs slice determines what SQL is generated. cmp-sql will repeatedly
    24  // generate new kinds of input SQL. The `sql` property of an Input is a
    25  // function that returns a SQL string with possible placeholders. The `args`
    26  // func slice generates the placeholder arguments.
    27  package main
    28  
    29  import (
    30  	"context"
    31  	"flag"
    32  	"fmt"
    33  	"math/rand"
    34  	"regexp"
    35  	"strings"
    36  
    37  	"github.com/cockroachdb/apd"
    38  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    39  	"github.com/jackc/pgx"
    40  	"github.com/jackc/pgx/pgtype"
    41  )
    42  
    43  var (
    44  	pgAddr = flag.String("pg", "localhost:5432", "postgres address")
    45  	pgUser = flag.String("pg-user", "postgres", "postgres user")
    46  	crAddr = flag.String("cr", "localhost:26257", "cockroach address")
    47  	crUser = flag.String("cr-user", "root", "cockroach user")
    48  	rng, _ = randutil.NewPseudoRand()
    49  )
    50  
    51  func main() {
    52  	flag.Parse()
    53  	ctx := context.Background()
    54  
    55  	// Save the confs so we can print out the used ports later.
    56  	var confs []pgx.ConnConfig
    57  	for addr, user := range map[string]string{
    58  		*pgAddr: *pgUser,
    59  		*crAddr: *crUser,
    60  	} {
    61  		conf, err := pgx.ParseURI(fmt.Sprintf("postgresql://%s@%s?sslmode=disable", user, addr))
    62  		if err != nil {
    63  			panic(err)
    64  		}
    65  		confs = append(confs, conf)
    66  	}
    67  	dbs := make([]*pgx.Conn, len(confs))
    68  	for i, conf := range confs {
    69  		db, err := pgx.Connect(conf)
    70  		if err != nil {
    71  			panic(err)
    72  		}
    73  		dbs[i] = db
    74  	}
    75  
    76  	// Record unique errors and mismatches and only show them once.
    77  	seen := make(map[string]bool)
    78  	sawErr := func(err error) string {
    79  		s := reduceErr(err)
    80  		// Ignore error strings after the first semicolon.
    81  		res := fmt.Sprintf("ERR: %s", s)
    82  		if !seen[res] {
    83  			fmt.Print(err, "\n\n")
    84  			seen[res] = true
    85  		}
    86  		return res
    87  	}
    88  	for {
    89  	Loop:
    90  		for _, input := range Inputs {
    91  			results := map[int]string{}
    92  			sql, args, repro := input.Generate()
    93  			for i, db := range dbs {
    94  				var res, full string
    95  				if rows, err := db.Query(sql, args...); err != nil {
    96  					res = sawErr(err)
    97  					full = err.Error()
    98  				} else {
    99  					if rows.Next() {
   100  						vals, err := rows.Values()
   101  						if err != nil {
   102  							panic(err)
   103  						} else if len(vals) != 1 {
   104  							panic(fmt.Errorf("expected 1 val, got %v", vals))
   105  						} else {
   106  							switch v := vals[0].(type) {
   107  							case *pgtype.Numeric:
   108  								b, err := v.EncodeText(nil, nil)
   109  								if err != nil {
   110  									panic(err)
   111  								}
   112  								// Use a decimal so we can Reduce away the extra zeros. Needed because
   113  								// pg and cr return equivalent but not identical decimal results.
   114  								var d apd.Decimal
   115  								if _, _, err := d.SetString(string(b)); err != nil {
   116  									panic(err)
   117  								}
   118  								d.Reduce(&d)
   119  								res = d.String()
   120  							default:
   121  								res = fmt.Sprint(v)
   122  							}
   123  							full = res
   124  						}
   125  					}
   126  					rows.Close()
   127  					if err := rows.Err(); err != nil {
   128  						res = sawErr(err)
   129  						full = err.Error()
   130  					}
   131  					if res == "" {
   132  						panic("empty")
   133  					}
   134  				}
   135  				// Ping to see if the previous query panic'd the server.
   136  				if err := db.Ping(ctx); err != nil {
   137  					fmt.Print("CRASHER:\n", repro)
   138  					panic(fmt.Errorf("%v is down", confs[i].Port))
   139  				}
   140  				// Check the current result against all previous results. Make sure they are the same.
   141  				for vi, v := range results {
   142  					if verr, reserr := strings.HasPrefix(v, "ERR"), strings.HasPrefix(res, "ERR"); verr && reserr {
   143  						continue
   144  					} else if input.ignoreIfEitherError && (verr || reserr) {
   145  						continue
   146  					}
   147  					if v != res {
   148  						mismatch := fmt.Sprintf("%v: got %s\n%v: saw %s\n",
   149  							confs[i].Port,
   150  							full,
   151  							confs[vi].Port,
   152  							v,
   153  						)
   154  						if !seen[mismatch] {
   155  							seen[mismatch] = true
   156  							fmt.Print("MISMATCH:\n", mismatch)
   157  							fmt.Println(repro)
   158  						}
   159  						continue Loop
   160  					}
   161  				}
   162  				results[i] = res
   163  			}
   164  		}
   165  	}
   166  }
   167  
   168  var reduceErrRE = regexp.MustCompile(` *(ERROR)?[ :]*([A-Za-z ]+?) +`)
   169  
   170  // reduceErr removes any "ERROR:" prefix and returns the first words of an
   171  // error message. This is usually enough to uniquely identify it and remove
   172  // any non-unique (i.e., random string or numeric) values.
   173  func reduceErr(err error) string {
   174  	match := reduceErrRE.FindStringSubmatch(err.Error())
   175  	if match == nil {
   176  		return err.Error()
   177  	}
   178  	return match[2]
   179  }
   180  
   181  // Input defines an SQL statement generator.
   182  type Input struct {
   183  	sql  func() string
   184  	args []func() interface{}
   185  	// ignoreIfEitherError, if true, will only do mismatch comparison if both
   186  	// crdb and pg return non-error results.
   187  	ignoreIfEitherError bool
   188  }
   189  
   190  // Generate returns an instance of input's SQL and arguments, as well as a
   191  // repro string that can be copy-pasted into a SQL console.
   192  func (i Input) Generate() (sql string, args []interface{}, repro string) {
   193  	sql = i.sql()
   194  	args = make([]interface{}, len(i.args))
   195  	for i, fn := range i.args {
   196  		args[i] = fn()
   197  	}
   198  	var b strings.Builder
   199  	fmt.Fprintf(&b, "PREPARE a AS %s;\n", sql)
   200  	b.WriteString("EXECUTE a (")
   201  	for i, fn := range i.args {
   202  		if i > 0 {
   203  			b.WriteString(", ")
   204  		}
   205  		arg := fn()
   206  		switch arg := arg.(type) {
   207  		case int, int64, float64:
   208  			fmt.Fprint(&b, arg)
   209  		case string:
   210  			s := fmt.Sprintf("%q", arg)
   211  			fmt.Fprintf(&b, "e'%s'", s[1:len(s)-1])
   212  		default:
   213  			panic(fmt.Errorf("unknown type: %T", arg))
   214  		}
   215  	}
   216  	b.WriteString(");\n")
   217  	return sql, args, b.String()
   218  }
   219  
   220  // Inputs is the collection of generators that are compared.
   221  var Inputs = []Input{
   222  	{
   223  		sql:  pass("SELECT $1 LIKE $2"),
   224  		args: twoLike,
   225  	},
   226  	{
   227  		sql:  pass("SELECT $1 LIKE $2 ESCAPE $3"),
   228  		args: threeLike,
   229  	},
   230  	{
   231  		sql: fromSlices(
   232  			"SELECT $1::%s %s $2::%s %s $3::%s",
   233  			numTyps,
   234  			binaryNumOps,
   235  			numTyps,
   236  			binaryNumOps,
   237  			numTyps,
   238  		),
   239  		args:                threeNum,
   240  		ignoreIfEitherError: true,
   241  	},
   242  	{
   243  		sql: fromSlices(
   244  			"SELECT %s($1::%s)",
   245  			unaryNumOps,
   246  			numTyps,
   247  		),
   248  		args: oneNum,
   249  	},
   250  }
   251  
   252  var (
   253  	twoLike   = []func() interface{}{likeArg(5), likeArg(5)}
   254  	threeLike = []func() interface{}{likeArg(5), likeArg(5), likeArg(3)}
   255  	oneNum    = []func() interface{}{num}
   256  	threeNum  = []func() interface{}{num, num, num}
   257  
   258  	binaryNumOps = []string{
   259  		"-",
   260  		"+",
   261  		"^",
   262  		"*",
   263  		"/",
   264  		"//",
   265  		"%",
   266  		"<<",
   267  		">>",
   268  		"&",
   269  		"#",
   270  		"|",
   271  	}
   272  	unaryNumOps = []string{
   273  		"-",
   274  		"~",
   275  	}
   276  	numTyps = []string{
   277  		"int8",
   278  		"float8",
   279  		"decimal",
   280  	}
   281  )
   282  
   283  func pass(s string) func() string {
   284  	return func() string {
   285  		return s
   286  	}
   287  }
   288  
   289  // fromSlice generates arguments for and executes fmt.Sprintf by randomly
   290  // selecting elements of args.
   291  func fromSlices(s string, args ...[]string) func() string {
   292  	return func() string {
   293  		gen := make([]interface{}, len(args))
   294  		for i, arg := range args {
   295  			gen[i] = arg[rand.Intn(len(arg))]
   296  		}
   297  		return fmt.Sprintf(s, gen...)
   298  	}
   299  }
   300  
   301  // num generates a random number (int or float64).
   302  func num() interface{} {
   303  	switch rand.Intn(6) {
   304  	case 1:
   305  		return 1
   306  	case 2:
   307  		return 2
   308  	case 3:
   309  		return -1
   310  	case 4:
   311  		return rand.Int() / (rand.Intn(10) + 1)
   312  	case 5:
   313  		return rand.NormFloat64()
   314  	default:
   315  		return 0
   316  	}
   317  }
   318  
   319  func likeArg(n int) func() interface{} {
   320  	return func() interface{} {
   321  		p := make([]byte, rng.Intn(n))
   322  		for i := range p {
   323  			switch rand.Intn(4) {
   324  			case 0:
   325  				p[i] = '_'
   326  			case 1:
   327  				p[i] = '%'
   328  			default:
   329  				p[i] = byte(1 + rng.Intn(127))
   330  			}
   331  		}
   332  		return string(p)
   333  	}
   334  }