github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/col/coldata/testutils.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package coldata
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  
    17  	"github.com/cockroachdb/cockroachdb-parser/pkg/col/typeconv"
    18  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/types"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  // testingT is a private interface that mirrors the testing.TB methods used.
    23  // testing.TB cannot be used directly since testing is an illegal import.
    24  // TODO(asubiotto): Remove AssertEquivalentBatches' dependency on testing.TB by
    25  //
    26  //	checking for equality and returning a diff string instead of operating on
    27  //	testing.TB.
    28  type testingT interface {
    29  	Helper()
    30  	Errorf(format string, args ...interface{})
    31  	Fatal(args ...interface{})
    32  	Fatalf(format string, args ...interface{})
    33  	FailNow()
    34  }
    35  
    36  // AssertEquivalentBatches is a testing function that asserts that expected and
    37  // actual are equivalent.
    38  func AssertEquivalentBatches(t testingT, expected, actual Batch) {
    39  	t.Helper()
    40  
    41  	if actual.Selection() != nil {
    42  		t.Fatal("violated invariant that batches have no selection vectors")
    43  	}
    44  	require.Equal(t, expected.Length(), actual.Length())
    45  	if expected.Length() == 0 {
    46  		// The schema of a zero-length batch is undefined, so the rest of the check
    47  		// is not required.
    48  		return
    49  	}
    50  	require.Equal(t, expected.Width(), actual.Width())
    51  	for colIdx := 0; colIdx < expected.Width(); colIdx++ {
    52  		// Verify equality of ColVecs (this includes nulls). Since the coldata.Vec
    53  		// backing array is always of coldata.BatchSize() due to the scratch batch
    54  		// that the converter keeps around, the coldata.Vec needs to be sliced to
    55  		// the first length elements to match on length, otherwise the check will
    56  		// fail.
    57  		expectedVec := expected.ColVec(colIdx)
    58  		actualVec := actual.ColVec(colIdx)
    59  		require.Equal(t, expectedVec.Type(), actualVec.Type())
    60  		// Check whether the nulls bitmaps are the same. Note that we don't
    61  		// track precisely the fact whether nulls are present or not in
    62  		// 'maybeHasNulls' field, so we override it manually to be 'true' for
    63  		// both nulls vectors if it is 'true' for at least one of them. This is
    64  		// acceptable since we still check the bitmaps precisely.
    65  		expectedNulls := expectedVec.Nulls()
    66  		actualNulls := actualVec.Nulls()
    67  		oldExpMaybeHasNulls, oldActMaybeHasNulls := expectedNulls.maybeHasNulls, actualNulls.maybeHasNulls
    68  		defer func() {
    69  			expectedNulls.maybeHasNulls, actualNulls.maybeHasNulls = oldExpMaybeHasNulls, oldActMaybeHasNulls
    70  		}()
    71  		expectedNulls.maybeHasNulls = expectedNulls.maybeHasNulls || actualNulls.maybeHasNulls
    72  		actualNulls.maybeHasNulls = expectedNulls.maybeHasNulls || actualNulls.maybeHasNulls
    73  		require.Equal(t, expectedNulls.Slice(0, expected.Length()), actualNulls.Slice(0, actual.Length()))
    74  
    75  		canonicalTypeFamily := expectedVec.CanonicalTypeFamily()
    76  		if canonicalTypeFamily == types.BytesFamily {
    77  			expectedBytes := expectedVec.Bytes().Window(0, expected.Length())
    78  			resultBytes := actualVec.Bytes().Window(0, actual.Length())
    79  			require.Equal(t, expectedBytes.Len(), resultBytes.Len())
    80  			for i := 0; i < expectedBytes.Len(); i++ {
    81  				if !expectedNulls.NullAt(i) {
    82  					if !bytes.Equal(expectedBytes.Get(i), resultBytes.Get(i)) {
    83  						t.Fatalf("bytes mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedBytes, resultBytes)
    84  					}
    85  				}
    86  			}
    87  		} else if canonicalTypeFamily == types.DecimalFamily {
    88  			expectedDecimal := expectedVec.Decimal()[0:expected.Length()]
    89  			resultDecimal := actualVec.Decimal()[0:actual.Length()]
    90  			require.Equal(t, len(expectedDecimal), len(resultDecimal))
    91  			for i := range expectedDecimal {
    92  				if !expectedNulls.NullAt(i) {
    93  					if expectedDecimal[i].Cmp(&resultDecimal[i]) != 0 {
    94  						t.Fatalf("Decimal mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, &expectedDecimal[i], &resultDecimal[i])
    95  					}
    96  				}
    97  			}
    98  		} else if canonicalTypeFamily == types.TimestampTZFamily {
    99  			expectedTimestamp := expectedVec.Timestamp()[0:expected.Length()]
   100  			resultTimestamp := actualVec.Timestamp()[0:actual.Length()]
   101  			require.Equal(t, len(expectedTimestamp), len(resultTimestamp))
   102  			for i := range expectedTimestamp {
   103  				if !expectedNulls.NullAt(i) {
   104  					if !expectedTimestamp[i].Equal(resultTimestamp[i]) {
   105  						t.Fatalf("Timestamp mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedTimestamp[i], resultTimestamp[i])
   106  					}
   107  				}
   108  			}
   109  		} else if canonicalTypeFamily == types.IntervalFamily {
   110  			expectedInterval := expectedVec.Interval()[0:expected.Length()]
   111  			resultInterval := actualVec.Interval()[0:actual.Length()]
   112  			require.Equal(t, len(expectedInterval), len(resultInterval))
   113  			for i := range expectedInterval {
   114  				if !expectedNulls.NullAt(i) {
   115  					if expectedInterval[i].Compare(resultInterval[i]) != 0 {
   116  						t.Fatalf("Interval mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedInterval[i], resultInterval[i])
   117  					}
   118  				}
   119  			}
   120  		} else if canonicalTypeFamily == types.JsonFamily {
   121  			expectedJSON := expectedVec.JSON().Window(0, expected.Length())
   122  			resultJSON := actualVec.JSON().Window(0, actual.Length())
   123  			require.Equal(t, expectedJSON.Len(), resultJSON.Len())
   124  			for i := 0; i < expectedJSON.Len(); i++ {
   125  				if !expectedNulls.NullAt(i) {
   126  					cmp, err := expectedJSON.Get(i).Compare(resultJSON.Get(i))
   127  					if err != nil {
   128  						t.Fatal(err)
   129  					}
   130  					if cmp != 0 {
   131  						t.Fatalf("json mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedJSON, resultJSON)
   132  					}
   133  				}
   134  			}
   135  		} else if canonicalTypeFamily == typeconv.DatumVecCanonicalTypeFamily {
   136  			expectedDatum := expectedVec.Datum().Window(0 /* start */, expected.Length())
   137  			resultDatum := actualVec.Datum().Window(0 /* start */, actual.Length())
   138  			require.Equal(t, expectedDatum.Len(), resultDatum.Len())
   139  			for i := 0; i < expectedDatum.Len(); i++ {
   140  				if !expectedNulls.NullAt(i) {
   141  					expected := expectedDatum.Get(i).(fmt.Stringer).String()
   142  					actual := resultDatum.Get(i).(fmt.Stringer).String()
   143  					if expected != actual {
   144  						t.Fatalf("Datum mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedDatum.Get(i), resultDatum.Get(i))
   145  					}
   146  				}
   147  			}
   148  		} else {
   149  			require.Equal(
   150  				t,
   151  				expectedVec.Window(0, expected.Length()),
   152  				actualVec.Window(0, actual.Length()),
   153  			)
   154  		}
   155  	}
   156  }