github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmp-protocol/main.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 // cmp-protocol connects to postgres and cockroach servers and compares 12 // the binary and text pgwire encodings of SQL statements. Statements can 13 // be specified in arguments (./cmp-protocol "select 1" "select 2") or will 14 // be generated randomly until a difference is found. 15 package main 16 17 import ( 18 "bytes" 19 "context" 20 "flag" 21 "fmt" 22 "io" 23 "os" 24 "strings" 25 26 "github.com/cockroachdb/cockroach/pkg/cmd/cmp-protocol/pgconnect" 27 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" 28 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 29 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 30 "github.com/cockroachdb/cockroach/pkg/sql/types" 31 "github.com/cockroachdb/cockroach/pkg/util/randutil" 32 "github.com/cockroachdb/errors" 33 ) 34 35 var ( 36 pgAddr = flag.String("pg", "localhost:5432", "postgres address") 37 pgUser = flag.String("pg-user", "postgres", "postgres user") 38 crAddr = flag.String("cr", "localhost:26257", "cockroach address") 39 crUser = flag.String("cr-user", "root", "cockroach user") 40 ) 41 42 func main() { 43 flag.Parse() 44 45 stmtCh := make(chan string) 46 if args := os.Args[1:]; len(args) > 0 { 47 go func() { 48 for _, arg := range os.Args[1:] { 49 stmtCh <- arg 50 } 51 close(stmtCh) 52 }() 53 } else { 54 go func() { 55 rng, _ := randutil.NewPseudoRand() 56 for { 57 typ := sqlbase.RandType(rng) 58 sem := typ.Family() 59 switch sem { 60 case types.DecimalFamily, // trailing zeros differ, ok 61 types.CollatedStringFamily, // pg complains about utf8 62 types.OidFamily, // our 8-byte ints are usually out of range for pg 63 types.FloatFamily, // slight rounding differences at the end 64 types.TimestampTZFamily, // slight timezone differences 65 types.UnknownFamily, 66 // tested manually below: 67 types.ArrayFamily, 68 types.TupleFamily: 69 continue 70 } 71 datum := sqlbase.RandDatum(rng, typ, false /* null ok */) 72 if datum == tree.DNull { 73 continue 74 } 75 for _, format := range []string{ 76 "SELECT %s::%s;", 77 "SELECT ARRAY[%s::%s];", 78 "SELECT (%s::%s, NULL);", 79 } { 80 input := fmt.Sprintf(format, datum, pgTypeName(sem)) 81 stmtCh <- input 82 fmt.Printf("\nTYP: %v, DATUM: %v\n", sem, datum) 83 } 84 } 85 }() 86 } 87 88 for input := range stmtCh { 89 fmt.Println("INPUT", input) 90 if err := compare(os.Stdout, input, *pgAddr, *crAddr, *pgUser, *crUser); err != nil { 91 fmt.Fprintln(os.Stderr, "ERROR:", input) 92 fmt.Fprintf(os.Stderr, "%v\n", err) 93 } else { 94 fmt.Fprintln(os.Stderr, "OK", input) 95 } 96 } 97 } 98 99 func pgTypeName(sem types.Family) string { 100 switch sem { 101 case types.StringFamily: 102 return "TEXT" 103 case types.BytesFamily: 104 return "BYTEA" 105 case types.IntFamily: 106 return "INT8" 107 default: 108 return sem.String() 109 } 110 } 111 112 func compare(w io.Writer, input, pgAddr, crAddr, pgUser, crUser string) error { 113 ctx := context.Background() 114 for _, code := range []pgwirebase.FormatCode{ 115 pgwirebase.FormatText, 116 pgwirebase.FormatBinary, 117 } { 118 // https://github.com/cockroachdb/cockroach/issues/31847 119 if code == pgwirebase.FormatBinary && strings.HasPrefix(input, "SELECT (") { 120 continue 121 } 122 results := map[string][]byte{} 123 for _, s := range []struct { 124 user string 125 addr string 126 }{ 127 {user: pgUser, addr: pgAddr}, 128 {user: crUser, addr: crAddr}, 129 } { 130 user := s.user 131 addr := s.addr 132 res, err := pgconnect.Connect(ctx, input, addr, user, code) 133 if err != nil { 134 return errors.Wrapf(err, "addr: %s, code: %s", addr, code) 135 } 136 fmt.Printf("INPUT: %s, ADDR: %s, CODE: %s, res: %q, res: %v\n", input, addr, code, res, res) 137 for k, v := range results { 138 if !bytes.Equal(res, v) { 139 return errors.Errorf("format: %s\naddr: %s\nstr: %q\nbytes: %[3]v\n!=\naddr: %s\nstr: %q\nbytes: %[5]v\n", code, k, v, addr, res) 140 } 141 } 142 results[addr] = res 143 } 144 } 145 return nil 146 }