github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmpconn/compare.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package cmpconn 12 13 import ( 14 "math/big" 15 "strings" 16 17 "github.com/cockroachdb/apd" 18 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 19 "github.com/cockroachdb/cockroach/pkg/util/duration" 20 "github.com/cockroachdb/errors" 21 "github.com/google/go-cmp/cmp" 22 "github.com/google/go-cmp/cmp/cmpopts" 23 "github.com/jackc/pgx/pgtype" 24 ) 25 26 // CompareVals returns an error if a and b differ, specifying what the 27 // difference is. It is designed to compare SQL results from a query 28 // executed on two different servers or configurations (i.e., cockroach and 29 // postgres). Postgres and Cockroach have subtle differences in their result 30 // types and OIDs. This function is aware of those and is able to correctly 31 // compare those values. 32 func CompareVals(a, b []interface{}) error { 33 if len(a) != len(b) { 34 return errors.Errorf("size difference: %d != %d", len(a), len(b)) 35 } 36 if len(a) == 0 { 37 return nil 38 } 39 if diff := cmp.Diff(a, b, cmpOptions...); diff != "" { 40 return errors.Newf("unexpected diff:\n%s", diff) 41 } 42 return nil 43 } 44 45 var ( 46 cmpOptions = []cmp.Option{ 47 cmp.Transformer("", func(x []interface{}) []interface{} { 48 out := make([]interface{}, len(x)) 49 for i, v := range x { 50 switch t := v.(type) { 51 case *pgtype.TextArray: 52 if t.Status == pgtype.Present && len(t.Elements) == 0 { 53 v = "" 54 } 55 case *pgtype.BPCharArray: 56 if t.Status == pgtype.Present && len(t.Elements) == 0 { 57 v = "" 58 } 59 case *pgtype.VarcharArray: 60 if t.Status == pgtype.Present && len(t.Elements) == 0 { 61 v = "" 62 } 63 case *pgtype.Int8Array: 64 if t.Status == pgtype.Present && len(t.Elements) == 0 { 65 v = &pgtype.Int8Array{} 66 } 67 case *pgtype.Float8Array: 68 if t.Status == pgtype.Present && len(t.Elements) == 0 { 69 v = &pgtype.Float8Array{} 70 } 71 case *pgtype.UUIDArray: 72 if t.Status == pgtype.Present && len(t.Elements) == 0 { 73 v = &pgtype.UUIDArray{} 74 } 75 case *pgtype.ByteaArray: 76 if t.Status == pgtype.Present && len(t.Elements) == 0 { 77 v = &pgtype.ByteaArray{} 78 } 79 case *pgtype.InetArray: 80 if t.Status == pgtype.Present && len(t.Elements) == 0 { 81 v = &pgtype.InetArray{} 82 } 83 case *pgtype.TimestampArray: 84 if t.Status == pgtype.Present && len(t.Elements) == 0 { 85 v = &pgtype.TimestampArray{} 86 } 87 case *pgtype.BoolArray: 88 if t.Status == pgtype.Present && len(t.Elements) == 0 { 89 v = &pgtype.BoolArray{} 90 } 91 case *pgtype.DateArray: 92 if t.Status == pgtype.Present && len(t.Elements) == 0 { 93 v = &pgtype.BoolArray{} 94 } 95 case *pgtype.Varbit: 96 if t.Status == pgtype.Present { 97 s, _ := t.EncodeText(nil, nil) 98 v = string(s) 99 } 100 case *pgtype.Bit: 101 vb := pgtype.Varbit(*t) 102 v = &vb 103 case *pgtype.Interval: 104 if t.Status == pgtype.Present { 105 v = duration.DecodeDuration(int64(t.Months), int64(t.Days), t.Microseconds*1000) 106 } 107 case string: 108 // Postgres sometimes adds spaces to the end of a string. 109 t = strings.TrimSpace(t) 110 v = strings.Replace(t, "T00:00:00+00:00", "T00:00:00Z", 1) 111 v = strings.Replace(t, ":00+00:00", ":00", 1) 112 case *pgtype.Numeric: 113 if t.Status == pgtype.Present { 114 v = apd.NewWithBigInt(t.Int, t.Exp) 115 } 116 case int64: 117 v = apd.New(t, 0) 118 } 119 out[i] = v 120 } 121 return out 122 }), 123 124 cmpopts.EquateEmpty(), 125 cmpopts.EquateNaNs(), 126 cmpopts.EquateApprox(0.00001, 0), 127 cmp.Comparer(func(x, y *big.Int) bool { 128 return x.Cmp(y) == 0 129 }), 130 cmp.Comparer(func(x, y *apd.Decimal) bool { 131 var a, b, min, sub apd.Decimal 132 a.Abs(x) 133 b.Abs(y) 134 if a.Cmp(&b) > 1 { 135 min.Set(&b) 136 } else { 137 min.Set(&a) 138 } 139 ctx := tree.DecimalCtx 140 _, _ = ctx.Mul(&min, &min, decimalCloseness) 141 _, _ = ctx.Sub(&sub, x, y) 142 sub.Abs(&sub) 143 return sub.Cmp(&min) <= 0 144 }), 145 cmp.Comparer(func(x, y duration.Duration) bool { 146 return x.Compare(y) == 0 147 }), 148 } 149 decimalCloseness = apd.New(1, -6) 150 )