github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/col/coldata/testutils.go (about) 1 // Copyright 2020 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package coldata 12 13 import ( 14 "bytes" 15 "fmt" 16 17 "github.com/cockroachdb/cockroachdb-parser/pkg/col/typeconv" 18 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/types" 19 "github.com/stretchr/testify/require" 20 ) 21 22 // testingT is a private interface that mirrors the testing.TB methods used. 23 // testing.TB cannot be used directly since testing is an illegal import. 24 // TODO(asubiotto): Remove AssertEquivalentBatches' dependency on testing.TB by 25 // 26 // checking for equality and returning a diff string instead of operating on 27 // testing.TB. 28 type testingT interface { 29 Helper() 30 Errorf(format string, args ...interface{}) 31 Fatal(args ...interface{}) 32 Fatalf(format string, args ...interface{}) 33 FailNow() 34 } 35 36 // AssertEquivalentBatches is a testing function that asserts that expected and 37 // actual are equivalent. 38 func AssertEquivalentBatches(t testingT, expected, actual Batch) { 39 t.Helper() 40 41 if actual.Selection() != nil { 42 t.Fatal("violated invariant that batches have no selection vectors") 43 } 44 require.Equal(t, expected.Length(), actual.Length()) 45 if expected.Length() == 0 { 46 // The schema of a zero-length batch is undefined, so the rest of the check 47 // is not required. 48 return 49 } 50 require.Equal(t, expected.Width(), actual.Width()) 51 for colIdx := 0; colIdx < expected.Width(); colIdx++ { 52 // Verify equality of ColVecs (this includes nulls). Since the coldata.Vec 53 // backing array is always of coldata.BatchSize() due to the scratch batch 54 // that the converter keeps around, the coldata.Vec needs to be sliced to 55 // the first length elements to match on length, otherwise the check will 56 // fail. 57 expectedVec := expected.ColVec(colIdx) 58 actualVec := actual.ColVec(colIdx) 59 require.Equal(t, expectedVec.Type(), actualVec.Type()) 60 // Check whether the nulls bitmaps are the same. Note that we don't 61 // track precisely the fact whether nulls are present or not in 62 // 'maybeHasNulls' field, so we override it manually to be 'true' for 63 // both nulls vectors if it is 'true' for at least one of them. This is 64 // acceptable since we still check the bitmaps precisely. 65 expectedNulls := expectedVec.Nulls() 66 actualNulls := actualVec.Nulls() 67 oldExpMaybeHasNulls, oldActMaybeHasNulls := expectedNulls.maybeHasNulls, actualNulls.maybeHasNulls 68 defer func() { 69 expectedNulls.maybeHasNulls, actualNulls.maybeHasNulls = oldExpMaybeHasNulls, oldActMaybeHasNulls 70 }() 71 expectedNulls.maybeHasNulls = expectedNulls.maybeHasNulls || actualNulls.maybeHasNulls 72 actualNulls.maybeHasNulls = expectedNulls.maybeHasNulls || actualNulls.maybeHasNulls 73 require.Equal(t, expectedNulls.Slice(0, expected.Length()), actualNulls.Slice(0, actual.Length())) 74 75 canonicalTypeFamily := expectedVec.CanonicalTypeFamily() 76 if canonicalTypeFamily == types.BytesFamily { 77 expectedBytes := expectedVec.Bytes().Window(0, expected.Length()) 78 resultBytes := actualVec.Bytes().Window(0, actual.Length()) 79 require.Equal(t, expectedBytes.Len(), resultBytes.Len()) 80 for i := 0; i < expectedBytes.Len(); i++ { 81 if !expectedNulls.NullAt(i) { 82 if !bytes.Equal(expectedBytes.Get(i), resultBytes.Get(i)) { 83 t.Fatalf("bytes mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedBytes, resultBytes) 84 } 85 } 86 } 87 } else if canonicalTypeFamily == types.DecimalFamily { 88 expectedDecimal := expectedVec.Decimal()[0:expected.Length()] 89 resultDecimal := actualVec.Decimal()[0:actual.Length()] 90 require.Equal(t, len(expectedDecimal), len(resultDecimal)) 91 for i := range expectedDecimal { 92 if !expectedNulls.NullAt(i) { 93 if expectedDecimal[i].Cmp(&resultDecimal[i]) != 0 { 94 t.Fatalf("Decimal mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, &expectedDecimal[i], &resultDecimal[i]) 95 } 96 } 97 } 98 } else if canonicalTypeFamily == types.TimestampTZFamily { 99 expectedTimestamp := expectedVec.Timestamp()[0:expected.Length()] 100 resultTimestamp := actualVec.Timestamp()[0:actual.Length()] 101 require.Equal(t, len(expectedTimestamp), len(resultTimestamp)) 102 for i := range expectedTimestamp { 103 if !expectedNulls.NullAt(i) { 104 if !expectedTimestamp[i].Equal(resultTimestamp[i]) { 105 t.Fatalf("Timestamp mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedTimestamp[i], resultTimestamp[i]) 106 } 107 } 108 } 109 } else if canonicalTypeFamily == types.IntervalFamily { 110 expectedInterval := expectedVec.Interval()[0:expected.Length()] 111 resultInterval := actualVec.Interval()[0:actual.Length()] 112 require.Equal(t, len(expectedInterval), len(resultInterval)) 113 for i := range expectedInterval { 114 if !expectedNulls.NullAt(i) { 115 if expectedInterval[i].Compare(resultInterval[i]) != 0 { 116 t.Fatalf("Interval mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedInterval[i], resultInterval[i]) 117 } 118 } 119 } 120 } else if canonicalTypeFamily == types.JsonFamily { 121 expectedJSON := expectedVec.JSON().Window(0, expected.Length()) 122 resultJSON := actualVec.JSON().Window(0, actual.Length()) 123 require.Equal(t, expectedJSON.Len(), resultJSON.Len()) 124 for i := 0; i < expectedJSON.Len(); i++ { 125 if !expectedNulls.NullAt(i) { 126 cmp, err := expectedJSON.Get(i).Compare(resultJSON.Get(i)) 127 if err != nil { 128 t.Fatal(err) 129 } 130 if cmp != 0 { 131 t.Fatalf("json mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedJSON, resultJSON) 132 } 133 } 134 } 135 } else if canonicalTypeFamily == typeconv.DatumVecCanonicalTypeFamily { 136 expectedDatum := expectedVec.Datum().Window(0 /* start */, expected.Length()) 137 resultDatum := actualVec.Datum().Window(0 /* start */, actual.Length()) 138 require.Equal(t, expectedDatum.Len(), resultDatum.Len()) 139 for i := 0; i < expectedDatum.Len(); i++ { 140 if !expectedNulls.NullAt(i) { 141 expected := expectedDatum.Get(i).(fmt.Stringer).String() 142 actual := resultDatum.Get(i).(fmt.Stringer).String() 143 if expected != actual { 144 t.Fatalf("Datum mismatch at index %d:\nexpected:\n%s\nactual:\n%s", i, expectedDatum.Get(i), resultDatum.Get(i)) 145 } 146 } 147 } 148 } else { 149 require.Equal( 150 t, 151 expectedVec.Window(0, expected.Length()), 152 actualVec.Window(0, actual.Length()), 153 ) 154 } 155 } 156 }