github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/replace_cross_joins_test.go (about) 1 // Copyright 2022 DoltHub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "testing" 19 20 "github.com/dolthub/go-mysql-server/memory" 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 24 "github.com/dolthub/go-mysql-server/sql/plan" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 func TestConvertCrossJoin(t *testing.T) { 29 db := memory.NewDatabase("db") 30 pro := memory.NewDBProvider(db) 31 ctx := newContext(pro) 32 33 tableA := memory.NewTable(db, "a", sql.NewPrimaryKeySchema(sql.Schema{ 34 {Name: "x", Type: types.Int64, Source: "a"}, 35 {Name: "y", Type: types.Int64, Source: "a"}, 36 {Name: "z", Type: types.Int64, Source: "a"}, 37 }), nil) 38 tableB := memory.NewTable(db, "b", sql.NewPrimaryKeySchema(sql.Schema{ 39 {Name: "x", Type: types.Int64, Source: "b"}, 40 {Name: "y", Type: types.Int64, Source: "b"}, 41 {Name: "z", Type: types.Int64, Source: "b"}, 42 }), nil) 43 44 fieldAx := expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false) 45 fieldBy := expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false) 46 litOne := expression.NewLiteral(1, types.Int64) 47 48 matching := []sql.Expression{ 49 expression.NewEquals(fieldAx, fieldBy), 50 expression.NewNullSafeEquals(fieldAx, fieldBy), 51 expression.NewGreaterThan(fieldAx, fieldBy), 52 expression.NewGreaterThanOrEqual(fieldAx, fieldBy), 53 expression.NewLessThan(fieldAx, fieldBy), 54 expression.NewLessThanOrEqual(fieldAx, fieldBy), 55 expression.NewOr( 56 expression.NewEquals(fieldAx, fieldBy), 57 expression.NewEquals(litOne, litOne), 58 ), 59 expression.NewNot( 60 expression.NewEquals(fieldAx, fieldBy), 61 ), 62 expression.NewIsFalse( 63 expression.NewEquals(fieldAx, fieldBy), 64 ), 65 expression.NewIsTrue( 66 expression.NewEquals(fieldAx, fieldBy), 67 ), 68 expression.NewIsNull( 69 expression.NewEquals(fieldAx, fieldBy), 70 ), 71 } 72 73 nonMatching := []sql.Expression{ 74 expression.NewEquals( 75 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 76 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false), 77 ), 78 expression.NewEquals( 79 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 80 aggregation.NewMax(expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "y", false)), 81 ), 82 } 83 84 tests := make([]analyzerFnTestCase, 0, len(matching)+len(nonMatching)) 85 for _, t := range matching { 86 new := analyzerFnTestCase{ 87 name: t.String(), 88 node: plan.NewFilter( 89 t, 90 plan.NewCrossJoin( 91 plan.NewResolvedTable(tableA, nil, nil), 92 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 93 ), 94 ), 95 expected: plan.NewInnerJoin( 96 plan.NewResolvedTable(tableA, nil, nil), 97 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 98 t, 99 ), 100 } 101 tests = append(tests, new) 102 } 103 for _, t := range nonMatching { 104 new := analyzerFnTestCase{ 105 name: t.String(), 106 node: plan.NewFilter( 107 t, 108 plan.NewCrossJoin( 109 plan.NewResolvedTable(tableA, nil, nil), 110 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 111 ), 112 ), 113 expected: plan.NewFilter( 114 t, 115 plan.NewCrossJoin( 116 plan.NewResolvedTable(tableA, nil, nil), 117 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 118 ), 119 ), 120 } 121 tests = append(tests, new) 122 } 123 124 nested := []analyzerFnTestCase{ 125 { 126 name: "split AND into predicate leaves", 127 node: plan.NewFilter( 128 expression.NewAnd( 129 expression.NewEquals(fieldAx, fieldBy), 130 expression.NewEquals(fieldAx, litOne), 131 ), 132 plan.NewCrossJoin( 133 plan.NewResolvedTable(tableA, nil, nil), 134 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 135 ), 136 ), 137 expected: plan.NewFilter( 138 expression.NewEquals(fieldAx, litOne), 139 plan.NewInnerJoin( 140 plan.NewResolvedTable(tableA, nil, nil), 141 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 142 expression.NewEquals(fieldAx, fieldBy), 143 ), 144 ), 145 }, 146 { 147 name: "carry whole OR expression as join expression", 148 node: plan.NewFilter( 149 expression.NewAnd( 150 expression.NewOr( 151 expression.NewEquals(fieldAx, fieldBy), 152 expression.NewEquals(fieldAx, litOne), 153 ), 154 expression.NewEquals(fieldAx, litOne), 155 ), 156 plan.NewCrossJoin( 157 plan.NewResolvedTable(tableA, nil, nil), 158 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 159 ), 160 ), 161 expected: plan.NewFilter( 162 expression.NewEquals(fieldAx, litOne), 163 plan.NewInnerJoin( 164 plan.NewResolvedTable(tableA, nil, nil), 165 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 166 expression.NewOr( 167 expression.NewEquals(fieldAx, fieldBy), 168 expression.NewEquals(fieldAx, litOne), 169 ), 170 ), 171 ), 172 }, 173 { 174 name: "nested cross joins full conversion", 175 node: plan.NewFilter( 176 expression.NewAnd( 177 expression.NewEquals(fieldAx, fieldBy), 178 expression.NewAnd( 179 expression.NewEquals( 180 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 181 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false), 182 ), 183 expression.NewAnd( 184 expression.NewEquals( 185 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false), 186 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false), 187 ), 188 expression.NewEquals( 189 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "c", "x", false), 190 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "d", "y", false), 191 ), 192 ), 193 ), 194 ), 195 plan.NewCrossJoin( 196 plan.NewResolvedTable(tableA, nil, nil), 197 plan.NewCrossJoin( 198 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 199 plan.NewCrossJoin( 200 plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)), 201 plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)), 202 ), 203 ), 204 ), 205 ), 206 expected: plan.NewFilter( 207 expression.NewEquals( 208 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false), 209 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false), 210 ), 211 plan.NewInnerJoin( 212 plan.NewResolvedTable(tableA, nil, nil), 213 plan.NewInnerJoin( 214 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 215 plan.NewInnerJoin( 216 plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)), 217 plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)), 218 expression.NewEquals( 219 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "c", "x", false), 220 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "d", "y", false), 221 ), 222 ), 223 expression.NewEquals( 224 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 225 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false), 226 ), 227 ), 228 expression.NewEquals(fieldAx, fieldBy), 229 ), 230 ), 231 }, 232 { 233 name: "nested cross joins partial conversion", 234 node: plan.NewFilter( 235 expression.NewAnd( 236 expression.NewEquals(fieldAx, fieldBy), 237 expression.NewEquals( 238 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 239 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false), 240 ), 241 ), 242 plan.NewCrossJoin( 243 plan.NewResolvedTable(tableA, nil, nil), 244 plan.NewCrossJoin( 245 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 246 plan.NewCrossJoin( 247 plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)), 248 plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)), 249 ), 250 ), 251 ), 252 ), 253 expected: plan.NewInnerJoin( 254 plan.NewResolvedTable(tableA, nil, nil), 255 plan.NewInnerJoin( 256 plan.NewTableAlias("b", plan.NewResolvedTable(tableB, nil, nil)), 257 plan.NewCrossJoin( 258 plan.NewTableAlias("c", plan.NewResolvedTable(tableB, nil, nil)), 259 plan.NewTableAlias("d", plan.NewResolvedTable(tableB, nil, nil)), 260 ), 261 expression.NewEquals( 262 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "b", "x", false), 263 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "c", "y", false), 264 ), 265 ), 266 expression.NewEquals(fieldAx, fieldBy), 267 ), 268 }, 269 } 270 tests = append(tests, nested...) 271 272 runTestCases(t, ctx, tests, NewDefault(sql.NewDatabaseProvider()), getRule(replaceCrossJoinsId)) 273 }