github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/replace_cross_joins_test.go (about)

     1  // Copyright 2022 DoltHub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/dolthub/go-mysql-server/memory"
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    24  	"github.com/dolthub/go-mysql-server/sql/plan"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  func TestConvertCrossJoin(t *testing.T) {
    29  	db := memory.NewDatabase("db")
    30  	pro := memory.NewDBProvider(db)
    31  	ctx := newContext(pro)
    32  
    33  	tableA := memory.NewTable(db, "a", sql.NewPrimaryKeySchema(sql.Schema{
    34  		{Name: "x", Type: types.Int64, Source: "a"},
    35  		{Name: "y", Type: types.Int64, Source: "a"},
    36  		{Name: "z", Type: types.Int64, Source: "a"},
    37  	}), nil)
    38  	tableB := memory.NewTable(db, "b", sql.NewPrimaryKeySchema(sql.Schema{
    39  		{Name: "x", Type: types.Int64, Source: "b"},
    40  		{Name: "y", Type: types.Int64, Source: "b"},
    41  		{Name: "z", Type: types.Int64, Source: "b"},
    42  	}), nil)
    43  
    44  	fieldAx := expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false)
    45  	fieldBy := expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false)
    46  	litOne := expression.NewLiteral(1, types.Int64)
    47  
    48  	matching := []sql.Expression{
    49  		expression.NewEquals(fieldAx, fieldBy),
    50  		expression.NewNullSafeEquals(fieldAx, fieldBy),
    51  		expression.NewGreaterThan(fieldAx, fieldBy),
    52  		expression.NewGreaterThanOrEqual(fieldAx, fieldBy),
    53  		expression.NewLessThan(fieldAx, fieldBy),
    54  		expression.NewLessThanOrEqual(fieldAx, fieldBy),
    55  		expression.NewOr(
    56  			expression.NewEquals(fieldAx, fieldBy),
    57  			expression.NewEquals(litOne, litOne),
    58  		),
    59  		expression.NewNot(
    60  			expression.NewEquals(fieldAx, fieldBy),
    61  		),
    62  		expression.NewIsFalse(
    63  			expression.NewEquals(fieldAx, fieldBy),
    64  		),
    65  		expression.NewIsTrue(
    66  			expression.NewEquals(fieldAx, fieldBy),
    67  		),
    68  		expression.NewIsNull(
    69  			expression.NewEquals(fieldAx, fieldBy),
    70  		),
    71  	}
    72  
    73  	nonMatching := []sql.Expression{
    74  		expression.NewEquals(
    75  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
    76  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false),
    77  		),
    78  		expression.NewEquals(
    79  			expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
    80  			aggregation.NewMax(expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false)),
    81  		),
    82  	}
    83  
    84  	tests := make([]analyzerFnTestCase, 0, len(matching)+len(nonMatching))
    85  	for _, t := range matching {
    86  		new := analyzerFnTestCase{
    87  			name: t.String(),
    88  			node: plan.NewFilter(
    89  				t,
    90  				plan.NewCrossJoin(
    91  					plan.NewResolvedTable(tableA, nil, nil),
    92  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
    93  				),
    94  			),
    95  			expected: plan.NewInnerJoin(
    96  				plan.NewResolvedTable(tableA, nil, nil),
    97  				plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
    98  				t,
    99  			),
   100  		}
   101  		tests = append(tests, new)
   102  	}
   103  	for _, t := range nonMatching {
   104  		new := analyzerFnTestCase{
   105  			name: t.String(),
   106  			node: plan.NewFilter(
   107  				t,
   108  				plan.NewCrossJoin(
   109  					plan.NewResolvedTable(tableA, nil, nil),
   110  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   111  				),
   112  			),
   113  			expected: plan.NewFilter(
   114  				t,
   115  				plan.NewCrossJoin(
   116  					plan.NewResolvedTable(tableA, nil, nil),
   117  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   118  				),
   119  			),
   120  		}
   121  		tests = append(tests, new)
   122  	}
   123  
   124  	nested := []analyzerFnTestCase{
   125  		{
   126  			name: "split AND into predicate leaves",
   127  			node: plan.NewFilter(
   128  				expression.NewAnd(
   129  					expression.NewEquals(fieldAx, fieldBy),
   130  					expression.NewEquals(fieldAx, litOne),
   131  				),
   132  				plan.NewCrossJoin(
   133  					plan.NewResolvedTable(tableA, nil, nil),
   134  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   135  				),
   136  			),
   137  			expected: plan.NewFilter(
   138  				expression.NewEquals(fieldAx, litOne),
   139  				plan.NewInnerJoin(
   140  					plan.NewResolvedTable(tableA, nil, nil),
   141  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   142  					expression.NewEquals(fieldAx, fieldBy),
   143  				),
   144  			),
   145  		},
   146  		{
   147  			name: "carry whole OR expression as join expression",
   148  			node: plan.NewFilter(
   149  				expression.NewAnd(
   150  					expression.NewOr(
   151  						expression.NewEquals(fieldAx, fieldBy),
   152  						expression.NewEquals(fieldAx, litOne),
   153  					),
   154  					expression.NewEquals(fieldAx, litOne),
   155  				),
   156  				plan.NewCrossJoin(
   157  					plan.NewResolvedTable(tableA, nil, nil),
   158  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   159  				),
   160  			),
   161  			expected: plan.NewFilter(
   162  				expression.NewEquals(fieldAx, litOne),
   163  				plan.NewInnerJoin(
   164  					plan.NewResolvedTable(tableA, nil, nil),
   165  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   166  					expression.NewOr(
   167  						expression.NewEquals(fieldAx, fieldBy),
   168  						expression.NewEquals(fieldAx, litOne),
   169  					),
   170  				),
   171  			),
   172  		},
   173  		{
   174  			name: "nested cross joins full conversion",
   175  			node: plan.NewFilter(
   176  				expression.NewAnd(
   177  					expression.NewEquals(fieldAx, fieldBy),
   178  					expression.NewAnd(
   179  						expression.NewEquals(
   180  							expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
   181  							expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false),
   182  						),
   183  						expression.NewAnd(
   184  							expression.NewEquals(
   185  								expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false),
   186  								expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false),
   187  							),
   188  							expression.NewEquals(
   189  								expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "c", "x", false),
   190  								expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "d", "y", false),
   191  							),
   192  						),
   193  					),
   194  				),
   195  				plan.NewCrossJoin(
   196  					plan.NewResolvedTable(tableA, nil, nil),
   197  					plan.NewCrossJoin(
   198  						plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   199  						plan.NewCrossJoin(
   200  							plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)),
   201  							plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)),
   202  						),
   203  					),
   204  				),
   205  			),
   206  			expected: plan.NewFilter(
   207  				expression.NewEquals(
   208  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false),
   209  					expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false),
   210  				),
   211  				plan.NewInnerJoin(
   212  					plan.NewResolvedTable(tableA, nil, nil),
   213  					plan.NewInnerJoin(
   214  						plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   215  						plan.NewInnerJoin(
   216  							plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)),
   217  							plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)),
   218  							expression.NewEquals(
   219  								expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "c", "x", false),
   220  								expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "d", "y", false),
   221  							),
   222  						),
   223  						expression.NewEquals(
   224  							expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
   225  							expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false),
   226  						),
   227  					),
   228  					expression.NewEquals(fieldAx, fieldBy),
   229  				),
   230  			),
   231  		},
   232  		{
   233  			name: "nested cross joins partial conversion",
   234  			node: plan.NewFilter(
   235  				expression.NewAnd(
   236  					expression.NewEquals(fieldAx, fieldBy),
   237  					expression.NewEquals(
   238  						expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
   239  						expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false),
   240  					),
   241  				),
   242  				plan.NewCrossJoin(
   243  					plan.NewResolvedTable(tableA, nil, nil),
   244  					plan.NewCrossJoin(
   245  						plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   246  						plan.NewCrossJoin(
   247  							plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)),
   248  							plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)),
   249  						),
   250  					),
   251  				),
   252  			),
   253  			expected: plan.NewInnerJoin(
   254  				plan.NewResolvedTable(tableA, nil, nil),
   255  				plan.NewInnerJoin(
   256  					plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)),
   257  					plan.NewCrossJoin(
   258  						plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)),
   259  						plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)),
   260  					),
   261  					expression.NewEquals(
   262  						expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false),
   263  						expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false),
   264  					),
   265  				),
   266  				expression.NewEquals(fieldAx, fieldBy),
   267  			),
   268  		},
   269  	}
   270  	tests = append(tests, nested...)
   271  
   272  	runTestCases(t, ctx, tests, NewDefault(sql.NewDatabaseProvider()), getRule(replaceCrossJoinsId))
   273  }