github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/sum_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  )
    25  
    26  func TestSum(t *testing.T) {
    27  	sum := NewSum(expression.NewGetField(0, nil, "", false))
    28  
    29  	testCases := []struct {
    30  		name     string
    31  		rows     []sql.Row
    32  		expected interface{}
    33  	}{
    34  		{
    35  			"string int values",
    36  			[]sql.Row{{"1"}, {"2"}, {"3"}, {"4"}},
    37  			float64(10),
    38  		},
    39  		{
    40  			"string float values",
    41  			[]sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}},
    42  			float64(10.5),
    43  		},
    44  		{
    45  			"string non-int values",
    46  			[]sql.Row{{"a"}, {"b"}, {"c"}, {"d"}},
    47  			float64(0),
    48  		},
    49  		{
    50  			"float values",
    51  			[]sql.Row{{1.}, {2.5}, {3.}, {4.}},
    52  			float64(10.5),
    53  		},
    54  		{
    55  			"no rows",
    56  			[]sql.Row{},
    57  			nil,
    58  		},
    59  		{
    60  			"nil values",
    61  			[]sql.Row{{nil}, {nil}},
    62  			nil,
    63  		},
    64  		{
    65  			"int64 values",
    66  			[]sql.Row{{int64(1)}, {int64(3)}},
    67  			float64(4),
    68  		},
    69  		{
    70  			"int32 values",
    71  			[]sql.Row{{int32(1)}, {int32(3)}},
    72  			float64(4),
    73  		},
    74  		{
    75  			"int32 and nil values",
    76  			[]sql.Row{{int32(1)}, {int32(3)}, {nil}},
    77  			float64(4),
    78  		},
    79  	}
    80  
    81  	for _, tt := range testCases {
    82  		t.Run(tt.name, func(t *testing.T) {
    83  			require := require.New(t)
    84  
    85  			ctx := sql.NewEmptyContext()
    86  			buf, _ := sum.NewBuffer()
    87  			for _, row := range tt.rows {
    88  				require.NoError(buf.Update(ctx, row))
    89  			}
    90  
    91  			result, err := buf.Eval(sql.NewEmptyContext())
    92  			require.NoError(err)
    93  			require.Equal(tt.expected, result)
    94  		})
    95  	}
    96  }
    97  
    98  func TestSumWithDistinct(t *testing.T) {
    99  	require := require.New(t)
   100  
   101  	ad := expression.NewDistinctExpression(expression.NewGetField(0, nil, "myfield", false))
   102  	sum := NewSum(ad)
   103  
   104  	// first validate that the expression's name is correct
   105  	require.Equal("SUM(DISTINCT myfield)", sum.String())
   106  
   107  	testCases := []struct {
   108  		name     string
   109  		rows     []sql.Row
   110  		expected interface{}
   111  	}{
   112  		{
   113  			"string int values",
   114  			[]sql.Row{{"1"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}},
   115  			float64(10),
   116  		},
   117  		// TODO : DISTINCT returns incorrect result, it currently returns 11.00
   118  		//        https://github.com/dolthub/dolt/issues/4298
   119  		//{
   120  		//	"string int values",
   121  		//	[]sql.Row{{"1.00"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}},
   122  		//	float64(10),
   123  		//},
   124  		{
   125  			"string float values",
   126  			[]sql.Row{{"1.5"}, {"1.5"}, {"1.5"}, {"1.5"}, {"2"}, {"3"}, {"4"}},
   127  			float64(10.5),
   128  		},
   129  		{
   130  			"string non-int values",
   131  			[]sql.Row{{"a"}, {"b"}, {"b"}, {"c"}, {"c"}, {"d"}},
   132  			float64(0),
   133  		},
   134  		{
   135  			"float values",
   136  			[]sql.Row{{1.}, {2.5}, {3.}, {4.}},
   137  			float64(10.5),
   138  		},
   139  		{
   140  			"no rows",
   141  			[]sql.Row{},
   142  			nil,
   143  		},
   144  		{
   145  			"nil values",
   146  			[]sql.Row{{nil}, {nil}},
   147  			nil,
   148  		},
   149  		{
   150  			"int64 values",
   151  			[]sql.Row{{int64(1)}, {int64(3)}, {int64(3)}, {int64(3)}},
   152  			float64(4),
   153  		},
   154  		{
   155  			"int32 values",
   156  			[]sql.Row{{int32(1)}, {int32(1)}, {int32(1)}, {int32(3)}},
   157  			float64(4),
   158  		},
   159  		{
   160  			"int32 and nil values",
   161  			[]sql.Row{{nil}, {int32(1)}, {int32(1)}, {int32(1)}, {int32(3)}, {nil}, {nil}},
   162  			float64(4),
   163  		},
   164  	}
   165  
   166  	for _, tt := range testCases {
   167  		t.Run(tt.name, func(t *testing.T) {
   168  			ad.Dispose()
   169  
   170  			ctx := sql.NewEmptyContext()
   171  			buf, _ := sum.NewBuffer()
   172  			for _, row := range tt.rows {
   173  				require.NoError(buf.Update(ctx, row))
   174  			}
   175  
   176  			result, err := buf.Eval(sql.NewEmptyContext())
   177  			require.NoError(err)
   178  			require.Equal(tt.expected, result)
   179  		})
   180  	}
   181  }