github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/sum_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 ) 25 26 func TestSum(t *testing.T) { 27 sum := NewSum(expression.NewGetField(0, nil, "", false)) 28 29 testCases := []struct { 30 name string 31 rows []sql.Row 32 expected interface{} 33 }{ 34 { 35 "string int values", 36 []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, 37 float64(10), 38 }, 39 { 40 "string float values", 41 []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, 42 float64(10.5), 43 }, 44 { 45 "string non-int values", 46 []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, 47 float64(0), 48 }, 49 { 50 "float values", 51 []sql.Row{{1.}, {2.5}, {3.}, {4.}}, 52 float64(10.5), 53 }, 54 { 55 "no rows", 56 []sql.Row{}, 57 nil, 58 }, 59 { 60 "nil values", 61 []sql.Row{{nil}, {nil}}, 62 nil, 63 }, 64 { 65 "int64 values", 66 []sql.Row{{int64(1)}, {int64(3)}}, 67 float64(4), 68 }, 69 { 70 "int32 values", 71 []sql.Row{{int32(1)}, {int32(3)}}, 72 float64(4), 73 }, 74 { 75 "int32 and nil values", 76 []sql.Row{{int32(1)}, {int32(3)}, {nil}}, 77 float64(4), 78 }, 79 } 80 81 for _, tt := range testCases { 82 t.Run(tt.name, func(t *testing.T) { 83 require := require.New(t) 84 85 ctx := sql.NewEmptyContext() 86 buf, _ := sum.NewBuffer() 87 for _, row := range tt.rows { 88 require.NoError(buf.Update(ctx, row)) 89 } 90 91 result, err := buf.Eval(sql.NewEmptyContext()) 92 require.NoError(err) 93 require.Equal(tt.expected, result) 94 }) 95 } 96 } 97 98 func TestSumWithDistinct(t *testing.T) { 99 require := require.New(t) 100 101 ad := expression.NewDistinctExpression(expression.NewGetField(0, nil, "myfield", false)) 102 sum := NewSum(ad) 103 104 // first validate that the expression's name is correct 105 require.Equal("SUM(DISTINCT myfield)", sum.String()) 106 107 testCases := []struct { 108 name string 109 rows []sql.Row 110 expected interface{} 111 }{ 112 { 113 "string int values", 114 []sql.Row{{"1"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, 115 float64(10), 116 }, 117 // TODO : DISTINCT returns incorrect result, it currently returns 11.00 118 // https://github.com/dolthub/dolt/issues/4298 119 //{ 120 // "string int values", 121 // []sql.Row{{"1.00"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, 122 // float64(10), 123 //}, 124 { 125 "string float values", 126 []sql.Row{{"1.5"}, {"1.5"}, {"1.5"}, {"1.5"}, {"2"}, {"3"}, {"4"}}, 127 float64(10.5), 128 }, 129 { 130 "string non-int values", 131 []sql.Row{{"a"}, {"b"}, {"b"}, {"c"}, {"c"}, {"d"}}, 132 float64(0), 133 }, 134 { 135 "float values", 136 []sql.Row{{1.}, {2.5}, {3.}, {4.}}, 137 float64(10.5), 138 }, 139 { 140 "no rows", 141 []sql.Row{}, 142 nil, 143 }, 144 { 145 "nil values", 146 []sql.Row{{nil}, {nil}}, 147 nil, 148 }, 149 { 150 "int64 values", 151 []sql.Row{{int64(1)}, {int64(3)}, {int64(3)}, {int64(3)}}, 152 float64(4), 153 }, 154 { 155 "int32 values", 156 []sql.Row{{int32(1)}, {int32(1)}, {int32(1)}, {int32(3)}}, 157 float64(4), 158 }, 159 { 160 "int32 and nil values", 161 []sql.Row{{nil}, {int32(1)}, {int32(1)}, {int32(1)}, {int32(3)}, {nil}, {nil}}, 162 float64(4), 163 }, 164 } 165 166 for _, tt := range testCases { 167 t.Run(tt.name, func(t *testing.T) { 168 ad.Dispose() 169 170 ctx := sql.NewEmptyContext() 171 buf, _ := sum.NewBuffer() 172 for _, row := range tt.rows { 173 require.NoError(buf.Update(ctx, row)) 174 } 175 176 result, err := buf.Eval(sql.NewEmptyContext()) 177 require.NoError(err) 178 require.Equal(tt.expected, result) 179 }) 180 } 181 }