github.com/dolthub/go-mysql-server@v0.18.0/sql/memo/join_order_builder_test.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package memo 16 17 import ( 18 "context" 19 "fmt" 20 "strings" 21 "testing" 22 23 "github.com/stretchr/testify/require" 24 25 "github.com/dolthub/go-mysql-server/memory" 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/expression" 28 "github.com/dolthub/go-mysql-server/sql/plan" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 func TestJoinOrderBuilder(t *testing.T) { 33 db := memory.NewDatabase("test") 34 pro := memory.NewDBProvider(db) 35 36 tests := []struct { 37 in sql.Node 38 name string 39 plans string 40 forceFastReorder bool 41 }{ 42 { 43 name: "inner joins", 44 in: plan.NewInnerJoin( 45 plan.NewInnerJoin( 46 plan.NewInnerJoin( 47 tableNode(db, "a"), 48 tableNode(db, "b"), 49 newEq("a.x = b.x"), 50 ), 51 tableNode(db, "c"), 52 newEq("b.x = c.x"), 53 ), 54 tableNode(db, "d"), 55 newEq("c.x = d.x"), 56 ), 57 plans: `memo: 58 ├── G1: (tablescan: a) 59 ├── G2: (tablescan: b) 60 ├── G3: (innerjoin 2 1) (innerjoin 1 2) 61 ├── G4: (tablescan: c) 62 ├── G5: (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4) 63 ├── G6: (tablescan: d) 64 ├── G7: (innerjoin 6 5) (innerjoin 10 9) (innerjoin 9 10) (innerjoin 11 8) (innerjoin 8 11) (innerjoin 12 4) (innerjoin 4 12) (innerjoin 13 3) (innerjoin 3 13) (innerjoin 14 2) (innerjoin 2 14) (innerjoin 15 1) (innerjoin 1 15) (innerjoin 5 6) 65 ├── G8: (innerjoin 4 1) (innerjoin 1 4) 66 ├── G9: (innerjoin 4 2) (innerjoin 2 4) 67 ├── G10: (innerjoin 6 1) (innerjoin 1 6) 68 ├── G11: (innerjoin 6 2) (innerjoin 2 6) 69 ├── G12: (innerjoin 6 3) (innerjoin 3 6) (innerjoin 10 2) (innerjoin 2 10) (innerjoin 11 1) (innerjoin 1 11) 70 ├── G13: (innerjoin 6 4) (innerjoin 4 6) 71 ├── G14: (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 4) (innerjoin 4 10) (innerjoin 13 1) (innerjoin 1 13) 72 └── G15: (innerjoin 6 9) (innerjoin 9 6) (innerjoin 11 4) (innerjoin 4 11) (innerjoin 13 2) (innerjoin 2 13) 73 `, 74 }, 75 { 76 name: "non-inner joins", 77 in: plan.NewInnerJoin( 78 plan.NewInnerJoin( 79 plan.NewLeftOuterJoin( 80 tableNode(db, "a"), 81 tableNode(db, "b"), 82 newEq("a.x = b.x"), 83 ), 84 plan.NewLeftOuterJoin( 85 plan.NewFullOuterJoin( 86 tableNode(db, "c"), 87 tableNode(db, "d"), 88 newEq("c.x = d.x"), 89 ), 90 tableNode(db, "e"), 91 newEq("c.x = e.x"), 92 ), 93 newEq("a.x = e.x"), 94 ), 95 plan.NewInnerJoin( 96 tableNode(db, "f"), 97 tableNode(db, "g"), 98 newEq("f.x = g.x"), 99 ), 100 newEq("e.x = g.x"), 101 ), 102 plans: `memo: 103 ├── G1: (tablescan: a) 104 ├── G2: (tablescan: b) 105 ├── G3: (leftjoin 1 2) 106 ├── G4: (tablescan: c) 107 ├── G5: (tablescan: d) 108 ├── G6: (fullouterjoin 4 5) 109 ├── G7: (tablescan: e) 110 ├── G8: (leftjoin 6 7) 111 ├── G9: (innerjoin 8 3) (leftjoin 14 2) (innerjoin 3 8) 112 ├── G10: (tablescan: f) 113 ├── G11: (tablescan: g) 114 ├── G12: (innerjoin 11 10) (innerjoin 10 11) 115 ├── G13: (innerjoin 11 19) (innerjoin 19 11) (innerjoin 21 17) (innerjoin 17 21) (innerjoin 22 16) (innerjoin 16 22) (innerjoin 24 10) (innerjoin 10 24) (innerjoin 12 9) (innerjoin 26 8) (innerjoin 8 26) (innerjoin 27 3) (innerjoin 3 27) (leftjoin 28 2) (innerjoin 9 12) 116 ├── G14: (innerjoin 8 1) (innerjoin 1 8) 117 ├── G15: (innerjoin 10 1) (innerjoin 1 10) 118 ├── G16: (innerjoin 10 3) (innerjoin 3 10) (leftjoin 15 2) 119 ├── G17: (innerjoin 10 8) (innerjoin 8 10) 120 ├── G18: (innerjoin 10 14) (innerjoin 14 10) (innerjoin 15 8) (innerjoin 8 15) (innerjoin 17 1) (innerjoin 1 17) 121 ├── G19: (innerjoin 10 9) (innerjoin 9 10) (innerjoin 16 8) (innerjoin 8 16) (innerjoin 17 3) (innerjoin 3 17) (leftjoin 18 2) 122 ├── G20: (innerjoin 11 1) (innerjoin 1 11) 123 ├── G21: (innerjoin 11 3) (innerjoin 3 11) (leftjoin 20 2) 124 ├── G22: (innerjoin 11 8) (innerjoin 8 11) 125 ├── G23: (innerjoin 11 14) (innerjoin 14 11) (innerjoin 20 8) (innerjoin 8 20) (innerjoin 22 1) (innerjoin 1 22) 126 ├── G24: (innerjoin 11 9) (innerjoin 9 11) (innerjoin 21 8) (innerjoin 8 21) (innerjoin 22 3) (innerjoin 3 22) (leftjoin 23 2) 127 ├── G25: (innerjoin 11 15) (innerjoin 15 11) (innerjoin 20 10) (innerjoin 10 20) (innerjoin 12 1) (innerjoin 1 12) 128 ├── G26: (innerjoin 11 16) (innerjoin 16 11) (innerjoin 21 10) (innerjoin 10 21) (innerjoin 12 3) (innerjoin 3 12) (leftjoin 25 2) 129 ├── G27: (innerjoin 11 17) (innerjoin 17 11) (innerjoin 22 10) (innerjoin 10 22) (innerjoin 12 8) (innerjoin 8 12) 130 └── G28: (innerjoin 11 18) (innerjoin 18 11) (innerjoin 20 17) (innerjoin 17 20) (innerjoin 22 15) (innerjoin 15 22) (innerjoin 23 10) (innerjoin 10 23) (innerjoin 12 14) (innerjoin 14 12) (innerjoin 25 8) (innerjoin 8 25) (innerjoin 27 1) (innerjoin 1 27) 131 `, 132 }, 133 { 134 name: "test fast reordering algorithm", 135 // Optimized plan appears as G11 - (innerjoin 1 12) 136 in: plan.NewInnerJoin( 137 plan.NewCrossJoin( 138 tableNode(db, "a"), 139 tableNode(db, "c"), 140 ), 141 tableNode(db, "b"), 142 expression.NewAnd(newEq("a.x = b.z"), newEq("b.x = c.z")), 143 ), 144 145 forceFastReorder: true, 146 plans: `memo: 147 ├── G1: (tablescan: a) 148 ├── G2: (tablescan: c) 149 ├── G3: (crossjoin 1 2) 150 ├── G4: (tablescan: b) 151 ├── G5: (innerjoin 1 6) (innerjoin 6 1) (innerjoin 3 4) 152 └── G6: (innerjoin 4 2) (innerjoin 2 4) 153 `, 154 }, 155 { 156 name: "test fast reordering algorithm on bushy join", 157 // Optimized plan appears as G16: (innerjoin 7 17) 158 in: plan.NewInnerJoin( 159 plan.NewInnerJoin( 160 tableNode(db, "c"), 161 tableNode(db, "d"), 162 newEq("c.x = d.z"), 163 ), 164 plan.NewInnerJoin( 165 tableNode(db, "a"), 166 tableNode(db, "b"), 167 newEq("a.x = b.z"), 168 ), 169 newEq("b.x = c.z"), 170 ), 171 172 forceFastReorder: true, 173 plans: `memo: 174 ├── G1: (tablescan: c) 175 ├── G2: (tablescan: d) 176 ├── G3: (innerjoin 1 2) (innerjoin 2 1) (innerjoin 1 2) 177 ├── G4: (tablescan: a) 178 ├── G5: (tablescan: b) 179 ├── G6: (innerjoin 4 5) 180 ├── G7: (innerjoin 4 8) (innerjoin 8 4) (innerjoin 3 6) 181 └── G8: (innerjoin 5 3) (innerjoin 3 5) 182 `, 183 }, 184 } 185 186 for _, tt := range tests { 187 t.Run(tt.name, func(t *testing.T) { 188 j := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster())) 189 j.forceFastDFSLookupForTest = tt.forceFastReorder 190 j.ReorderJoin(tt.in) 191 require.Equal(t, tt.plans, j.m.String()) 192 }) 193 } 194 } 195 196 func newContext(provider *memory.DbProvider) *sql.Context { 197 return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider))) 198 } 199 200 func TestJoinOrderBuilder_populateSubgraph(t *testing.T) { 201 db := memory.NewDatabase("test") 202 pro := memory.NewDBProvider(db) 203 204 tests := []struct { 205 name string 206 join sql.Node 207 expEdges []edge 208 }{ 209 { 210 name: "cross join", 211 join: plan.NewCrossJoin( 212 tableNode(db, "a"), 213 plan.NewInnerJoin( 214 tableNode(db, "b"), 215 plan.NewLeftOuterJoin( 216 tableNode(db, "c"), 217 tableNode(db, "d"), 218 newEq("c.x=d.x"), 219 ), 220 newEq("b.y=d.y"), 221 ), 222 ), 223 expEdges: []edge{ 224 newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil, 225 newEq("c.x=d.x"), 226 ""), // C x D 227 newEdge2(plan.JoinTypeInner, "0101", "0111", "0100", "0011", nil, 228 newEq("b.y=d.y"), 229 ""), // B x (CD) 230 newEdge2(plan.JoinTypeCross, "0000", "1111", "1000", "0111", nil, nil, ""), // A x (BCD) 231 }, 232 }, 233 { 234 name: "right deep left join", 235 join: plan.NewInnerJoin( 236 tableNode(db, "a"), 237 plan.NewInnerJoin( 238 tableNode(db, "b"), 239 plan.NewLeftOuterJoin( 240 tableNode(db, "c"), 241 tableNode(db, "d"), 242 newEq("c.x=d.x"), 243 ), 244 newEq("b.y=d.y"), 245 ), 246 newEq("a.z=b.z"), 247 ), 248 expEdges: []edge{ 249 newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil, 250 newEq("c.x=d.x"), 251 ""), // C x D 252 newEdge2(plan.JoinTypeInner, "0101", "0111", "0100", "0011", nil, 253 newEq("b.y=d.y"), 254 255 ""), // B x (CD) 256 newEdge2(plan.JoinTypeInner, "1100", "1100", "1000", "0111", []conflictRule{{from: newVertexSet("0001"), to: newVertexSet("0010")}}, 257 newEq("a.z=b.z"), 258 259 ""), // A x (BCD) 260 }, 261 }, 262 { 263 name: "bushy left joins", 264 join: plan.NewLeftOuterJoin( 265 plan.NewLeftOuterJoin( 266 tableNode(db, "a"), 267 tableNode(db, "b"), 268 newEq("a.x=b.x"), 269 ), 270 plan.NewLeftOuterJoin( 271 tableNode(db, "c"), 272 tableNode(db, "d"), 273 newEq("c.x=d.x"), 274 ), 275 newEq("b.y=c.y"), 276 ), 277 expEdges: []edge{ 278 newEdge2(plan.JoinTypeLeftOuter, "1100", "1100", "1000", "0100", nil, 279 newEq("a.x=b.x"), 280 ""), // A x B 281 newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil, 282 newEq("c.x=d.x"), // offset by filters 283 ""), // C x D 284 newEdge2(plan.JoinTypeLeftOuter, "0110", "1111", "1100", "0011", nil, 285 newEq("b.y=c.y"), 286 ""), // (AB) x (CD) 287 }, 288 }, 289 { 290 // SELECT * 291 // FROM (SELECT * FROM A CROSS JOIN B) 292 // LEFT JOIN C 293 // ON B.x = C.x 294 name: "degenerate inner join", 295 join: plan.NewLeftOuterJoin( 296 plan.NewCrossJoin( 297 tableNode(db, "a"), 298 tableNode(db, "b"), 299 ), 300 tableNode(db, "c"), 301 newEq("b.x=c.x"), 302 ), 303 expEdges: []edge{ 304 newEdge2(plan.JoinTypeCross, "000", "110", "100", "010", nil, nil, ""), // A X B 305 newEdge2(plan.JoinTypeLeftOuter, "011", "111", "110", "001", nil, 306 newEq("b.x=c.x"), 307 308 ""), // (AB) x C 309 }, 310 }, 311 { 312 // SELECT * 313 // FROM (SELECT * FROM A INNER JOIN B ON True) 314 // FULL JOIN (SELECT * FROM C INNER JOIN D ON True) 315 // ON A.x = C.x 316 name: "degenerate inner join", 317 join: plan.NewFullOuterJoin( 318 plan.NewInnerJoin( 319 tableNode(db, "a"), 320 tableNode(db, "b"), 321 expression.NewLiteral(true, types.Boolean), 322 ), 323 plan.NewInnerJoin( 324 tableNode(db, "c"), 325 tableNode(db, "d"), 326 expression.NewLiteral(true, types.Boolean), 327 ), 328 newEq("a.x=c.x"), 329 ), 330 expEdges: []edge{ 331 newEdge2(plan.JoinTypeInner, "0000", "1100", "1000", "0100", nil, expression.NewLiteral(true, types.Boolean), ""), // A x B 332 newEdge2(plan.JoinTypeInner, "0000", "0011", "0010", "0001", nil, expression.NewLiteral(true, types.Boolean), ""), // C x D 333 newEdge2(plan.JoinTypeFullOuter, "1010", "1111", "1100", "0011", nil, 334 newEq("a.x=c.x"), 335 ""), // (AB) x (CD) 336 }, 337 }, 338 { 339 // SELECT * FROM A 340 // WHERE EXISTS 341 // ( 342 // SELECT * FROM B 343 // LEFT JOIN C ON B.x = C.x 344 // WHERE A.y = B.y 345 // ) 346 // note: left join is the right child 347 name: "semi join", 348 join: plan.NewSemiJoin( 349 plan.NewLeftOuterJoin( 350 tableNode(db, "b"), 351 tableNode(db, "c"), 352 newEq("b.x=c.x"), 353 ), 354 tableNode(db, "a"), 355 newEq("a.y=b.y"), 356 ), 357 expEdges: []edge{ 358 newEdge2(plan.JoinTypeLeftOuter, "110", "110", "100", "010", nil, 359 newEq("b.x=c.x"), 360 ""), // B x C 361 newEdge2(plan.JoinTypeSemi, "101", "101", "110", "001", nil, 362 newEq("a.y=b.y"), 363 ""), // A x (BC) 364 }, 365 }, 366 } 367 368 for _, tt := range tests { 369 t.Run(tt.name, func(t *testing.T) { 370 b := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster())) 371 b.populateSubgraph(tt.join) 372 edgesEq(t, tt.expEdges, b.edges) 373 }) 374 } 375 } 376 377 func newEq(eq string) sql.Expression { 378 vars := strings.Split(strings.Replace(eq, " ", "", -1), "=") 379 if len(vars) > 2 { 380 panic("invalid equal expression") 381 } 382 left := strings.Split(vars[0], ".") 383 right := strings.Split(vars[1], ".") 384 leftTabId, leftColId := getIds(left) 385 rightTabId, rightColId := getIds(right) 386 return expression.NewEquals( 387 expression.NewGetFieldWithTable(leftColId, leftTabId, types.Int64, "", left[0], left[1], false), 388 expression.NewGetFieldWithTable(rightColId, rightTabId, types.Int64, "", right[0], right[1], false), 389 ) 390 } 391 392 func getIds(s []string) (tabId int, colId int) { 393 switch s[0] { 394 case "a": 395 tabId = 1 396 case "b": 397 tabId = 2 398 case "c": 399 tabId = 3 400 case "d": 401 tabId = 4 402 case "e": 403 tabId = 5 404 case "f": 405 tabId = 6 406 case "g": 407 tabId = 7 408 case "xy": 409 tabId = 1 410 case "uv": 411 tabId = 2 412 case "ab": 413 tabId = 3 414 case "pq": 415 tabId = 4 416 } 417 switch s[1] { 418 case "x": 419 colId = (tabId-1)*3 + 1 420 case "y": 421 colId = (tabId-1)*3 + 2 422 case "z": 423 colId = (tabId-1)*3 + 3 424 } 425 return 426 } 427 428 func TestAssociativeTransforms(t *testing.T) { 429 // Sourced from Figure 3 430 // each test has a reversible pair test which is a product of its transform 431 validTests := []struct { 432 name string 433 eA *edge 434 eB *edge 435 transform assocTransform 436 rev bool 437 }{ 438 { 439 name: "assoc(a,b)", 440 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 441 eB: newEdge(plan.JoinTypeInner, "101", "110", "001"), 442 transform: assoc, 443 }, 444 { 445 name: "assoc(b,a)", 446 eA: newEdge(plan.JoinTypeInner, "010", "101", "010"), 447 eB: newEdge(plan.JoinTypeInner, "101", "001", "100"), 448 transform: assoc, 449 rev: true, 450 }, 451 { 452 name: "r-asscom(a,b)", 453 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 454 eB: newEdge(plan.JoinTypeInner, "101", "001", "110"), 455 transform: rightAsscom, 456 }, 457 { 458 name: "r-asscom(b,a)", 459 eA: newEdge(plan.JoinTypeInner, "110", "010", "101"), 460 eB: newEdge(plan.JoinTypeInner, "101", "001", "100"), 461 transform: rightAsscom, 462 rev: true, 463 }, 464 { 465 name: "l-asscom(a,b)", 466 eA: newEdge(plan.JoinTypeInner, "110", "100", "010"), 467 eB: newEdge(plan.JoinTypeInner, "101", "110", "001"), 468 transform: leftAsscom, 469 }, 470 { 471 name: "l-asscom(b,a)", 472 eA: newEdge(plan.JoinTypeInner, "110", "101", "010"), 473 eB: newEdge(plan.JoinTypeInner, "101", "100", "001"), 474 transform: leftAsscom, 475 rev: true, 476 }, 477 { 478 name: "assoc(a,b)", 479 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 480 eB: newEdge(plan.JoinTypeLeftOuter, "101", "110", "001"), 481 transform: assoc, 482 }, 483 // l-asscom is OK with everything but full outerjoin w/ null rejecting A(e1). 484 // Refer to rule table. 485 { 486 name: "l-asscom(a,b)", 487 eA: newEdge(plan.JoinTypeLeftOuter, "110", "100", "010"), 488 eB: newEdge(plan.JoinTypeInner, "101", "110", "001"), 489 transform: leftAsscom, 490 }, 491 { 492 name: "l-asscom(b,a)", 493 eA: newEdge(plan.JoinTypeLeftOuter, "110", "101", "010"), 494 eB: newEdge(plan.JoinTypeLeftOuter, "101", "100", "001"), 495 transform: leftAsscom, 496 rev: true, 497 }, 498 // TODO special case operators 499 } 500 501 for _, tt := range validTests { 502 t.Run(fmt.Sprintf("OK %s", tt.name), func(t *testing.T) { 503 var res bool 504 if tt.rev { 505 res = tt.transform(tt.eB, tt.eA) 506 } else { 507 res = tt.transform(tt.eA, tt.eB) 508 } 509 require.True(t, res) 510 }) 511 } 512 513 invalidTests := []struct { 514 name string 515 eA *edge 516 eB *edge 517 transform assocTransform 518 rev bool 519 }{ 520 // most transforms are invalid, these are also from Figure 3 521 { 522 name: "assoc(a,b)", 523 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 524 eB: newEdge(plan.JoinTypeInner, "101", "001", "100"), 525 transform: assoc, 526 }, 527 { 528 name: "r-asscom(a,b)", 529 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 530 eB: newEdge(plan.JoinTypeInner, "101", "100", "010"), 531 transform: rightAsscom, 532 }, 533 { 534 name: "l-asscom(a,b)", 535 eA: newEdge(plan.JoinTypeInner, "110", "010", "100"), 536 eB: newEdge(plan.JoinTypeInner, "101", "001", "100"), 537 transform: leftAsscom, 538 }, 539 // these are correct transforms with cross or inner joins, but invalid 540 // with other operators 541 { 542 name: "assoc(a,b)", 543 eA: newEdge(plan.JoinTypeLeftOuter, "110", "010", "100"), 544 eB: newEdge(plan.JoinTypeInner, "101", "110", "001"), 545 transform: assoc, 546 }, 547 { 548 // this one depends on rejecting nulls on A(e2) 549 name: "left join assoc(b,a)", 550 eA: newEdge(plan.JoinTypeLeftOuter, "010", "101", "010"), 551 eB: newEdge(plan.JoinTypeLeftOuter, "101", "001", "100"), 552 transform: assoc, 553 rev: true, 554 }, 555 { 556 name: "left join r-asscom(a,b)", 557 eA: newEdge(plan.JoinTypeLeftOuter, "110", "010", "100"), 558 eB: newEdge(plan.JoinTypeInner, "101", "001", "110"), 559 transform: rightAsscom, 560 }, 561 { 562 name: "left join r-asscom(b,a)", 563 eA: newEdge(plan.JoinTypeInner, "110", "010", "101"), 564 eB: newEdge(plan.JoinTypeLeftOuter, "101", "001", "100"), 565 transform: rightAsscom, 566 rev: true, 567 }, 568 { 569 name: "left join l-asscom(a,b)", 570 eA: newEdge(plan.JoinTypeFullOuter, "110", "100", "010"), 571 eB: newEdge(plan.JoinTypeInner, "101", "110", "001"), 572 transform: leftAsscom, 573 }, 574 } 575 576 for _, tt := range invalidTests { 577 t.Run(fmt.Sprintf("Invalid %s", tt.name), func(t *testing.T) { 578 var res bool 579 if tt.rev { 580 res = tt.transform(tt.eB, tt.eA) 581 } else { 582 res = tt.transform(tt.eA, tt.eB) 583 } 584 require.False(t, res) 585 }) 586 } 587 } 588 589 func TestEnsureClosure(t *testing.T) { 590 db := memory.NewDatabase("test") 591 pro := memory.NewDBProvider(db) 592 593 tests := []struct { 594 in sql.Node 595 name string 596 expEdges []edge 597 }{ 598 { 599 name: "inner joins", 600 in: plan.NewInnerJoin( 601 plan.NewInnerJoin( 602 plan.NewInnerJoin( 603 tableNode(db, "a"), 604 tableNode(db, "b"), 605 newEq("a.x = b.x"), 606 ), 607 tableNode(db, "c"), 608 newEq("b.x = c.x"), 609 ), 610 tableNode(db, "d"), 611 newEq("c.x = d.x"), 612 ), 613 expEdges: []edge{ 614 newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil, 615 newEq("a.x=c.x"), 616 617 ""), // (A)B x (C) 618 newEdge2(plan.JoinTypeInner, "1001", "1001", "1110", "0001", []conflictRule{{from: 4, to: 2}}, 619 newEq("a.x=d.x"), 620 621 ""), // (A)BC x (D) 622 newEdge2(plan.JoinTypeInner, "0101", "0101", "1110", "0001", nil, 623 newEq("b.x=d.x"), 624 625 ""), // A(B)C x (D) 626 }, 627 }, 628 { 629 name: "left joins", 630 in: plan.NewLeftOuterJoin( 631 plan.NewInnerJoin( 632 plan.NewInnerJoin( 633 tableNode(db, "a"), 634 tableNode(db, "b"), 635 newEq("a.x = b.x"), 636 ), 637 tableNode(db, "c"), 638 newEq("b.x = c.x"), 639 ), 640 tableNode(db, "d"), 641 newEq("c.x = d.x"), 642 ), 643 expEdges: []edge{ 644 newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil, 645 newEq("a.x=c.x"), 646 ""), // (A)B x (C) 647 }, 648 }, 649 { 650 name: "left join equivalence doesn't hold", 651 in: plan.NewLeftOuterJoin( 652 plan.NewInnerJoin( 653 plan.NewInnerJoin( 654 tableNode(db, "a"), 655 tableNode(db, "b"), 656 newEq("a.x = b.x"), 657 ), 658 tableNode(db, "c"), 659 newEq("b.x = c.x"), 660 ), 661 tableNode(db, "d"), 662 newEq("c.x = d.x"), 663 ), 664 expEdges: []edge{ 665 newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil, 666 newEq("a.x=c.x"), 667 ""), // (A)B x (C) 668 }, 669 }, 670 } 671 672 for _, tt := range tests { 673 t.Run(tt.name, func(t *testing.T) { 674 b := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster())) 675 b.populateSubgraph(tt.in) 676 beforeLen := len(b.edges) 677 b.ensureClosure(b.m.Root()) 678 newEdges := b.edges[beforeLen:] 679 edgesEq(t, tt.expEdges, newEdges) 680 }) 681 } 682 } 683 684 func childSchema(source string) sql.PrimaryKeySchema { 685 return sql.NewPrimaryKeySchema(sql.Schema{ 686 {Name: "x", Source: source, Type: types.Int64, Nullable: false}, 687 {Name: "y", Source: source, Type: types.Text, Nullable: true}, 688 {Name: "z", Source: source, Type: types.Int64, Nullable: true}, 689 }, 0) 690 } 691 692 func tableNode(db *memory.Database, name string) sql.Node { 693 t := memory.NewTable(db, name, childSchema(name), nil) 694 t.EnablePrimaryKeyIndexes() 695 tabId, colId := getIds([]string{name, "x"}) 696 colset := sql.NewColSet(sql.ColumnId(colId), sql.ColumnId(colId+1), sql.ColumnId(colId+2)) 697 return plan.NewResolvedTable(t, db, nil).WithId(sql.TableId(tabId)).WithColumns(colset) 698 } 699 700 func newVertexSet(s string) vertexSet { 701 v := vertexSet(0) 702 for i, c := range s { 703 if string(c) == "1" { 704 v = v.add(uint64(i)) 705 } 706 } 707 return v 708 } 709 710 func newEdge(op plan.JoinType, ses, leftV, rightV string) *edge { 711 return &edge{ 712 op: &operator{ 713 joinType: op, 714 rightVertices: newVertexSet(rightV), 715 leftVertices: newVertexSet(leftV), 716 }, 717 ses: newVertexSet(ses), 718 } 719 } 720 721 func newEdge2(op plan.JoinType, ses, tes, leftV, rightV string, rules []conflictRule, filter sql.Expression, nullRej string) edge { 722 var filters []sql.Expression 723 if filter != nil { 724 filters = []sql.Expression{filter} 725 } 726 return edge{ 727 op: &operator{ 728 joinType: op, 729 rightVertices: newVertexSet(rightV), 730 leftVertices: newVertexSet(leftV), 731 }, 732 ses: newVertexSet(ses), 733 tes: newVertexSet(tes), 734 rules: rules, 735 filters: filters, 736 nullRejectedRels: newVertexSet(nullRej), 737 } 738 } 739 740 func edgesEq(t *testing.T, edges1, edges2 []edge) bool { 741 if len(edges1) != len(edges2) { 742 return false 743 } 744 for i := range edges1 { 745 e1 := edges1[i] 746 e2 := edges2[i] 747 require.Equal(t, e1.op.joinType, e2.op.joinType) 748 require.Equal(t, e1.op.leftVertices.String(), e2.op.leftVertices.String()) 749 require.Equal(t, e1.op.rightVertices.String(), e2.op.rightVertices.String()) 750 require.Equal(t, len(e1.filters), len(e2.filters)) 751 for i := range e1.filters { 752 assertScalarEq(t, e1.filters[i], e2.filters[i]) 753 } 754 require.Equal(t, e1.nullRejectedRels, e2.nullRejectedRels) 755 require.Equal(t, e1.tes, e2.tes) 756 require.Equal(t, e1.ses, e2.ses) 757 require.Equal(t, e1.rules, e2.rules) 758 } 759 return true 760 } 761 762 func assertScalarEq(t *testing.T, exp, cmp sql.Expression) { 763 switch cmp := cmp.(type) { 764 case *expression.Equals: 765 exp, ok := exp.(*expression.Equals) 766 require.True(t, ok) 767 assertScalarEq(t, exp.Left(), cmp.Left()) 768 assertScalarEq(t, exp.Right(), cmp.Right()) 769 case *expression.Literal: 770 exp, ok := exp.(*expression.Literal) 771 require.True(t, ok) 772 require.Equal(t, exp.Value(), cmp.Value()) 773 case *expression.GetField: 774 exp, ok := exp.(*expression.GetField) 775 require.True(t, ok) 776 require.Equal(t, exp.Table(), cmp.Table()) 777 require.Equal(t, exp.Name(), cmp.Name()) 778 require.Equal(t, exp.String(), cmp.String()) 779 } 780 }