github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmp-protocol/main.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // cmp-protocol connects to postgres and cockroach servers and compares
    12  // the binary and text pgwire encodings of SQL statements. Statements can
    13  // be specified in arguments (./cmp-protocol "select 1" "select 2") or will
    14  // be generated randomly until a difference is found.
    15  package main
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"flag"
    21  	"fmt"
    22  	"io"
    23  	"os"
    24  	"strings"
    25  
    26  	"github.com/cockroachdb/cockroach/pkg/cmd/cmp-protocol/pgconnect"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    30  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    31  	"github.com/cockroachdb/cockroach/pkg/util/randutil"
    32  	"github.com/cockroachdb/errors"
    33  )
    34  
    35  var (
    36  	pgAddr = flag.String("pg", "localhost:5432", "postgres address")
    37  	pgUser = flag.String("pg-user", "postgres", "postgres user")
    38  	crAddr = flag.String("cr", "localhost:26257", "cockroach address")
    39  	crUser = flag.String("cr-user", "root", "cockroach user")
    40  )
    41  
    42  func main() {
    43  	flag.Parse()
    44  
    45  	stmtCh := make(chan string)
    46  	if args := os.Args[1:]; len(args) > 0 {
    47  		go func() {
    48  			for _, arg := range os.Args[1:] {
    49  				stmtCh <- arg
    50  			}
    51  			close(stmtCh)
    52  		}()
    53  	} else {
    54  		go func() {
    55  			rng, _ := randutil.NewPseudoRand()
    56  			for {
    57  				typ := sqlbase.RandType(rng)
    58  				sem := typ.Family()
    59  				switch sem {
    60  				case types.DecimalFamily, // trailing zeros differ, ok
    61  					types.CollatedStringFamily, // pg complains about utf8
    62  					types.OidFamily,            // our 8-byte ints are usually out of range for pg
    63  					types.FloatFamily,          // slight rounding differences at the end
    64  					types.TimestampTZFamily,    // slight timezone differences
    65  					types.UnknownFamily,
    66  					// tested manually below:
    67  					types.ArrayFamily,
    68  					types.TupleFamily:
    69  					continue
    70  				}
    71  				datum := sqlbase.RandDatum(rng, typ, false /* null ok */)
    72  				if datum == tree.DNull {
    73  					continue
    74  				}
    75  				for _, format := range []string{
    76  					"SELECT %s::%s;",
    77  					"SELECT ARRAY[%s::%s];",
    78  					"SELECT (%s::%s, NULL);",
    79  				} {
    80  					input := fmt.Sprintf(format, datum, pgTypeName(sem))
    81  					stmtCh <- input
    82  					fmt.Printf("\nTYP: %v, DATUM: %v\n", sem, datum)
    83  				}
    84  			}
    85  		}()
    86  	}
    87  
    88  	for input := range stmtCh {
    89  		fmt.Println("INPUT", input)
    90  		if err := compare(os.Stdout, input, *pgAddr, *crAddr, *pgUser, *crUser); err != nil {
    91  			fmt.Fprintln(os.Stderr, "ERROR:", input)
    92  			fmt.Fprintf(os.Stderr, "%v\n", err)
    93  		} else {
    94  			fmt.Fprintln(os.Stderr, "OK", input)
    95  		}
    96  	}
    97  }
    98  
    99  func pgTypeName(sem types.Family) string {
   100  	switch sem {
   101  	case types.StringFamily:
   102  		return "TEXT"
   103  	case types.BytesFamily:
   104  		return "BYTEA"
   105  	case types.IntFamily:
   106  		return "INT8"
   107  	default:
   108  		return sem.String()
   109  	}
   110  }
   111  
   112  func compare(w io.Writer, input, pgAddr, crAddr, pgUser, crUser string) error {
   113  	ctx := context.Background()
   114  	for _, code := range []pgwirebase.FormatCode{
   115  		pgwirebase.FormatText,
   116  		pgwirebase.FormatBinary,
   117  	} {
   118  		// https://github.com/cockroachdb/cockroach/issues/31847
   119  		if code == pgwirebase.FormatBinary && strings.HasPrefix(input, "SELECT (") {
   120  			continue
   121  		}
   122  		results := map[string][]byte{}
   123  		for _, s := range []struct {
   124  			user string
   125  			addr string
   126  		}{
   127  			{user: pgUser, addr: pgAddr},
   128  			{user: crUser, addr: crAddr},
   129  		} {
   130  			user := s.user
   131  			addr := s.addr
   132  			res, err := pgconnect.Connect(ctx, input, addr, user, code)
   133  			if err != nil {
   134  				return errors.Wrapf(err, "addr: %s, code: %s", addr, code)
   135  			}
   136  			fmt.Printf("INPUT: %s, ADDR: %s, CODE: %s, res: %q, res: %v\n", input, addr, code, res, res)
   137  			for k, v := range results {
   138  				if !bytes.Equal(res, v) {
   139  					return errors.Errorf("format: %s\naddr: %s\nstr: %q\nbytes: %[3]v\n!=\naddr: %s\nstr: %q\nbytes: %[5]v\n", code, k, v, addr, res)
   140  				}
   141  			}
   142  			results[addr] = res
   143  		}
   144  	}
   145  	return nil
   146  }