github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmpconn/compare.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package cmpconn
    12  
    13  import (
    14  	"math/big"
    15  	"strings"
    16  
    17  	"github.com/cockroachdb/apd"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    19  	"github.com/cockroachdb/cockroach/pkg/util/duration"
    20  	"github.com/cockroachdb/errors"
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/google/go-cmp/cmp/cmpopts"
    23  	"github.com/jackc/pgx/pgtype"
    24  )
    25  
    26  // CompareVals returns an error if a and b differ, specifying what the
    27  // difference is. It is designed to compare SQL results from a query
    28  // executed on two different servers or configurations (i.e., cockroach and
    29  // postgres). Postgres and Cockroach have subtle differences in their result
    30  // types and OIDs. This function is aware of those and is able to correctly
    31  // compare those values.
    32  func CompareVals(a, b []interface{}) error {
    33  	if len(a) != len(b) {
    34  		return errors.Errorf("size difference: %d != %d", len(a), len(b))
    35  	}
    36  	if len(a) == 0 {
    37  		return nil
    38  	}
    39  	if diff := cmp.Diff(a, b, cmpOptions...); diff != "" {
    40  		return errors.Newf("unexpected diff:\n%s", diff)
    41  	}
    42  	return nil
    43  }
    44  
    45  var (
    46  	cmpOptions = []cmp.Option{
    47  		cmp.Transformer("", func(x []interface{}) []interface{} {
    48  			out := make([]interface{}, len(x))
    49  			for i, v := range x {
    50  				switch t := v.(type) {
    51  				case *pgtype.TextArray:
    52  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    53  						v = ""
    54  					}
    55  				case *pgtype.BPCharArray:
    56  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    57  						v = ""
    58  					}
    59  				case *pgtype.VarcharArray:
    60  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    61  						v = ""
    62  					}
    63  				case *pgtype.Int8Array:
    64  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    65  						v = &pgtype.Int8Array{}
    66  					}
    67  				case *pgtype.Float8Array:
    68  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    69  						v = &pgtype.Float8Array{}
    70  					}
    71  				case *pgtype.UUIDArray:
    72  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    73  						v = &pgtype.UUIDArray{}
    74  					}
    75  				case *pgtype.ByteaArray:
    76  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    77  						v = &pgtype.ByteaArray{}
    78  					}
    79  				case *pgtype.InetArray:
    80  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    81  						v = &pgtype.InetArray{}
    82  					}
    83  				case *pgtype.TimestampArray:
    84  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    85  						v = &pgtype.TimestampArray{}
    86  					}
    87  				case *pgtype.BoolArray:
    88  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    89  						v = &pgtype.BoolArray{}
    90  					}
    91  				case *pgtype.DateArray:
    92  					if t.Status == pgtype.Present && len(t.Elements) == 0 {
    93  						v = &pgtype.BoolArray{}
    94  					}
    95  				case *pgtype.Varbit:
    96  					if t.Status == pgtype.Present {
    97  						s, _ := t.EncodeText(nil, nil)
    98  						v = string(s)
    99  					}
   100  				case *pgtype.Bit:
   101  					vb := pgtype.Varbit(*t)
   102  					v = &vb
   103  				case *pgtype.Interval:
   104  					if t.Status == pgtype.Present {
   105  						v = duration.DecodeDuration(int64(t.Months), int64(t.Days), t.Microseconds*1000)
   106  					}
   107  				case string:
   108  					// Postgres sometimes adds spaces to the end of a string.
   109  					t = strings.TrimSpace(t)
   110  					v = strings.Replace(t, "T00:00:00+00:00", "T00:00:00Z", 1)
   111  					v = strings.Replace(t, ":00+00:00", ":00", 1)
   112  				case *pgtype.Numeric:
   113  					if t.Status == pgtype.Present {
   114  						v = apd.NewWithBigInt(t.Int, t.Exp)
   115  					}
   116  				case int64:
   117  					v = apd.New(t, 0)
   118  				}
   119  				out[i] = v
   120  			}
   121  			return out
   122  		}),
   123  
   124  		cmpopts.EquateEmpty(),
   125  		cmpopts.EquateNaNs(),
   126  		cmpopts.EquateApprox(0.00001, 0),
   127  		cmp.Comparer(func(x, y *big.Int) bool {
   128  			return x.Cmp(y) == 0
   129  		}),
   130  		cmp.Comparer(func(x, y *apd.Decimal) bool {
   131  			var a, b, min, sub apd.Decimal
   132  			a.Abs(x)
   133  			b.Abs(y)
   134  			if a.Cmp(&b) > 1 {
   135  				min.Set(&b)
   136  			} else {
   137  				min.Set(&a)
   138  			}
   139  			ctx := tree.DecimalCtx
   140  			_, _ = ctx.Mul(&min, &min, decimalCloseness)
   141  			_, _ = ctx.Sub(&sub, x, y)
   142  			sub.Abs(&sub)
   143  			return sub.Cmp(&min) <= 0
   144  		}),
   145  		cmp.Comparer(func(x, y duration.Duration) bool {
   146  			return x.Compare(y) == 0
   147  		}),
   148  	}
   149  	decimalCloseness = apd.New(1, -6)
   150  )