vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/scalar_aggregation_test.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  
    27  	"vitess.io/vitess/go/sqltypes"
    28  )
    29  
    30  func TestEmptyRows(outer *testing.T) {
    31  	testCases := []struct {
    32  		opcode      AggregateOpcode
    33  		origOpcode  AggregateOpcode
    34  		expectedVal string
    35  		expectedTyp string
    36  	}{{
    37  		opcode:      AggregateCountDistinct,
    38  		expectedVal: "0",
    39  		expectedTyp: "int64",
    40  	}, {
    41  		opcode:      AggregateCount,
    42  		expectedVal: "0",
    43  		expectedTyp: "int64",
    44  	}, {
    45  		opcode:      AggregateSumDistinct,
    46  		expectedVal: "null",
    47  		expectedTyp: "decimal",
    48  	}, {
    49  		opcode:      AggregateSum,
    50  		expectedVal: "null",
    51  		expectedTyp: "int64",
    52  	}, {
    53  		opcode:      AggregateSum,
    54  		expectedVal: "0",
    55  		expectedTyp: "int64",
    56  		origOpcode:  AggregateCount,
    57  	}, {
    58  		opcode:      AggregateMax,
    59  		expectedVal: "null",
    60  		expectedTyp: "int64",
    61  	}, {
    62  		opcode:      AggregateMin,
    63  		expectedVal: "null",
    64  		expectedTyp: "int64",
    65  	}}
    66  
    67  	for _, test := range testCases {
    68  		outer.Run(test.opcode.String(), func(t *testing.T) {
    69  			assert := assert.New(t)
    70  			fp := &fakePrimitive{
    71  				results: []*sqltypes.Result{sqltypes.MakeTestResult(
    72  					sqltypes.MakeTestFields(
    73  						test.opcode.String(),
    74  						"int64",
    75  					),
    76  					// Empty input table
    77  				)},
    78  			}
    79  
    80  			oa := &ScalarAggregate{
    81  				PreProcess: true,
    82  				Aggregates: []*AggregateParams{{
    83  					Opcode:     test.opcode,
    84  					Col:        0,
    85  					Alias:      test.opcode.String(),
    86  					OrigOpcode: test.origOpcode,
    87  				}},
    88  				Input: fp,
    89  			}
    90  
    91  			result, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
    92  			assert.NoError(err)
    93  
    94  			wantResult := sqltypes.MakeTestResult(
    95  				sqltypes.MakeTestFields(
    96  					test.opcode.String(),
    97  					test.expectedTyp,
    98  				),
    99  				test.expectedVal,
   100  			)
   101  			assert.Equal(wantResult, result)
   102  		})
   103  	}
   104  }
   105  
   106  func TestScalarAggregateStreamExecute(t *testing.T) {
   107  	assert := assert.New(t)
   108  	fields := sqltypes.MakeTestFields(
   109  		"col|weight_string(col)",
   110  		"uint64|varbinary",
   111  	)
   112  	fp := &fakePrimitive{
   113  		allResultsInOneCall: true,
   114  		results: []*sqltypes.Result{
   115  			sqltypes.MakeTestResult(fields,
   116  				"1|null",
   117  			), sqltypes.MakeTestResult(fields,
   118  				"3|null",
   119  			)},
   120  	}
   121  
   122  	oa := &ScalarAggregate{
   123  		Aggregates: []*AggregateParams{{
   124  			Opcode: AggregateSum,
   125  			Col:    0,
   126  		}},
   127  		Input:               fp,
   128  		TruncateColumnCount: 1,
   129  		PreProcess:          true,
   130  	}
   131  
   132  	var results []*sqltypes.Result
   133  	err := oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
   134  		results = append(results, qr)
   135  		return nil
   136  	})
   137  	assert.NoError(err)
   138  	// one for the fields, and one for the actual aggregation result
   139  	require.EqualValues(t, 2, len(results), "number of results")
   140  
   141  	got := fmt.Sprintf("%v", results[1].Rows)
   142  	assert.Equal("[[UINT64(4)]]", got)
   143  }
   144  
   145  // TestScalarAggregateExecuteTruncate checks if truncate works
   146  func TestScalarAggregateExecuteTruncate(t *testing.T) {
   147  	assert := assert.New(t)
   148  	fields := sqltypes.MakeTestFields(
   149  		"col|weight_string(col)",
   150  		"uint64|varbinary",
   151  	)
   152  
   153  	fp := &fakePrimitive{
   154  		allResultsInOneCall: true,
   155  		results: []*sqltypes.Result{
   156  			sqltypes.MakeTestResult(fields,
   157  				"1|null", "3|null",
   158  			)},
   159  	}
   160  
   161  	oa := &ScalarAggregate{
   162  		Aggregates: []*AggregateParams{{
   163  			Opcode: AggregateSum,
   164  			Col:    0,
   165  		}},
   166  		Input:               fp,
   167  		TruncateColumnCount: 1,
   168  		PreProcess:          true,
   169  	}
   170  
   171  	qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true)
   172  	assert.NoError(err)
   173  	assert.Equal("[[UINT64(4)]]", fmt.Sprintf("%v", qr.Rows))
   174  }