github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/avg_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 _ "github.com/dolthub/go-mysql-server/sql/variables" 26 ) 27 28 func TestAvg_String(t *testing.T) { 29 require := require.New(t) 30 31 avg := NewAvg(expression.NewGetField(0, types.Int32, "col1", true)) 32 require.Equal("AVG(col1)", avg.String()) 33 } 34 35 func TestAvg_Float64(t *testing.T) { 36 require := require.New(t) 37 ctx := sql.NewEmptyContext() 38 39 avg := NewAvg(expression.NewGetField(0, types.Float64, "col1", true)) 40 buffer, _ := avg.NewBuffer() 41 buffer.Update(ctx, sql.NewRow(float64(23.2220000))) 42 43 require.Equal(float64(23.222), evalBuffer(t, buffer)) 44 } 45 46 func TestAvg_Eval_INT32(t *testing.T) { 47 require := require.New(t) 48 ctx := sql.NewEmptyContext() 49 50 avgNode := NewAvg(expression.NewGetField(0, types.Int32, "col1", true)) 51 buffer, _ := avgNode.NewBuffer() 52 require.Equal(nil, evalBuffer(t, buffer)) 53 54 buffer.Update(ctx, sql.NewRow(int32(1))) 55 require.Equal(float64(1), evalBuffer(t, buffer)) 56 57 buffer.Update(ctx, sql.NewRow(int32(2))) 58 require.Equal(float64(1.5), evalBuffer(t, buffer)) 59 } 60 61 func TestAvg_Eval_UINT64(t *testing.T) { 62 require := require.New(t) 63 ctx := sql.NewEmptyContext() 64 65 avgNode := NewAvg(expression.NewGetField(0, types.Uint64, "col1", true)) 66 buffer, _ := avgNode.NewBuffer() 67 require.Equal(nil, evalBuffer(t, buffer)) 68 69 err := buffer.Update(ctx, sql.NewRow(uint64(1))) 70 require.NoError(err) 71 require.Equal(float64(1), evalBuffer(t, buffer)) 72 73 err = buffer.Update(ctx, sql.NewRow(uint64(2))) 74 require.NoError(err) 75 require.Equal(float64(1.5), evalBuffer(t, buffer)) 76 } 77 78 func TestAvg_Eval_String(t *testing.T) { 79 require := require.New(t) 80 ctx := sql.NewEmptyContext() 81 82 avgNode := NewAvg(expression.NewGetField(0, types.Text, "col1", true)) 83 buffer, _ := avgNode.NewBuffer() 84 require.Equal(nil, evalBuffer(t, buffer)) 85 86 err := buffer.Update(ctx, sql.NewRow("foo")) 87 require.NoError(err) 88 require.Equal(float64(0), evalBuffer(t, buffer)) 89 90 err = buffer.Update(ctx, sql.NewRow("2")) 91 require.NoError(err) 92 require.Equal(float64(1), evalBuffer(t, buffer)) 93 } 94 95 func TestAvg_NULL(t *testing.T) { 96 require := require.New(t) 97 ctx := sql.NewEmptyContext() 98 99 avgNode := NewAvg(expression.NewGetField(0, types.Uint64, "col1", true)) 100 buffer, _ := avgNode.NewBuffer() 101 require.Zero(evalBuffer(t, buffer)) 102 103 err := buffer.Update(ctx, sql.NewRow(nil)) 104 require.NoError(err) 105 require.Equal(nil, evalBuffer(t, buffer)) 106 } 107 108 func TestAvg_NUMS_AND_NULLS(t *testing.T) { 109 require := require.New(t) 110 ctx := sql.NewEmptyContext() 111 112 avgNode := NewAvg(expression.NewGetField(0, types.Uint64, "col1", true)) 113 114 testCases := []struct { 115 name string 116 rows []sql.Row 117 expected interface{} 118 }{ 119 { 120 "float values with nil", 121 []sql.Row{{2.0}, {2.0}, {3.}, {4.}, {nil}}, 122 float64(2.75), 123 }, 124 { 125 "float values with nil", 126 []sql.Row{{1}, {2}, {3}, {nil}, {nil}}, 127 float64(2.0), 128 }, 129 { 130 "no rows", 131 []sql.Row{}, 132 nil, 133 }, 134 { 135 "nil values", 136 []sql.Row{{nil}, {nil}}, 137 nil, 138 }, 139 } 140 141 for _, tt := range testCases { 142 t.Run(tt.name, func(t *testing.T) { 143 buf, _ := avgNode.NewBuffer() 144 for _, row := range tt.rows { 145 require.NoError(buf.Update(ctx, row)) 146 } 147 148 require.Equal(tt.expected, evalBuffer(t, buf)) 149 }) 150 } 151 } 152 153 func TestAvg_Distinct(t *testing.T) { 154 require := require.New(t) 155 ctx := sql.NewEmptyContext() 156 157 ad := expression.NewDistinctExpression(expression.NewGetField(0, nil, "myfield", false)) 158 avg := NewAvg(ad) 159 160 // first validate that the expression's name is correct 161 require.Equal("AVG(DISTINCT myfield)", avg.String()) 162 163 testCases := []struct { 164 name string 165 rows []sql.Row 166 expected interface{} 167 }{ 168 { 169 "string int values", 170 []sql.Row{{"1"}, {"1"}, {"2"}, {"2"}, {"3"}, {"3"}, {"4"}, {"4"}}, 171 float64(2.5), 172 }, 173 { 174 "string float values", 175 []sql.Row{{"2.0"}, {"2.0"}, {"3.0"}, {"4.0"}, {"4.0"}}, 176 float64(3.0), 177 }, 178 { 179 "string float values", 180 []sql.Row{{"2.0"}, {"2.0"}, {"3.0"}, {"4.0"}, {"4.0"}}, 181 float64(3.0), 182 }, 183 { 184 "float values", 185 []sql.Row{{2.0}, {2.0}, {3.}, {4.}}, 186 float64(3.0), 187 }, 188 { 189 "float values with nil", 190 []sql.Row{{2.0}, {2.0}, {3.}, {4.}, {nil}}, 191 float64(3.0), 192 }, 193 { 194 "no rows", 195 []sql.Row{}, 196 nil, 197 }, 198 { 199 "nil values", 200 []sql.Row{{nil}, {nil}}, 201 nil, 202 }, 203 } 204 205 for _, tt := range testCases { 206 t.Run(tt.name, func(t *testing.T) { 207 buf, _ := avg.NewBuffer() 208 for _, row := range tt.rows { 209 require.NoError(buf.Update(ctx, row)) 210 } 211 212 require.Equal(tt.expected, evalBuffer(t, buf)) 213 }) 214 } 215 }