github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/optimization_rules_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/dolthub/go-mysql-server/memory"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    28  	"github.com/dolthub/go-mysql-server/sql/plan"
    29  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    30  	"github.com/dolthub/go-mysql-server/sql/types"
    31  )
    32  
    33  func TestEvalFilter(t *testing.T) {
    34  	db := memory.NewDatabase("db")
    35  	pro := memory.NewDBProvider(db)
    36  	ctx := newContext(pro)
    37  
    38  	inner := memory.NewTable(db, "foo", sql.PrimaryKeySchema{}, nil)
    39  	rule := getRule(simplifyFiltersId)
    40  
    41  	testCases := []struct {
    42  		filter   sql.Expression
    43  		expected sql.Node
    44  	}{
    45  		{
    46  			and(
    47  				eq(lit(5), lit(5)),
    48  				eq(
    49  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    50  					lit(5)),
    51  			),
    52  			plan.NewFilter(
    53  				eq(
    54  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    55  					lit(5)),
    56  				plan.NewResolvedTable(inner, nil, nil),
    57  			),
    58  		},
    59  		{
    60  			and(
    61  				eq(
    62  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    63  					lit(5)),
    64  				eq(lit(5), lit(5)),
    65  			),
    66  			plan.NewFilter(
    67  				eq(
    68  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    69  					lit(5)),
    70  				plan.NewResolvedTable(inner, nil, nil),
    71  			),
    72  		},
    73  		{
    74  			and(
    75  				eq(lit(5), lit(4)),
    76  				eq(
    77  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    78  					lit(5)),
    79  			),
    80  			plan.NewEmptyTableWithSchema(inner.Schema()),
    81  		},
    82  		{
    83  			and(
    84  				eq(
    85  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
    86  					lit(5)),
    87  				eq(lit(5), lit(4)),
    88  			),
    89  			plan.NewEmptyTableWithSchema(inner.Schema()),
    90  		},
    91  		{
    92  			and(
    93  				eq(lit(4), lit(4)),
    94  				eq(lit(5), lit(5)),
    95  			),
    96  			plan.NewResolvedTable(inner, nil, nil),
    97  		},
    98  		{
    99  			or(
   100  				eq(lit(5), lit(4)),
   101  				eq(
   102  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   103  					lit(5)),
   104  			),
   105  			plan.NewFilter(
   106  				eq(
   107  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   108  					lit(5)),
   109  				plan.NewResolvedTable(inner, nil, nil),
   110  			),
   111  		},
   112  		{
   113  			or(
   114  				eq(
   115  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   116  					lit(5)),
   117  				eq(lit(5), lit(4)),
   118  			),
   119  			plan.NewFilter(
   120  				eq(
   121  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   122  					lit(5)),
   123  				plan.NewResolvedTable(inner, nil, nil),
   124  			),
   125  		},
   126  		{
   127  			or(
   128  				eq(lit(5), lit(5)),
   129  				eq(
   130  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   131  					lit(5)),
   132  			),
   133  			plan.NewResolvedTable(inner, nil, nil),
   134  		},
   135  		{
   136  			or(
   137  				eq(
   138  					expression.NewGetFieldWithTable(0, 0, types.Int64, "", "foo", "bar", false),
   139  					lit(5)),
   140  				eq(lit(5), lit(5)),
   141  			),
   142  			plan.NewResolvedTable(inner, nil, nil),
   143  		},
   144  		{
   145  			or(
   146  				eq(lit(5), lit(4)),
   147  				eq(lit(5), lit(4)),
   148  			),
   149  			plan.NewEmptyTableWithSchema(inner.Schema()),
   150  		},
   151  	}
   152  
   153  	for _, tt := range testCases {
   154  		t.Run(tt.filter.String(), func(t *testing.T) {
   155  			require := require.New(t)
   156  			node := plan.NewFilter(tt.filter, plan.NewResolvedTable(inner, nil, nil))
   157  			result, _, err := rule.Apply(ctx, NewDefault(nil), node, nil, DefaultRuleSelector)
   158  			require.NoError(err)
   159  			require.Equal(tt.expected, result)
   160  		})
   161  	}
   162  }
   163  
   164  func TestPushNotFilters(t *testing.T) {
   165  	tests := []struct {
   166  		in  string
   167  		exp string
   168  	}{
   169  		{
   170  			in:  "NOT(NOT(x IS NULL))",
   171  			exp: "xy.x IS NULL",
   172  		},
   173  		{
   174  			in:  "NOT(x BETWEEN 0 AND 5)",
   175  			exp: "((xy.x < 0) OR (xy.x > 5))",
   176  		},
   177  		{
   178  			in:  "NOT(x <= 0)",
   179  			exp: "(xy.x > 0)",
   180  		},
   181  		{
   182  			in:  "NOT(x < 0)",
   183  			exp: "(xy.x >= 0)",
   184  		},
   185  		{
   186  			in:  "NOT(x > 0)",
   187  			exp: "(xy.x <= 0)",
   188  		},
   189  		{
   190  			in:  "NOT(x >= 0)",
   191  			exp: "(xy.x < 0)",
   192  		},
   193  		// TODO this isn't correct for join filters
   194  		//{
   195  		//	in:  "NOT(y IS NULL)",
   196  		//	exp: "((xy.x < NULL) OR (xy.x > NULL))",
   197  		//},
   198  		{
   199  			in:  "NOT (x > 2 AND y > 2)",
   200  			exp: "((xy.x <= 2) OR (xy.y <= 2))",
   201  		},
   202  		{
   203  			in:  "NOT (x > 2 AND NOT(y > 2))",
   204  			exp: "((xy.x <= 2) OR (xy.y > 2))",
   205  		},
   206  		{
   207  			in:  "((NOT(x > 1 AND NOT((x > 0) OR (y < 2))) OR (y > 1)) OR NOT(y < 3))",
   208  			exp: "((((xy.x <= 1) OR ((xy.x > 0) OR (xy.y < 2))) OR (xy.y > 1)) OR (xy.y >= 3))",
   209  		},
   210  	}
   211  
   212  	// todo dummy catalog and table
   213  	db := memory.NewDatabase("mydb")
   214  	cat := newTestCatalog(db)
   215  	pro := memory.NewDBProvider(db)
   216  	sess := memory.NewSession(sql.NewBaseSession(), pro)
   217  
   218  	ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
   219  	ctx.SetCurrentDatabase("mydb")
   220  
   221  	b := planbuilder.New(ctx, cat)
   222  
   223  	for _, tt := range tests {
   224  		t.Run(tt.in, func(t *testing.T) {
   225  			q := fmt.Sprintf("SELECT 1 from xy WHERE %s", tt.in)
   226  			node, err := b.ParseOne(q)
   227  			require.NoError(t, err)
   228  
   229  			cmp, _, err := pushNotFilters(ctx, nil, node, nil, nil)
   230  			require.NoError(t, err)
   231  
   232  			cmpF := cmp.(*plan.Project).Child.(*plan.Filter).Expression
   233  			cmpStr := cmpF.String()
   234  
   235  			require.Equal(t, tt.exp, cmpStr, fmt.Sprintf("\nexpected: %s\nfound:%s\n", tt.exp, cmpStr))
   236  		})
   237  	}
   238  }
   239  
   240  func newTestCatalog(db *memory.Database) *sql.MapCatalog {
   241  	cat := &sql.MapCatalog{
   242  		Databases: make(map[string]sql.Database),
   243  		Tables:    make(map[string]sql.Table),
   244  	}
   245  
   246  	cat.Tables["xy"] = memory.NewTable(db, "xy", sql.NewPrimaryKeySchema(sql.Schema{
   247  		{Name: "x", Type: types.Int64},
   248  		{Name: "y", Type: types.Int64},
   249  		{Name: "z", Type: types.Int64},
   250  	}, 0), nil)
   251  	cat.Tables["uv"] = memory.NewTable(db, "uv", sql.NewPrimaryKeySchema(sql.Schema{
   252  		{Name: "u", Type: types.Int64},
   253  		{Name: "v", Type: types.Int64},
   254  		{Name: "w", Type: types.Int64},
   255  	}, 0), nil)
   256  
   257  	db.AddTable("xy", cat.Tables["xy"].(memory.MemTable))
   258  	db.AddTable("uv", cat.Tables["uv"].(memory.MemTable))
   259  	cat.Databases["mydb"] = db
   260  	cat.Funcs = function.NewRegistry()
   261  	return cat
   262  }