github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/agg/aggUt/agg_testutil.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggut
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    21  	"github.com/matrixorigin/matrixone/pkg/sql/colexec/agg"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/container/types"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/testutil"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  type testCase struct {
    30  	// agg type
    31  	op               int
    32  	isDistinct       bool
    33  	hasDecimalResult bool
    34  	inputTyp         types.Type
    35  	outputTyp        types.Type
    36  
    37  	// test data for Fill() and Eval()
    38  	input    any
    39  	inputNsp []uint64
    40  	expected any
    41  
    42  	// test data for Merge()
    43  	mergeInput  any
    44  	mergeNsp    []uint64
    45  	mergeExpect any
    46  
    47  	// test Marshal() and Unmarshal() or not
    48  	testMarshal bool
    49  }
    50  
    51  func RunTest(t *testing.T, testCases []testCase) {
    52  	for _, c := range testCases {
    53  		// update some parameter
    54  		switch c.op {
    55  		case agg.AggregateAvg, agg.AggregateVariance, agg.AggregateStdDevPop, agg.AggregateMedian:
    56  			c.hasDecimalResult = true
    57  		default:
    58  			c.hasDecimalResult = false
    59  		}
    60  		c.outputTyp, _ = agg.ReturnType(c.op, c.inputTyp)
    61  
    62  		RunBaseTest(t, &c)
    63  		if c.testMarshal {
    64  			RunBaseMarshalTest(t, &c)
    65  		}
    66  	}
    67  }
    68  
    69  func RunBaseTest(t *testing.T, c *testCase) {
    70  	// base test: Grows(), Fill(), Eval() and Merge()
    71  	m := mpool.MustNewZero()
    72  
    73  	// Grows(), Fill() and Eval() test
    74  	{
    75  		// New()
    76  		agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp)
    77  		require.NoError(t, newErr)
    78  
    79  		// Grows()
    80  		growsErr := agg0.Grows(1, m)
    81  		require.NoError(t, growsErr)
    82  
    83  		// Fill()
    84  		vec, l := GetVector(c.inputTyp, c.input, c.inputNsp)
    85  		if l > 0 && vec != nil {
    86  			for i := 0; i < l; i++ {
    87  				agg0.Fill(0, int64(i), 1, []*vector.Vector{vec})
    88  			}
    89  		}
    90  
    91  		// Eval()
    92  		v, err := agg0.Eval(m)
    93  		require.NoError(t, err)
    94  		CompareResult(t, c.outputTyp, c.expected, v, c.hasDecimalResult)
    95  		if vec != nil {
    96  			vec.Free(m)
    97  		}
    98  		v.Free(m)
    99  	}
   100  
   101  	// Merge() Test
   102  	if c.mergeInput != nil {
   103  		// New()
   104  		agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp)
   105  		require.NoError(t, newErr)
   106  
   107  		// Grows()
   108  		growsErr := agg0.Grows(1, m)
   109  		require.NoError(t, growsErr)
   110  
   111  		// Fill()
   112  		vec, l := GetVector(c.inputTyp, c.input, c.inputNsp)
   113  		if l > 0 && vec != nil {
   114  			for i := 0; i < l; i++ {
   115  				agg0.Fill(0, int64(i), 1, []*vector.Vector{vec})
   116  			}
   117  		}
   118  
   119  		// create another agg for merge
   120  		agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp)
   121  		agg1.Grows(1, m)
   122  		vec2, l2 := GetVector(c.inputTyp, c.mergeInput, c.inputNsp)
   123  		if l2 > 0 && vec2 != nil {
   124  			for i := 0; i < l2; i++ {
   125  				agg1.Fill(0, int64(i), 1, []*vector.Vector{vec2})
   126  			}
   127  		}
   128  
   129  		// Merge()
   130  		agg0.Merge(agg1, 0, 0)
   131  
   132  		// Eval()
   133  		v, err := agg0.Eval(m)
   134  		require.NoError(t, err)
   135  		CompareResult(t, c.outputTyp, c.mergeExpect, v, c.hasDecimalResult)
   136  
   137  		// release
   138  		if vec != nil {
   139  			vec.Free(m)
   140  		}
   141  		if vec2 != nil {
   142  			vec2.Free(m)
   143  		}
   144  		v.Free(m)
   145  	}
   146  	//require.Equal(t, int64(0), m.Size())
   147  }
   148  
   149  func RunBaseMarshalTest(t *testing.T, c *testCase) {
   150  	// base test: Grows(), Fill() and Eval()
   151  	m := mpool.MustNewZero()
   152  	{
   153  		// New()
   154  		agg0, newErr := agg.New(c.op, c.isDistinct, c.inputTyp)
   155  		require.NoError(t, newErr)
   156  
   157  		// Grows()
   158  		growsErr := agg0.Grows(1, m)
   159  		require.NoError(t, growsErr)
   160  
   161  		// Fill()
   162  		vec, l := GetVector(c.inputTyp, c.input, c.inputNsp)
   163  		mid := l / 2
   164  		if mid > 0 && vec != nil {
   165  			for i := 0; i < mid; i++ {
   166  				agg0.Fill(0, int64(i), 1, []*vector.Vector{vec})
   167  			}
   168  		}
   169  
   170  		// Marshal and Unmarshal()
   171  		d, marshalErr := agg0.MarshalBinary()
   172  		require.NoError(t, marshalErr)
   173  		agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp)
   174  		unmarshalErr := agg1.UnmarshalBinary(d)
   175  		require.NoError(t, unmarshalErr)
   176  		agg1.WildAggReAlloc(m)
   177  
   178  		// Fill() after marshal and unmarshal
   179  		if l > 0 && vec != nil {
   180  			for i := mid; i < l; i++ {
   181  				agg1.Fill(0, int64(i), 1, []*vector.Vector{vec})
   182  			}
   183  		}
   184  
   185  		// Eval() after marshal and unmarshal
   186  		v, err := agg1.Eval(m)
   187  		require.NoError(t, err)
   188  		CompareResult(t, c.outputTyp, c.expected, v, c.hasDecimalResult)
   189  		if vec != nil {
   190  			vec.Free(m)
   191  		}
   192  		v.Free(m)
   193  	}
   194  
   195  	// Merge() Test
   196  	if c.mergeInput != nil {
   197  		// create an agg for marshal and unmarshal
   198  		agg0, _ := agg.New(c.op, c.isDistinct, c.inputTyp)
   199  		agg0.Grows(1, m)
   200  		vec, l := GetVector(c.inputTyp, c.input, c.inputNsp)
   201  		if l != 0 && vec != nil {
   202  			for i := 0; i < l; i++ {
   203  				agg0.Fill(0, int64(i), 1, []*vector.Vector{vec})
   204  			}
   205  		}
   206  
   207  		// create another agg for merge
   208  		agg1, _ := agg.New(c.op, c.isDistinct, c.inputTyp)
   209  		agg1.Grows(1, m)
   210  		vec2, l2 := GetVector(c.inputTyp, c.mergeInput, c.mergeNsp)
   211  		if l2 != 0 && vec2 != nil {
   212  			for i := 0; i < l2; i++ {
   213  				agg1.Fill(0, int64(i), 1, []*vector.Vector{vec2})
   214  			}
   215  		}
   216  
   217  		// Marshal and Unmarshal()
   218  		d, marshalErr := agg0.MarshalBinary()
   219  		require.NoError(t, marshalErr)
   220  		mAgg, _ := agg.New(c.op, c.isDistinct, c.inputTyp)
   221  		unmarshalErr := mAgg.UnmarshalBinary(d)
   222  		require.NoError(t, unmarshalErr)
   223  		mAgg.WildAggReAlloc(m)
   224  
   225  		// Merge()
   226  		mAgg.Merge(agg1, 0, 0)
   227  
   228  		// Eval()
   229  		v, err := mAgg.Eval(m)
   230  		require.NoError(t, err)
   231  		CompareResult(t, c.outputTyp, c.mergeExpect, v, c.hasDecimalResult)
   232  		if vec != nil {
   233  			vec.Free(m)
   234  		}
   235  		if vec2 != nil {
   236  			vec2.Free(m)
   237  		}
   238  		v.Free(m)
   239  	}
   240  	//require.Equal(t, int64(0), m.Size())
   241  }
   242  
   243  func GetVector(typ types.Type, input any, nsp []uint64) (*vector.Vector, int) {
   244  	switch typ.Oid {
   245  	case types.T_bool:
   246  		return testutil.MakeBoolVector(input.([]bool)), len(input.([]bool))
   247  	case types.T_int8:
   248  		return testutil.MakeInt8Vector(input.([]int8), nsp), len(input.([]int8))
   249  	case types.T_int16:
   250  		return testutil.MakeInt16Vector(input.([]int16), nsp), len(input.([]int16))
   251  	case types.T_int32:
   252  		return testutil.MakeInt32Vector(input.([]int32), nsp), len(input.([]int32))
   253  	case types.T_int64:
   254  		return testutil.MakeInt64Vector(input.([]int64), nsp), len(input.([]int64))
   255  	case types.T_uint8:
   256  		return testutil.MakeUint8Vector(input.([]uint8), nsp), len(input.([]uint8))
   257  	case types.T_uint16:
   258  		return testutil.MakeUint16Vector(input.([]uint16), nsp), len(input.([]uint16))
   259  	case types.T_uint32:
   260  		return testutil.MakeUint32Vector(input.([]uint32), nsp), len(input.([]uint32))
   261  	case types.T_uint64:
   262  		return testutil.MakeUint64Vector(input.([]uint64), nsp), len(input.([]uint64))
   263  	case types.T_float32:
   264  		return testutil.MakeFloat32Vector(input.([]float32), nsp), len(input.([]float32))
   265  	case types.T_float64:
   266  		return testutil.MakeFloat64Vector(input.([]float64), nsp), len(input.([]float64))
   267  	case types.T_char:
   268  		return testutil.MakeVarcharVector(input.([]string), nsp), len(input.([]string))
   269  	case types.T_varchar:
   270  		return testutil.MakeVarcharVector(input.([]string), nsp), len(input.([]string))
   271  	case types.T_date:
   272  		return testutil.MakeDateVector(input.([]string), nsp), len(input.([]string))
   273  	case types.T_time:
   274  		return testutil.MakeTimeVector(input.([]string), nsp), len(input.([]string))
   275  	case types.T_datetime:
   276  		return testutil.MakeDateTimeVector(input.([]string), nsp), len(input.([]string))
   277  	case types.T_timestamp:
   278  		return testutil.MakeTimeStampVector(input.([]string), nsp), len(input.([]string))
   279  	case types.T_decimal64:
   280  		return testutil.MakeDecimal64Vector(input.([]int64), nsp, typ), len(input.([]int64))
   281  	case types.T_decimal128:
   282  		return testutil.MakeDecimal128Vector(input.([]int64), nsp, typ), len(input.([]int64))
   283  	case types.T_uuid:
   284  		// Make vector by string.
   285  		// There is another function which can make uuid by uuid directly
   286  		return testutil.MakeUuidVectorByString(input.([]string), nsp), len(input.([]string))
   287  	}
   288  
   289  	return nil, 0
   290  }
   291  
   292  func CompareResult(t *testing.T, typ types.Type, expected any, vec *vector.Vector, hasDecimalResult bool) bool {
   293  	switch typ.Oid {
   294  	case types.T_bool:
   295  		require.Equal(t, expected.([]bool), vector.GetColumn[bool](vec))
   296  	case types.T_int8:
   297  		require.Equal(t, expected.([]int8), vector.GetColumn[int8](vec))
   298  	case types.T_int16:
   299  		require.Equal(t, expected.([]int16), vector.GetColumn[int16](vec))
   300  	case types.T_int32:
   301  		require.Equal(t, expected.([]int32), vector.GetColumn[int32](vec))
   302  	case types.T_int64:
   303  		require.Equal(t, expected.([]int64), vector.GetColumn[int64](vec))
   304  	case types.T_uint8:
   305  		require.Equal(t, expected.([]uint8), vector.GetColumn[uint8](vec))
   306  	case types.T_uint16:
   307  		require.Equal(t, expected.([]uint16), vector.GetColumn[uint16](vec))
   308  	case types.T_uint32:
   309  		require.Equal(t, expected.([]uint32), vector.GetColumn[uint32](vec))
   310  	case types.T_uint64:
   311  		require.Equal(t, expected.([]uint64), vector.GetColumn[uint64](vec))
   312  	case types.T_float32:
   313  		require.Equal(t, expected.([]float32), vector.GetColumn[float32](vec))
   314  	case types.T_float64:
   315  		require.Equal(t, expected.([]float64), vector.GetColumn[float64](vec))
   316  	case types.T_char:
   317  		require.Equal(t, expected.([]string), vector.GetStrColumn(vec))
   318  	case types.T_varchar:
   319  		require.Equal(t, expected.([]string), vector.GetStrColumn(vec))
   320  	case types.T_date:
   321  		require.Equal(t, expected.([]types.Date), vector.GetColumn[types.Date](vec))
   322  	case types.T_time:
   323  		require.Equal(t, expected.([]types.Time), vector.GetColumn[types.Datetime](vec))
   324  	case types.T_datetime:
   325  		require.Equal(t, expected.([]types.Datetime), vector.GetColumn[types.Datetime](vec))
   326  	case types.T_timestamp:
   327  		require.Equal(t, expected.([]types.Timestamp), vector.GetColumn[types.Timestamp](vec))
   328  	case types.T_decimal64:
   329  		if hasDecimalResult {
   330  			result := testutil.MakeDecimal64ArrByFloat64Arr(expected.([]float64))
   331  			require.Equal(t, result, vector.GetColumn[types.Decimal64](vec))
   332  		} else {
   333  			result := testutil.MakeDecimal64ArrByInt64Arr(expected.([]int64))
   334  			require.Equal(t, result, vector.GetColumn[types.Decimal64](vec))
   335  		}
   336  	case types.T_decimal128:
   337  		if hasDecimalResult {
   338  			result := testutil.MakeDecimal128ArrByFloat64Arr(expected.([]float64))
   339  			require.Equal(t, result, vector.GetColumn[types.Decimal128](vec))
   340  		} else {
   341  			result := testutil.MakeDecimal128ArrByInt64Arr(expected.([]int64))
   342  			require.Equal(t, result, vector.GetColumn[types.Decimal128](vec))
   343  		}
   344  	case types.T_uuid:
   345  		result := vector.GetColumn[types.Uuid](testutil.MakeUuidVectorByString(expected.([]string), nil))
   346  		require.Equal(t, result, vector.GetColumn[types.Uuid](vec))
   347  	default:
   348  		return false
   349  	}
   350  	return true
   351  
   352  }