github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/filters_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/expression/function"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  func TestFiltersMerge(t *testing.T) {
    30  	f1 := filtersByTable{
    31  		"1": []sql.Expression{
    32  			expression.NewLiteral("1", types.LongText),
    33  		},
    34  		"2": []sql.Expression{
    35  			expression.NewLiteral("2", types.LongText),
    36  		},
    37  	}
    38  
    39  	f2 := filtersByTable{
    40  		"2": []sql.Expression{
    41  			expression.NewLiteral("2.2", types.LongText),
    42  		},
    43  		"3": []sql.Expression{
    44  			expression.NewLiteral("3", types.LongText),
    45  		},
    46  	}
    47  
    48  	f1.merge(f2)
    49  
    50  	require.Equal(t,
    51  		filtersByTable{
    52  			"1": []sql.Expression{
    53  				expression.NewLiteral("1", types.LongText),
    54  			},
    55  			"2": []sql.Expression{
    56  				expression.NewLiteral("2", types.LongText),
    57  				expression.NewLiteral("2.2", types.LongText),
    58  			},
    59  			"3": []sql.Expression{
    60  				expression.NewLiteral("3", types.LongText),
    61  			},
    62  		},
    63  		f1,
    64  	)
    65  }
    66  
    67  func TestSplitExpression(t *testing.T) {
    68  	e := expression.NewAnd(
    69  		expression.NewAnd(
    70  			expression.NewIsNull(expression.NewUnresolvedColumn("foo")),
    71  			expression.NewNot(expression.NewUnresolvedColumn("foo")),
    72  		),
    73  		expression.NewAnd(
    74  			expression.NewOr(
    75  				expression.NewIsNull(expression.NewUnresolvedColumn("bar")),
    76  				expression.NewNot(expression.NewUnresolvedColumn("bar")),
    77  			),
    78  			expression.NewEquals(
    79  				expression.NewUnresolvedColumn("foo"),
    80  				expression.NewLiteral("foo", types.LongText),
    81  			),
    82  		),
    83  	)
    84  
    85  	expected := []sql.Expression{
    86  		expression.NewIsNull(expression.NewUnresolvedColumn("foo")),
    87  		expression.NewNot(expression.NewUnresolvedColumn("foo")),
    88  		expression.NewOr(
    89  			expression.NewIsNull(expression.NewUnresolvedColumn("bar")),
    90  			expression.NewNot(expression.NewUnresolvedColumn("bar")),
    91  		),
    92  		expression.NewEquals(
    93  			expression.NewUnresolvedColumn("foo"),
    94  			expression.NewLiteral("foo", types.LongText),
    95  		),
    96  	}
    97  
    98  	require.Equal(t,
    99  		expected,
   100  		expression.SplitConjunction(e),
   101  	)
   102  }
   103  
   104  func TestSubtractExprSet(t *testing.T) {
   105  	filters := []sql.Expression{
   106  		expression.NewIsNull(nil),
   107  		expression.NewNot(nil),
   108  		expression.NewEquals(nil, nil),
   109  		expression.NewGreaterThan(nil, nil),
   110  	}
   111  
   112  	handled := []sql.Expression{
   113  		filters[1],
   114  		filters[3],
   115  	}
   116  
   117  	unhandled := subtractExprSet(filters, handled)
   118  
   119  	require.Equal(t,
   120  		[]sql.Expression{filters[0], filters[2]},
   121  		unhandled,
   122  	)
   123  }
   124  
   125  func TestExprToTableFilters(t *testing.T) {
   126  	expr := expression.NewAnd(
   127  		expression.NewAnd(
   128  			expression.NewAnd(
   129  				expression.NewEquals(
   130  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   131  					expression.NewLiteral(3.14, types.Float64),
   132  				),
   133  				expression.NewGreaterThan(
   134  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   135  					expression.NewLiteral(3., types.Float64),
   136  				),
   137  			),
   138  			expression.NewIsNull(
   139  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable2", "i2", false),
   140  			),
   141  		),
   142  		expression.NewOr(
   143  			expression.NewEquals(
   144  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   145  				expression.NewLiteral(3.14, types.Float64),
   146  			),
   147  			expression.NewGreaterThan(
   148  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   149  				expression.NewLiteral(3., types.Float64),
   150  			),
   151  		),
   152  	)
   153  
   154  	expected := filtersByTable{
   155  		"mytable": []sql.Expression{
   156  			expression.NewEquals(
   157  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   158  				expression.NewLiteral(3.14, types.Float64),
   159  			),
   160  			expression.NewGreaterThan(
   161  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   162  				expression.NewLiteral(3., types.Float64),
   163  			),
   164  			expression.NewOr(
   165  				expression.NewEquals(
   166  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   167  					expression.NewLiteral(3.14, types.Float64),
   168  				),
   169  				expression.NewGreaterThan(
   170  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   171  					expression.NewLiteral(3., types.Float64),
   172  				),
   173  			),
   174  		},
   175  		"mytable2": []sql.Expression{
   176  			expression.NewIsNull(
   177  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable2", "i2", false),
   178  			),
   179  		},
   180  	}
   181  
   182  	filters := exprToTableFilters(expr)
   183  	assert.Equal(t, expected, filters)
   184  
   185  	// Test various complex conditions -- anytime we can't neatly split the expressions into tables
   186  	filters = exprToTableFilters(expression.NewAnd(
   187  		lit(0),
   188  		expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   189  	))
   190  	expected = filtersByTable{
   191  		"mytable": []sql.Expression{
   192  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   193  		},
   194  	}
   195  	assert.Equal(t, expected, filters)
   196  
   197  	filters = exprToTableFilters(expression.NewAnd(
   198  		expression.NewLiteral(nil, types.Null),
   199  		expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   200  	))
   201  	expected = filtersByTable{
   202  		"mytable": []sql.Expression{
   203  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   204  		},
   205  	}
   206  	assert.Equal(t, expected, filters)
   207  
   208  	filters = exprToTableFilters(expression.NewAnd(
   209  		expression.NewEquals(lit(1), mustExpr(function.NewRand())),
   210  		expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   211  	))
   212  	expected = filtersByTable{
   213  		"mytable": []sql.Expression{
   214  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   215  		},
   216  	}
   217  	assert.Equal(t, expected, filters)
   218  
   219  	filters = exprToTableFilters(expression.NewOr(
   220  		expression.NewLiteral(nil, types.Null),
   221  		expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   222  	))
   223  	expected = filtersByTable{
   224  		"mytable": []sql.Expression{
   225  			expression.NewOr(
   226  				expression.NewLiteral(nil, types.Null),
   227  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   228  			),
   229  		},
   230  	}
   231  	assert.Equal(t, expected, filters)
   232  
   233  	filters = exprToTableFilters(expression.NewAnd(
   234  		eq(
   235  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "a", false),
   236  			lit(1),
   237  		),
   238  		eq(
   239  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "f", false),
   240  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable2", "i", false),
   241  		),
   242  	))
   243  	expected = filtersByTable{
   244  		"mytable": []sql.Expression{
   245  			eq(
   246  				expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "mytable", "a", false),
   247  				lit(1),
   248  			),
   249  		},
   250  	}
   251  	assert.Equal(t, expected, filters)
   252  }