github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/agg/aggUt/agg_testutil.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggut 16 17 import ( 18 "testing" 19 20 "github.com/matrixorigin/matrixone/pkg/common/mpool" 21 "github.com/matrixorigin/matrixone/pkg/sql/colexec/agg" 22 23 "github.com/matrixorigin/matrixone/pkg/container/types" 24 "github.com/matrixorigin/matrixone/pkg/container/vector" 25 "github.com/matrixorigin/matrixone/pkg/testutil" 26 "github.com/stretchr/testify/require" 27 ) 28 29 type testCase struct { 30 // agg type 31 op int 32 isDistinct bool 33 hasDecimalResult bool 34 inputTyp types.Type 35 outputTyp types.Type 36 37 // test data for Fill() and Eval() 38 input any 39 inputNsp []uint64 40 expected any 41 42 // test data for Merge() 43 mergeInput any 44 mergeNsp []uint64 45 mergeExpect any 46 47 // test Marshal() and Unmarshal() or not 48 testMarshal bool 49 } 50 51 func RunTest(t *testing.T, testCases []testCase) { 52 for _, c := range testCases { 53 // update some parameter 54 switch c.op { 55 case agg.AggregateAvg, agg.AggregateVariance, agg.AggregateStdDevPop, agg.AggregateMedian: 56 c.hasDecimalResult = true 57 default: 58 c.hasDecimalResult = false 59 } 60 c.outputTyp, _ = agg.ReturnType(c.op, c.inputTyp) 61 62 RunBaseTest(t, &c) 63 if c.testMarshal { 64 RunBaseMarshalTest(t, &c) 65 } 66 } 67 } 68 69 func RunBaseTest(t *testing.T, c *testCase) { 70 // base test: Grows(), Fill(), Eval() and Merge() 71 m := mpool.MustNewZero() 72 73 // Grows(), Fill() and Eval() test 74 { 75 // New() 76 agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp) 77 require.NoError(t, newErr) 78 79 // Grows() 80 growsErr := agg0.Grows(1, m) 81 require.NoError(t, growsErr) 82 83 // Fill() 84 vec, l := GetVector(c.inputTyp, c.input, c.inputNsp) 85 if l > 0 && vec != nil { 86 for i := 0; i < l; i++ { 87 agg0.Fill(0, int64(i), 1, []*vector.Vector{vec}) 88 } 89 } 90 91 // Eval() 92 v, err := agg0.Eval(m) 93 require.NoError(t, err) 94 CompareResult(t, c.outputTyp, c.expected, v, c.hasDecimalResult) 95 if vec != nil { 96 vec.Free(m) 97 } 98 v.Free(m) 99 } 100 101 // Merge() Test 102 if c.mergeInput != nil { 103 // New() 104 agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp) 105 require.NoError(t, newErr) 106 107 // Grows() 108 growsErr := agg0.Grows(1, m) 109 require.NoError(t, growsErr) 110 111 // Fill() 112 vec, l := GetVector(c.inputTyp, c.input, c.inputNsp) 113 if l > 0 && vec != nil { 114 for i := 0; i < l; i++ { 115 agg0.Fill(0, int64(i), 1, []*vector.Vector{vec}) 116 } 117 } 118 119 // create another agg for merge 120 agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp) 121 agg1.Grows(1, m) 122 vec2, l2 := GetVector(c.inputTyp, c.mergeInput, c.inputNsp) 123 if l2 > 0 && vec2 != nil { 124 for i := 0; i < l2; i++ { 125 agg1.Fill(0, int64(i), 1, []*vector.Vector{vec2}) 126 } 127 } 128 129 // Merge() 130 agg0.Merge(agg1, 0, 0) 131 132 // Eval() 133 v, err := agg0.Eval(m) 134 require.NoError(t, err) 135 CompareResult(t, c.outputTyp, c.mergeExpect, v, c.hasDecimalResult) 136 137 // release 138 if vec != nil { 139 vec.Free(m) 140 } 141 if vec2 != nil { 142 vec2.Free(m) 143 } 144 v.Free(m) 145 } 146 //require.Equal(t, int64(0), m.Size()) 147 } 148 149 func RunBaseMarshalTest(t *testing.T, c *testCase) { 150 // base test: Grows(), Fill() and Eval() 151 m := mpool.MustNewZero() 152 { 153 // New() 154 agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp) 155 require.NoError(t, newErr) 156 157 // Grows() 158 growsErr := agg0.Grows(1, m) 159 require.NoError(t, growsErr) 160 161 // Fill() 162 vec, l := GetVector(c.inputTyp, c.input, c.inputNsp) 163 mid := l / 2 164 if mid > 0 && vec != nil { 165 for i := 0; i < mid; i++ { 166 agg0.Fill(0, int64(i), 1, []*vector.Vector{vec}) 167 } 168 } 169 170 // Marshal and Unmarshal() 171 d, marshalErr := agg0.MarshalBinary() 172 require.NoError(t, marshalErr) 173 agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp) 174 unmarshalErr := agg1.UnmarshalBinary(d) 175 require.NoError(t, unmarshalErr) 176 agg1.WildAggReAlloc(m) 177 178 // Fill() after marshal and unmarshal 179 if l > 0 && vec != nil { 180 for i := mid; i < l; i++ { 181 agg1.Fill(0, int64(i), 1, []*vector.Vector{vec}) 182 } 183 } 184 185 // Eval() after marshal and unmarshal 186 v, err := agg1.Eval(m) 187 require.NoError(t, err) 188 CompareResult(t, c.outputTyp, c.expected, v, c.hasDecimalResult) 189 if vec != nil { 190 vec.Free(m) 191 } 192 v.Free(m) 193 } 194 195 // Merge() Test 196 if c.mergeInput != nil { 197 // create an agg for marshal and unmarshal 198 agg0, _ := agg.New(c.op, c.isDistinct, c.inputTyp) 199 agg0.Grows(1, m) 200 vec, l := GetVector(c.inputTyp, c.input, c.inputNsp) 201 if l != 0 && vec != nil { 202 for i := 0; i < l; i++ { 203 agg0.Fill(0, int64(i), 1, []*vector.Vector{vec}) 204 } 205 } 206 207 // create another agg for merge 208 agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp) 209 agg1.Grows(1, m) 210 vec2, l2 := GetVector(c.inputTyp, c.mergeInput, c.mergeNsp) 211 if l2 != 0 && vec2 != nil { 212 for i := 0; i < l2; i++ { 213 agg1.Fill(0, int64(i), 1, []*vector.Vector{vec2}) 214 } 215 } 216 217 // Marshal and Unmarshal() 218 d, marshalErr := agg0.MarshalBinary() 219 require.NoError(t, marshalErr) 220 mAgg, _ := agg.New(c.op, c.isDistinct, c.inputTyp) 221 unmarshalErr := mAgg.UnmarshalBinary(d) 222 require.NoError(t, unmarshalErr) 223 mAgg.WildAggReAlloc(m) 224 225 // Merge() 226 mAgg.Merge(agg1, 0, 0) 227 228 // Eval() 229 v, err := mAgg.Eval(m) 230 require.NoError(t, err) 231 CompareResult(t, c.outputTyp, c.mergeExpect, v, c.hasDecimalResult) 232 if vec != nil { 233 vec.Free(m) 234 } 235 if vec2 != nil { 236 vec2.Free(m) 237 } 238 v.Free(m) 239 } 240 //require.Equal(t, int64(0), m.Size()) 241 } 242 243 func GetVector(typ types.Type, input any, nsp []uint64) (*vector.Vector, int) { 244 switch typ.Oid { 245 case types.T_bool: 246 return testutil.MakeBoolVector(input.([]bool)), len(input.([]bool)) 247 case types.T_int8: 248 return testutil.MakeInt8Vector(input.([]int8), nsp), len(input.([]int8)) 249 case types.T_int16: 250 return testutil.MakeInt16Vector(input.([]int16), nsp), len(input.([]int16)) 251 case types.T_int32: 252 return testutil.MakeInt32Vector(input.([]int32), nsp), len(input.([]int32)) 253 case types.T_int64: 254 return testutil.MakeInt64Vector(input.([]int64), nsp), len(input.([]int64)) 255 case types.T_uint8: 256 return testutil.MakeUint8Vector(input.([]uint8), nsp), len(input.([]uint8)) 257 case types.T_uint16: 258 return testutil.MakeUint16Vector(input.([]uint16), nsp), len(input.([]uint16)) 259 case types.T_uint32: 260 return testutil.MakeUint32Vector(input.([]uint32), nsp), len(input.([]uint32)) 261 case types.T_uint64: 262 return testutil.MakeUint64Vector(input.([]uint64), nsp), len(input.([]uint64)) 263 case types.T_float32: 264 return testutil.MakeFloat32Vector(input.([]float32), nsp), len(input.([]float32)) 265 case types.T_float64: 266 return testutil.MakeFloat64Vector(input.([]float64), nsp), len(input.([]float64)) 267 case types.T_char: 268 return testutil.MakeVarcharVector(input.([]string), nsp), len(input.([]string)) 269 case types.T_varchar: 270 return testutil.MakeVarcharVector(input.([]string), nsp), len(input.([]string)) 271 case types.T_date: 272 return testutil.MakeDateVector(input.([]string), nsp), len(input.([]string)) 273 case types.T_time: 274 return testutil.MakeTimeVector(input.([]string), nsp), len(input.([]string)) 275 case types.T_datetime: 276 return testutil.MakeDateTimeVector(input.([]string), nsp), len(input.([]string)) 277 case types.T_timestamp: 278 return testutil.MakeTimeStampVector(input.([]string), nsp), len(input.([]string)) 279 case types.T_decimal64: 280 return testutil.MakeDecimal64Vector(input.([]int64), nsp, typ), len(input.([]int64)) 281 case types.T_decimal128: 282 return testutil.MakeDecimal128Vector(input.([]int64), nsp, typ), len(input.([]int64)) 283 case types.T_uuid: 284 // Make vector by string. 285 // There is another function which can make uuid by uuid directly 286 return testutil.MakeUuidVectorByString(input.([]string), nsp), len(input.([]string)) 287 } 288 289 return nil, 0 290 } 291 292 func CompareResult(t *testing.T, typ types.Type, expected any, vec *vector.Vector, hasDecimalResult bool) bool { 293 switch typ.Oid { 294 case types.T_bool: 295 require.Equal(t, expected.([]bool), vector.GetColumn[bool](vec)) 296 case types.T_int8: 297 require.Equal(t, expected.([]int8), vector.GetColumn[int8](vec)) 298 case types.T_int16: 299 require.Equal(t, expected.([]int16), vector.GetColumn[int16](vec)) 300 case types.T_int32: 301 require.Equal(t, expected.([]int32), vector.GetColumn[int32](vec)) 302 case types.T_int64: 303 require.Equal(t, expected.([]int64), vector.GetColumn[int64](vec)) 304 case types.T_uint8: 305 require.Equal(t, expected.([]uint8), vector.GetColumn[uint8](vec)) 306 case types.T_uint16: 307 require.Equal(t, expected.([]uint16), vector.GetColumn[uint16](vec)) 308 case types.T_uint32: 309 require.Equal(t, expected.([]uint32), vector.GetColumn[uint32](vec)) 310 case types.T_uint64: 311 require.Equal(t, expected.([]uint64), vector.GetColumn[uint64](vec)) 312 case types.T_float32: 313 require.Equal(t, expected.([]float32), vector.GetColumn[float32](vec)) 314 case types.T_float64: 315 require.Equal(t, expected.([]float64), vector.GetColumn[float64](vec)) 316 case types.T_char: 317 require.Equal(t, expected.([]string), vector.GetStrColumn(vec)) 318 case types.T_varchar: 319 require.Equal(t, expected.([]string), vector.GetStrColumn(vec)) 320 case types.T_date: 321 require.Equal(t, expected.([]types.Date), vector.GetColumn[types.Date](vec)) 322 case types.T_time: 323 require.Equal(t, expected.([]types.Time), vector.GetColumn[types.Datetime](vec)) 324 case types.T_datetime: 325 require.Equal(t, expected.([]types.Datetime), vector.GetColumn[types.Datetime](vec)) 326 case types.T_timestamp: 327 require.Equal(t, expected.([]types.Timestamp), vector.GetColumn[types.Timestamp](vec)) 328 case types.T_decimal64: 329 if hasDecimalResult { 330 result := testutil.MakeDecimal64ArrByFloat64Arr(expected.([]float64)) 331 require.Equal(t, result, vector.GetColumn[types.Decimal64](vec)) 332 } else { 333 result := testutil.MakeDecimal64ArrByInt64Arr(expected.([]int64)) 334 require.Equal(t, result, vector.GetColumn[types.Decimal64](vec)) 335 } 336 case types.T_decimal128: 337 if hasDecimalResult { 338 result := testutil.MakeDecimal128ArrByFloat64Arr(expected.([]float64)) 339 require.Equal(t, result, vector.GetColumn[types.Decimal128](vec)) 340 } else { 341 result := testutil.MakeDecimal128ArrByInt64Arr(expected.([]int64)) 342 require.Equal(t, result, vector.GetColumn[types.Decimal128](vec)) 343 } 344 case types.T_uuid: 345 result := vector.GetColumn[types.Uuid](testutil.MakeUuidVectorByString(expected.([]string), nil)) 346 require.Equal(t, result, vector.GetColumn[types.Uuid](vec)) 347 default: 348 return false 349 } 350 return true 351 352 }