github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/optimization_rules_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "context" 19 "fmt" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/dolthub/go-mysql-server/memory" 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/expression/function" 28 "github.com/dolthub/go-mysql-server/sql/plan" 29 "github.com/dolthub/go-mysql-server/sql/planbuilder" 30 "github.com/dolthub/go-mysql-server/sql/types" 31 ) 32 33 func TestEvalFilter(t *testing.T) { 34 db := memory.NewDatabase("db") 35 pro := memory.NewDBProvider(db) 36 ctx := newContext(pro) 37 38 inner := memory.NewTable(db, "foo", sql.PrimaryKeySchema{}, nil) 39 rule := getRule(simplifyFiltersId) 40 41 testCases := []struct { 42 filter sql.Expression 43 expected sql.Node 44 }{ 45 { 46 and( 47 eq(lit(5), lit(5)), 48 eq( 49 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 50 lit(5)), 51 ), 52 plan.NewFilter( 53 eq( 54 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 55 lit(5)), 56 plan.NewResolvedTable(inner, nil, nil), 57 ), 58 }, 59 { 60 and( 61 eq( 62 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 63 lit(5)), 64 eq(lit(5), lit(5)), 65 ), 66 plan.NewFilter( 67 eq( 68 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 69 lit(5)), 70 plan.NewResolvedTable(inner, nil, nil), 71 ), 72 }, 73 { 74 and( 75 eq(lit(5), lit(4)), 76 eq( 77 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 78 lit(5)), 79 ), 80 plan.NewEmptyTableWithSchema(inner.Schema()), 81 }, 82 { 83 and( 84 eq( 85 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 86 lit(5)), 87 eq(lit(5), lit(4)), 88 ), 89 plan.NewEmptyTableWithSchema(inner.Schema()), 90 }, 91 { 92 and( 93 eq(lit(4), lit(4)), 94 eq(lit(5), lit(5)), 95 ), 96 plan.NewResolvedTable(inner, nil, nil), 97 }, 98 { 99 or( 100 eq(lit(5), lit(4)), 101 eq( 102 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 103 lit(5)), 104 ), 105 plan.NewFilter( 106 eq( 107 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 108 lit(5)), 109 plan.NewResolvedTable(inner, nil, nil), 110 ), 111 }, 112 { 113 or( 114 eq( 115 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 116 lit(5)), 117 eq(lit(5), lit(4)), 118 ), 119 plan.NewFilter( 120 eq( 121 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 122 lit(5)), 123 plan.NewResolvedTable(inner, nil, nil), 124 ), 125 }, 126 { 127 or( 128 eq(lit(5), lit(5)), 129 eq( 130 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 131 lit(5)), 132 ), 133 plan.NewResolvedTable(inner, nil, nil), 134 }, 135 { 136 or( 137 eq( 138 expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false), 139 lit(5)), 140 eq(lit(5), lit(5)), 141 ), 142 plan.NewResolvedTable(inner, nil, nil), 143 }, 144 { 145 or( 146 eq(lit(5), lit(4)), 147 eq(lit(5), lit(4)), 148 ), 149 plan.NewEmptyTableWithSchema(inner.Schema()), 150 }, 151 } 152 153 for _, tt := range testCases { 154 t.Run(tt.filter.String(), func(t *testing.T) { 155 require := require.New(t) 156 node := plan.NewFilter(tt.filter, plan.NewResolvedTable(inner, nil, nil)) 157 result, _, err := rule.Apply(ctx, NewDefault(nil), node, nil, DefaultRuleSelector) 158 require.NoError(err) 159 require.Equal(tt.expected, result) 160 }) 161 } 162 } 163 164 func TestPushNotFilters(t *testing.T) { 165 tests := []struct { 166 in string 167 exp string 168 }{ 169 { 170 in: "NOT(NOT(x IS NULL))", 171 exp: "xy.x IS NULL", 172 }, 173 { 174 in: "NOT(x BETWEEN 0 AND 5)", 175 exp: "((xy.x < 0) OR (xy.x > 5))", 176 }, 177 { 178 in: "NOT(x <= 0)", 179 exp: "(xy.x > 0)", 180 }, 181 { 182 in: "NOT(x < 0)", 183 exp: "(xy.x >= 0)", 184 }, 185 { 186 in: "NOT(x > 0)", 187 exp: "(xy.x <= 0)", 188 }, 189 { 190 in: "NOT(x >= 0)", 191 exp: "(xy.x < 0)", 192 }, 193 // TODO this isn't correct for join filters 194 //{ 195 // in: "NOT(y IS NULL)", 196 // exp: "((xy.x < NULL) OR (xy.x > NULL))", 197 //}, 198 { 199 in: "NOT (x > 2 AND y > 2)", 200 exp: "((xy.x <= 2) OR (xy.y <= 2))", 201 }, 202 { 203 in: "NOT (x > 2 AND NOT(y > 2))", 204 exp: "((xy.x <= 2) OR (xy.y > 2))", 205 }, 206 { 207 in: "((NOT(x > 1 AND NOT((x > 0) OR (y < 2))) OR (y > 1)) OR NOT(y < 3))", 208 exp: "((((xy.x <= 1) OR ((xy.x > 0) OR (xy.y < 2))) OR (xy.y > 1)) OR (xy.y >= 3))", 209 }, 210 } 211 212 // todo dummy catalog and table 213 db := memory.NewDatabase("mydb") 214 cat := newTestCatalog(db) 215 pro := memory.NewDBProvider(db) 216 sess := memory.NewSession(sql.NewBaseSession(), pro) 217 218 ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) 219 ctx.SetCurrentDatabase("mydb") 220 221 b := planbuilder.New(ctx, cat) 222 223 for _, tt := range tests { 224 t.Run(tt.in, func(t *testing.T) { 225 q := fmt.Sprintf("SELECT 1 from xy WHERE %s", tt.in) 226 node, err := b.ParseOne(q) 227 require.NoError(t, err) 228 229 cmp, _, err := pushNotFilters(ctx, nil, node, nil, nil) 230 require.NoError(t, err) 231 232 cmpF := cmp.(*plan.Project).Child.(*plan.Filter).Expression 233 cmpStr := cmpF.String() 234 235 require.Equal(t, tt.exp, cmpStr, fmt.Sprintf("\nexpected: %s\nfound:%s\n", tt.exp, cmpStr)) 236 }) 237 } 238 } 239 240 func newTestCatalog(db *memory.Database) *sql.MapCatalog { 241 cat := &sql.MapCatalog{ 242 Databases: make(map[string]sql.Database), 243 Tables: make(map[string]sql.Table), 244 } 245 246 cat.Tables["xy"] = memory.NewTable(db, "xy", sql.NewPrimaryKeySchema(sql.Schema{ 247 {Name: "x", Type: types.Int64}, 248 {Name: "y", Type: types.Int64}, 249 {Name: "z", Type: types.Int64}, 250 }, 0), nil) 251 cat.Tables["uv"] = memory.NewTable(db, "uv", sql.NewPrimaryKeySchema(sql.Schema{ 252 {Name: "u", Type: types.Int64}, 253 {Name: "v", Type: types.Int64}, 254 {Name: "w", Type: types.Int64}, 255 }, 0), nil) 256 257 db.AddTable("xy", cat.Tables["xy"].(memory.MemTable)) 258 db.AddTable("uv", cat.Tables["uv"].(memory.MemTable)) 259 cat.Databases["mydb"] = db 260 cat.Funcs = function.NewRegistry() 261 return cat 262 }