github.com/dolthub/go-mysql-server@v0.18.0/sql/memo/join_order_builder_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memo
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"github.com/dolthub/go-mysql-server/memory"
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/plan"
    29  	"github.com/dolthub/go-mysql-server/sql/types"
    30  )
    31  
    32  func TestJoinOrderBuilder(t *testing.T) {
    33  	db := memory.NewDatabase("test")
    34  	pro := memory.NewDBProvider(db)
    35  
    36  	tests := []struct {
    37  		in               sql.Node
    38  		name             string
    39  		plans            string
    40  		forceFastReorder bool
    41  	}{
    42  		{
    43  			name: "inner joins",
    44  			in: plan.NewInnerJoin(
    45  				plan.NewInnerJoin(
    46  					plan.NewInnerJoin(
    47  						tableNode(db, "a"),
    48  						tableNode(db, "b"),
    49  						newEq("a.x = b.x"),
    50  					),
    51  					tableNode(db, "c"),
    52  					newEq("b.x = c.x"),
    53  				),
    54  				tableNode(db, "d"),
    55  				newEq("c.x = d.x"),
    56  			),
    57  			plans: `memo:
    58  ├── G1: (tablescan: a)
    59  ├── G2: (tablescan: b)
    60  ├── G3: (innerjoin 2 1) (innerjoin 1 2)
    61  ├── G4: (tablescan: c)
    62  ├── G5: (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4)
    63  ├── G6: (tablescan: d)
    64  ├── G7: (innerjoin 6 5) (innerjoin 10 9) (innerjoin 9 10) (innerjoin 11 8) (innerjoin 8 11) (innerjoin 12 4) (innerjoin 4 12) (innerjoin 13 3) (innerjoin 3 13) (innerjoin 14 2) (innerjoin 2 14) (innerjoin 15 1) (innerjoin 1 15) (innerjoin 5 6)
    65  ├── G8: (innerjoin 4 1) (innerjoin 1 4)
    66  ├── G9: (innerjoin 4 2) (innerjoin 2 4)
    67  ├── G10: (innerjoin 6 1) (innerjoin 1 6)
    68  ├── G11: (innerjoin 6 2) (innerjoin 2 6)
    69  ├── G12: (innerjoin 6 3) (innerjoin 3 6) (innerjoin 10 2) (innerjoin 2 10) (innerjoin 11 1) (innerjoin 1 11)
    70  ├── G13: (innerjoin 6 4) (innerjoin 4 6)
    71  ├── G14: (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 4) (innerjoin 4 10) (innerjoin 13 1) (innerjoin 1 13)
    72  └── G15: (innerjoin 6 9) (innerjoin 9 6) (innerjoin 11 4) (innerjoin 4 11) (innerjoin 13 2) (innerjoin 2 13)
    73  `,
    74  		},
    75  		{
    76  			name: "non-inner joins",
    77  			in: plan.NewInnerJoin(
    78  				plan.NewInnerJoin(
    79  					plan.NewLeftOuterJoin(
    80  						tableNode(db, "a"),
    81  						tableNode(db, "b"),
    82  						newEq("a.x = b.x"),
    83  					),
    84  					plan.NewLeftOuterJoin(
    85  						plan.NewFullOuterJoin(
    86  							tableNode(db, "c"),
    87  							tableNode(db, "d"),
    88  							newEq("c.x = d.x"),
    89  						),
    90  						tableNode(db, "e"),
    91  						newEq("c.x = e.x"),
    92  					),
    93  					newEq("a.x = e.x"),
    94  				),
    95  				plan.NewInnerJoin(
    96  					tableNode(db, "f"),
    97  					tableNode(db, "g"),
    98  					newEq("f.x = g.x"),
    99  				),
   100  				newEq("e.x = g.x"),
   101  			),
   102  			plans: `memo:
   103  ├── G1: (tablescan: a)
   104  ├── G2: (tablescan: b)
   105  ├── G3: (leftjoin 1 2)
   106  ├── G4: (tablescan: c)
   107  ├── G5: (tablescan: d)
   108  ├── G6: (fullouterjoin 4 5)
   109  ├── G7: (tablescan: e)
   110  ├── G8: (leftjoin 6 7)
   111  ├── G9: (innerjoin 8 3) (leftjoin 14 2) (innerjoin 3 8)
   112  ├── G10: (tablescan: f)
   113  ├── G11: (tablescan: g)
   114  ├── G12: (innerjoin 11 10) (innerjoin 10 11)
   115  ├── G13: (innerjoin 11 19) (innerjoin 19 11) (innerjoin 21 17) (innerjoin 17 21) (innerjoin 22 16) (innerjoin 16 22) (innerjoin 24 10) (innerjoin 10 24) (innerjoin 12 9) (innerjoin 26 8) (innerjoin 8 26) (innerjoin 27 3) (innerjoin 3 27) (leftjoin 28 2) (innerjoin 9 12)
   116  ├── G14: (innerjoin 8 1) (innerjoin 1 8)
   117  ├── G15: (innerjoin 10 1) (innerjoin 1 10)
   118  ├── G16: (innerjoin 10 3) (innerjoin 3 10) (leftjoin 15 2)
   119  ├── G17: (innerjoin 10 8) (innerjoin 8 10)
   120  ├── G18: (innerjoin 10 14) (innerjoin 14 10) (innerjoin 15 8) (innerjoin 8 15) (innerjoin 17 1) (innerjoin 1 17)
   121  ├── G19: (innerjoin 10 9) (innerjoin 9 10) (innerjoin 16 8) (innerjoin 8 16) (innerjoin 17 3) (innerjoin 3 17) (leftjoin 18 2)
   122  ├── G20: (innerjoin 11 1) (innerjoin 1 11)
   123  ├── G21: (innerjoin 11 3) (innerjoin 3 11) (leftjoin 20 2)
   124  ├── G22: (innerjoin 11 8) (innerjoin 8 11)
   125  ├── G23: (innerjoin 11 14) (innerjoin 14 11) (innerjoin 20 8) (innerjoin 8 20) (innerjoin 22 1) (innerjoin 1 22)
   126  ├── G24: (innerjoin 11 9) (innerjoin 9 11) (innerjoin 21 8) (innerjoin 8 21) (innerjoin 22 3) (innerjoin 3 22) (leftjoin 23 2)
   127  ├── G25: (innerjoin 11 15) (innerjoin 15 11) (innerjoin 20 10) (innerjoin 10 20) (innerjoin 12 1) (innerjoin 1 12)
   128  ├── G26: (innerjoin 11 16) (innerjoin 16 11) (innerjoin 21 10) (innerjoin 10 21) (innerjoin 12 3) (innerjoin 3 12) (leftjoin 25 2)
   129  ├── G27: (innerjoin 11 17) (innerjoin 17 11) (innerjoin 22 10) (innerjoin 10 22) (innerjoin 12 8) (innerjoin 8 12)
   130  └── G28: (innerjoin 11 18) (innerjoin 18 11) (innerjoin 20 17) (innerjoin 17 20) (innerjoin 22 15) (innerjoin 15 22) (innerjoin 23 10) (innerjoin 10 23) (innerjoin 12 14) (innerjoin 14 12) (innerjoin 25 8) (innerjoin 8 25) (innerjoin 27 1) (innerjoin 1 27)
   131  `,
   132  		},
   133  		{
   134  			name: "test fast reordering algorithm",
   135  			// Optimized plan appears as G11 - (innerjoin 1 12)
   136  			in: plan.NewInnerJoin(
   137  				plan.NewCrossJoin(
   138  					tableNode(db, "a"),
   139  					tableNode(db, "c"),
   140  				),
   141  				tableNode(db, "b"),
   142  				expression.NewAnd(newEq("a.x = b.z"), newEq("b.x = c.z")),
   143  			),
   144  
   145  			forceFastReorder: true,
   146  			plans: `memo:
   147  ├── G1: (tablescan: a)
   148  ├── G2: (tablescan: c)
   149  ├── G3: (crossjoin 1 2)
   150  ├── G4: (tablescan: b)
   151  ├── G5: (innerjoin 1 6) (innerjoin 6 1) (innerjoin 3 4)
   152  └── G6: (innerjoin 4 2) (innerjoin 2 4)
   153  `,
   154  		},
   155  		{
   156  			name: "test fast reordering algorithm on bushy join",
   157  			// Optimized plan appears as G16: (innerjoin 7 17)
   158  			in: plan.NewInnerJoin(
   159  				plan.NewInnerJoin(
   160  					tableNode(db, "c"),
   161  					tableNode(db, "d"),
   162  					newEq("c.x = d.z"),
   163  				),
   164  				plan.NewInnerJoin(
   165  					tableNode(db, "a"),
   166  					tableNode(db, "b"),
   167  					newEq("a.x = b.z"),
   168  				),
   169  				newEq("b.x = c.z"),
   170  			),
   171  
   172  			forceFastReorder: true,
   173  			plans: `memo:
   174  ├── G1: (tablescan: c)
   175  ├── G2: (tablescan: d)
   176  ├── G3: (innerjoin 1 2) (innerjoin 2 1) (innerjoin 1 2)
   177  ├── G4: (tablescan: a)
   178  ├── G5: (tablescan: b)
   179  ├── G6: (innerjoin 4 5)
   180  ├── G7: (innerjoin 4 8) (innerjoin 8 4) (innerjoin 3 6)
   181  └── G8: (innerjoin 5 3) (innerjoin 3 5)
   182  `,
   183  		},
   184  	}
   185  
   186  	for _, tt := range tests {
   187  		t.Run(tt.name, func(t *testing.T) {
   188  			j := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster()))
   189  			j.forceFastDFSLookupForTest = tt.forceFastReorder
   190  			j.ReorderJoin(tt.in)
   191  			require.Equal(t, tt.plans, j.m.String())
   192  		})
   193  	}
   194  }
   195  
   196  func newContext(provider *memory.DbProvider) *sql.Context {
   197  	return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider)))
   198  }
   199  
   200  func TestJoinOrderBuilder_populateSubgraph(t *testing.T) {
   201  	db := memory.NewDatabase("test")
   202  	pro := memory.NewDBProvider(db)
   203  
   204  	tests := []struct {
   205  		name     string
   206  		join     sql.Node
   207  		expEdges []edge
   208  	}{
   209  		{
   210  			name: "cross join",
   211  			join: plan.NewCrossJoin(
   212  				tableNode(db, "a"),
   213  				plan.NewInnerJoin(
   214  					tableNode(db, "b"),
   215  					plan.NewLeftOuterJoin(
   216  						tableNode(db, "c"),
   217  						tableNode(db, "d"),
   218  						newEq("c.x=d.x"),
   219  					),
   220  					newEq("b.y=d.y"),
   221  				),
   222  			),
   223  			expEdges: []edge{
   224  				newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil,
   225  					newEq("c.x=d.x"),
   226  					""), // C x D
   227  				newEdge2(plan.JoinTypeInner, "0101", "0111", "0100", "0011", nil,
   228  					newEq("b.y=d.y"),
   229  					""), // B x (CD)
   230  				newEdge2(plan.JoinTypeCross, "0000", "1111", "1000", "0111", nil, nil, ""), // A x (BCD)
   231  			},
   232  		},
   233  		{
   234  			name: "right deep left join",
   235  			join: plan.NewInnerJoin(
   236  				tableNode(db, "a"),
   237  				plan.NewInnerJoin(
   238  					tableNode(db, "b"),
   239  					plan.NewLeftOuterJoin(
   240  						tableNode(db, "c"),
   241  						tableNode(db, "d"),
   242  						newEq("c.x=d.x"),
   243  					),
   244  					newEq("b.y=d.y"),
   245  				),
   246  				newEq("a.z=b.z"),
   247  			),
   248  			expEdges: []edge{
   249  				newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil,
   250  					newEq("c.x=d.x"),
   251  					""), // C x D
   252  				newEdge2(plan.JoinTypeInner, "0101", "0111", "0100", "0011", nil,
   253  					newEq("b.y=d.y"),
   254  
   255  					""), // B x (CD)
   256  				newEdge2(plan.JoinTypeInner, "1100", "1100", "1000", "0111", []conflictRule{{from: newVertexSet("0001"), to: newVertexSet("0010")}},
   257  					newEq("a.z=b.z"),
   258  
   259  					""), // A x (BCD)
   260  			},
   261  		},
   262  		{
   263  			name: "bushy left joins",
   264  			join: plan.NewLeftOuterJoin(
   265  				plan.NewLeftOuterJoin(
   266  					tableNode(db, "a"),
   267  					tableNode(db, "b"),
   268  					newEq("a.x=b.x"),
   269  				),
   270  				plan.NewLeftOuterJoin(
   271  					tableNode(db, "c"),
   272  					tableNode(db, "d"),
   273  					newEq("c.x=d.x"),
   274  				),
   275  				newEq("b.y=c.y"),
   276  			),
   277  			expEdges: []edge{
   278  				newEdge2(plan.JoinTypeLeftOuter, "1100", "1100", "1000", "0100", nil,
   279  					newEq("a.x=b.x"),
   280  					""), // A x B
   281  				newEdge2(plan.JoinTypeLeftOuter, "0011", "0011", "0010", "0001", nil,
   282  					newEq("c.x=d.x"), // offset by filters
   283  					""),              // C x D
   284  				newEdge2(plan.JoinTypeLeftOuter, "0110", "1111", "1100", "0011", nil,
   285  					newEq("b.y=c.y"),
   286  					""), // (AB) x (CD)
   287  			},
   288  		},
   289  		{
   290  			// SELECT *
   291  			// FROM (SELECT * FROM A CROSS JOIN B)
   292  			// LEFT JOIN C
   293  			// ON B.x = C.x
   294  			name: "degenerate inner join",
   295  			join: plan.NewLeftOuterJoin(
   296  				plan.NewCrossJoin(
   297  					tableNode(db, "a"),
   298  					tableNode(db, "b"),
   299  				),
   300  				tableNode(db, "c"),
   301  				newEq("b.x=c.x"),
   302  			),
   303  			expEdges: []edge{
   304  				newEdge2(plan.JoinTypeCross, "000", "110", "100", "010", nil, nil, ""), // A X B
   305  				newEdge2(plan.JoinTypeLeftOuter, "011", "111", "110", "001", nil,
   306  					newEq("b.x=c.x"),
   307  
   308  					""), // (AB) x C
   309  			},
   310  		},
   311  		{
   312  			// SELECT *
   313  			// FROM (SELECT * FROM A INNER JOIN B ON True)
   314  			// FULL JOIN (SELECT * FROM C INNER JOIN D ON True)
   315  			// ON A.x = C.x
   316  			name: "degenerate inner join",
   317  			join: plan.NewFullOuterJoin(
   318  				plan.NewInnerJoin(
   319  					tableNode(db, "a"),
   320  					tableNode(db, "b"),
   321  					expression.NewLiteral(true, types.Boolean),
   322  				),
   323  				plan.NewInnerJoin(
   324  					tableNode(db, "c"),
   325  					tableNode(db, "d"),
   326  					expression.NewLiteral(true, types.Boolean),
   327  				),
   328  				newEq("a.x=c.x"),
   329  			),
   330  			expEdges: []edge{
   331  				newEdge2(plan.JoinTypeInner, "0000", "1100", "1000", "0100", nil, expression.NewLiteral(true, types.Boolean), ""), // A x B
   332  				newEdge2(plan.JoinTypeInner, "0000", "0011", "0010", "0001", nil, expression.NewLiteral(true, types.Boolean), ""), // C x D
   333  				newEdge2(plan.JoinTypeFullOuter, "1010", "1111", "1100", "0011", nil,
   334  					newEq("a.x=c.x"),
   335  					""), // (AB) x (CD)
   336  			},
   337  		},
   338  		{
   339  			// SELECT * FROM A
   340  			// WHERE EXISTS
   341  			// (
   342  			//   SELECT * FROM B
   343  			//   LEFT JOIN C ON B.x = C.x
   344  			//   WHERE A.y = B.y
   345  			// )
   346  			// note: left join is the right child
   347  			name: "semi join",
   348  			join: plan.NewSemiJoin(
   349  				plan.NewLeftOuterJoin(
   350  					tableNode(db, "b"),
   351  					tableNode(db, "c"),
   352  					newEq("b.x=c.x"),
   353  				),
   354  				tableNode(db, "a"),
   355  				newEq("a.y=b.y"),
   356  			),
   357  			expEdges: []edge{
   358  				newEdge2(plan.JoinTypeLeftOuter, "110", "110", "100", "010", nil,
   359  					newEq("b.x=c.x"),
   360  					""), // B x C
   361  				newEdge2(plan.JoinTypeSemi, "101", "101", "110", "001", nil,
   362  					newEq("a.y=b.y"),
   363  					""), // A x (BC)
   364  			},
   365  		},
   366  	}
   367  
   368  	for _, tt := range tests {
   369  		t.Run(tt.name, func(t *testing.T) {
   370  			b := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster()))
   371  			b.populateSubgraph(tt.join)
   372  			edgesEq(t, tt.expEdges, b.edges)
   373  		})
   374  	}
   375  }
   376  
   377  func newEq(eq string) sql.Expression {
   378  	vars := strings.Split(strings.Replace(eq, " ", "", -1), "=")
   379  	if len(vars) > 2 {
   380  		panic("invalid equal expression")
   381  	}
   382  	left := strings.Split(vars[0], ".")
   383  	right := strings.Split(vars[1], ".")
   384  	leftTabId, leftColId := getIds(left)
   385  	rightTabId, rightColId := getIds(right)
   386  	return expression.NewEquals(
   387  		expression.NewGetFieldWithTable(leftColId, leftTabId, types.Int64, "", left[0], left[1], false),
   388  		expression.NewGetFieldWithTable(rightColId, rightTabId, types.Int64, "", right[0], right[1], false),
   389  	)
   390  }
   391  
   392  func getIds(s []string) (tabId int, colId int) {
   393  	switch s[0] {
   394  	case "a":
   395  		tabId = 1
   396  	case "b":
   397  		tabId = 2
   398  	case "c":
   399  		tabId = 3
   400  	case "d":
   401  		tabId = 4
   402  	case "e":
   403  		tabId = 5
   404  	case "f":
   405  		tabId = 6
   406  	case "g":
   407  		tabId = 7
   408  	case "xy":
   409  		tabId = 1
   410  	case "uv":
   411  		tabId = 2
   412  	case "ab":
   413  		tabId = 3
   414  	case "pq":
   415  		tabId = 4
   416  	}
   417  	switch s[1] {
   418  	case "x":
   419  		colId = (tabId-1)*3 + 1
   420  	case "y":
   421  		colId = (tabId-1)*3 + 2
   422  	case "z":
   423  		colId = (tabId-1)*3 + 3
   424  	}
   425  	return
   426  }
   427  
   428  func TestAssociativeTransforms(t *testing.T) {
   429  	// Sourced from Figure 3
   430  	// each test has a reversible pair test which is a product of its transform
   431  	validTests := []struct {
   432  		name      string
   433  		eA        *edge
   434  		eB        *edge
   435  		transform assocTransform
   436  		rev       bool
   437  	}{
   438  		{
   439  			name:      "assoc(a,b)",
   440  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   441  			eB:        newEdge(plan.JoinTypeInner, "101", "110", "001"),
   442  			transform: assoc,
   443  		},
   444  		{
   445  			name:      "assoc(b,a)",
   446  			eA:        newEdge(plan.JoinTypeInner, "010", "101", "010"),
   447  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "100"),
   448  			transform: assoc,
   449  			rev:       true,
   450  		},
   451  		{
   452  			name:      "r-asscom(a,b)",
   453  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   454  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "110"),
   455  			transform: rightAsscom,
   456  		},
   457  		{
   458  			name:      "r-asscom(b,a)",
   459  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "101"),
   460  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "100"),
   461  			transform: rightAsscom,
   462  			rev:       true,
   463  		},
   464  		{
   465  			name:      "l-asscom(a,b)",
   466  			eA:        newEdge(plan.JoinTypeInner, "110", "100", "010"),
   467  			eB:        newEdge(plan.JoinTypeInner, "101", "110", "001"),
   468  			transform: leftAsscom,
   469  		},
   470  		{
   471  			name:      "l-asscom(b,a)",
   472  			eA:        newEdge(plan.JoinTypeInner, "110", "101", "010"),
   473  			eB:        newEdge(plan.JoinTypeInner, "101", "100", "001"),
   474  			transform: leftAsscom,
   475  			rev:       true,
   476  		},
   477  		{
   478  			name:      "assoc(a,b)",
   479  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   480  			eB:        newEdge(plan.JoinTypeLeftOuter, "101", "110", "001"),
   481  			transform: assoc,
   482  		},
   483  		// l-asscom is OK with everything but full outerjoin w/ null rejecting A(e1).
   484  		// Refer to rule table.
   485  		{
   486  			name:      "l-asscom(a,b)",
   487  			eA:        newEdge(plan.JoinTypeLeftOuter, "110", "100", "010"),
   488  			eB:        newEdge(plan.JoinTypeInner, "101", "110", "001"),
   489  			transform: leftAsscom,
   490  		},
   491  		{
   492  			name:      "l-asscom(b,a)",
   493  			eA:        newEdge(plan.JoinTypeLeftOuter, "110", "101", "010"),
   494  			eB:        newEdge(plan.JoinTypeLeftOuter, "101", "100", "001"),
   495  			transform: leftAsscom,
   496  			rev:       true,
   497  		},
   498  		// TODO special case operators
   499  	}
   500  
   501  	for _, tt := range validTests {
   502  		t.Run(fmt.Sprintf("OK %s", tt.name), func(t *testing.T) {
   503  			var res bool
   504  			if tt.rev {
   505  				res = tt.transform(tt.eB, tt.eA)
   506  			} else {
   507  				res = tt.transform(tt.eA, tt.eB)
   508  			}
   509  			require.True(t, res)
   510  		})
   511  	}
   512  
   513  	invalidTests := []struct {
   514  		name      string
   515  		eA        *edge
   516  		eB        *edge
   517  		transform assocTransform
   518  		rev       bool
   519  	}{
   520  		// most transforms are invalid, these are also from Figure 3
   521  		{
   522  			name:      "assoc(a,b)",
   523  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   524  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "100"),
   525  			transform: assoc,
   526  		},
   527  		{
   528  			name:      "r-asscom(a,b)",
   529  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   530  			eB:        newEdge(plan.JoinTypeInner, "101", "100", "010"),
   531  			transform: rightAsscom,
   532  		},
   533  		{
   534  			name:      "l-asscom(a,b)",
   535  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "100"),
   536  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "100"),
   537  			transform: leftAsscom,
   538  		},
   539  		// these are correct transforms with cross or inner joins, but invalid
   540  		// with other operators
   541  		{
   542  			name:      "assoc(a,b)",
   543  			eA:        newEdge(plan.JoinTypeLeftOuter, "110", "010", "100"),
   544  			eB:        newEdge(plan.JoinTypeInner, "101", "110", "001"),
   545  			transform: assoc,
   546  		},
   547  		{
   548  			// this one depends on rejecting nulls on A(e2)
   549  			name:      "left join assoc(b,a)",
   550  			eA:        newEdge(plan.JoinTypeLeftOuter, "010", "101", "010"),
   551  			eB:        newEdge(plan.JoinTypeLeftOuter, "101", "001", "100"),
   552  			transform: assoc,
   553  			rev:       true,
   554  		},
   555  		{
   556  			name:      "left join r-asscom(a,b)",
   557  			eA:        newEdge(plan.JoinTypeLeftOuter, "110", "010", "100"),
   558  			eB:        newEdge(plan.JoinTypeInner, "101", "001", "110"),
   559  			transform: rightAsscom,
   560  		},
   561  		{
   562  			name:      "left join r-asscom(b,a)",
   563  			eA:        newEdge(plan.JoinTypeInner, "110", "010", "101"),
   564  			eB:        newEdge(plan.JoinTypeLeftOuter, "101", "001", "100"),
   565  			transform: rightAsscom,
   566  			rev:       true,
   567  		},
   568  		{
   569  			name:      "left join l-asscom(a,b)",
   570  			eA:        newEdge(plan.JoinTypeFullOuter, "110", "100", "010"),
   571  			eB:        newEdge(plan.JoinTypeInner, "101", "110", "001"),
   572  			transform: leftAsscom,
   573  		},
   574  	}
   575  
   576  	for _, tt := range invalidTests {
   577  		t.Run(fmt.Sprintf("Invalid %s", tt.name), func(t *testing.T) {
   578  			var res bool
   579  			if tt.rev {
   580  				res = tt.transform(tt.eB, tt.eA)
   581  			} else {
   582  				res = tt.transform(tt.eA, tt.eB)
   583  			}
   584  			require.False(t, res)
   585  		})
   586  	}
   587  }
   588  
   589  func TestEnsureClosure(t *testing.T) {
   590  	db := memory.NewDatabase("test")
   591  	pro := memory.NewDBProvider(db)
   592  
   593  	tests := []struct {
   594  		in       sql.Node
   595  		name     string
   596  		expEdges []edge
   597  	}{
   598  		{
   599  			name: "inner joins",
   600  			in: plan.NewInnerJoin(
   601  				plan.NewInnerJoin(
   602  					plan.NewInnerJoin(
   603  						tableNode(db, "a"),
   604  						tableNode(db, "b"),
   605  						newEq("a.x = b.x"),
   606  					),
   607  					tableNode(db, "c"),
   608  					newEq("b.x = c.x"),
   609  				),
   610  				tableNode(db, "d"),
   611  				newEq("c.x = d.x"),
   612  			),
   613  			expEdges: []edge{
   614  				newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil,
   615  					newEq("a.x=c.x"),
   616  
   617  					""), // (A)B x (C)
   618  				newEdge2(plan.JoinTypeInner, "1001", "1001", "1110", "0001", []conflictRule{{from: 4, to: 2}},
   619  					newEq("a.x=d.x"),
   620  
   621  					""), // (A)BC x (D)
   622  				newEdge2(plan.JoinTypeInner, "0101", "0101", "1110", "0001", nil,
   623  					newEq("b.x=d.x"),
   624  
   625  					""), // A(B)C x (D)
   626  			},
   627  		},
   628  		{
   629  			name: "left joins",
   630  			in: plan.NewLeftOuterJoin(
   631  				plan.NewInnerJoin(
   632  					plan.NewInnerJoin(
   633  						tableNode(db, "a"),
   634  						tableNode(db, "b"),
   635  						newEq("a.x = b.x"),
   636  					),
   637  					tableNode(db, "c"),
   638  					newEq("b.x = c.x"),
   639  				),
   640  				tableNode(db, "d"),
   641  				newEq("c.x = d.x"),
   642  			),
   643  			expEdges: []edge{
   644  				newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil,
   645  					newEq("a.x=c.x"),
   646  					""), // (A)B x (C)
   647  			},
   648  		},
   649  		{
   650  			name: "left join equivalence doesn't hold",
   651  			in: plan.NewLeftOuterJoin(
   652  				plan.NewInnerJoin(
   653  					plan.NewInnerJoin(
   654  						tableNode(db, "a"),
   655  						tableNode(db, "b"),
   656  						newEq("a.x = b.x"),
   657  					),
   658  					tableNode(db, "c"),
   659  					newEq("b.x = c.x"),
   660  				),
   661  				tableNode(db, "d"),
   662  				newEq("c.x = d.x"),
   663  			),
   664  			expEdges: []edge{
   665  				newEdge2(plan.JoinTypeInner, "1010", "1010", "1100", "0010", nil,
   666  					newEq("a.x=c.x"),
   667  					""), // (A)B x (C)
   668  			},
   669  		},
   670  	}
   671  
   672  	for _, tt := range tests {
   673  		t.Run(tt.name, func(t *testing.T) {
   674  			b := NewJoinOrderBuilder(NewMemo(newContext(pro), nil, nil, 0, NewDefaultCoster()))
   675  			b.populateSubgraph(tt.in)
   676  			beforeLen := len(b.edges)
   677  			b.ensureClosure(b.m.Root())
   678  			newEdges := b.edges[beforeLen:]
   679  			edgesEq(t, tt.expEdges, newEdges)
   680  		})
   681  	}
   682  }
   683  
   684  func childSchema(source string) sql.PrimaryKeySchema {
   685  	return sql.NewPrimaryKeySchema(sql.Schema{
   686  		{Name: "x", Source: source, Type: types.Int64, Nullable: false},
   687  		{Name: "y", Source: source, Type: types.Text, Nullable: true},
   688  		{Name: "z", Source: source, Type: types.Int64, Nullable: true},
   689  	}, 0)
   690  }
   691  
   692  func tableNode(db *memory.Database, name string) sql.Node {
   693  	t := memory.NewTable(db, name, childSchema(name), nil)
   694  	t.EnablePrimaryKeyIndexes()
   695  	tabId, colId := getIds([]string{name, "x"})
   696  	colset := sql.NewColSet(sql.ColumnId(colId), sql.ColumnId(colId+1), sql.ColumnId(colId+2))
   697  	return plan.NewResolvedTable(t, db, nil).WithId(sql.TableId(tabId)).WithColumns(colset)
   698  }
   699  
   700  func newVertexSet(s string) vertexSet {
   701  	v := vertexSet(0)
   702  	for i, c := range s {
   703  		if string(c) == "1" {
   704  			v = v.add(uint64(i))
   705  		}
   706  	}
   707  	return v
   708  }
   709  
   710  func newEdge(op plan.JoinType, ses, leftV, rightV string) *edge {
   711  	return &edge{
   712  		op: &operator{
   713  			joinType:      op,
   714  			rightVertices: newVertexSet(rightV),
   715  			leftVertices:  newVertexSet(leftV),
   716  		},
   717  		ses: newVertexSet(ses),
   718  	}
   719  }
   720  
   721  func newEdge2(op plan.JoinType, ses, tes, leftV, rightV string, rules []conflictRule, filter sql.Expression, nullRej string) edge {
   722  	var filters []sql.Expression
   723  	if filter != nil {
   724  		filters = []sql.Expression{filter}
   725  	}
   726  	return edge{
   727  		op: &operator{
   728  			joinType:      op,
   729  			rightVertices: newVertexSet(rightV),
   730  			leftVertices:  newVertexSet(leftV),
   731  		},
   732  		ses:              newVertexSet(ses),
   733  		tes:              newVertexSet(tes),
   734  		rules:            rules,
   735  		filters:          filters,
   736  		nullRejectedRels: newVertexSet(nullRej),
   737  	}
   738  }
   739  
   740  func edgesEq(t *testing.T, edges1, edges2 []edge) bool {
   741  	if len(edges1) != len(edges2) {
   742  		return false
   743  	}
   744  	for i := range edges1 {
   745  		e1 := edges1[i]
   746  		e2 := edges2[i]
   747  		require.Equal(t, e1.op.joinType, e2.op.joinType)
   748  		require.Equal(t, e1.op.leftVertices.String(), e2.op.leftVertices.String())
   749  		require.Equal(t, e1.op.rightVertices.String(), e2.op.rightVertices.String())
   750  		require.Equal(t, len(e1.filters), len(e2.filters))
   751  		for i := range e1.filters {
   752  			assertScalarEq(t, e1.filters[i], e2.filters[i])
   753  		}
   754  		require.Equal(t, e1.nullRejectedRels, e2.nullRejectedRels)
   755  		require.Equal(t, e1.tes, e2.tes)
   756  		require.Equal(t, e1.ses, e2.ses)
   757  		require.Equal(t, e1.rules, e2.rules)
   758  	}
   759  	return true
   760  }
   761  
   762  func assertScalarEq(t *testing.T, exp, cmp sql.Expression) {
   763  	switch cmp := cmp.(type) {
   764  	case *expression.Equals:
   765  		exp, ok := exp.(*expression.Equals)
   766  		require.True(t, ok)
   767  		assertScalarEq(t, exp.Left(), cmp.Left())
   768  		assertScalarEq(t, exp.Right(), cmp.Right())
   769  	case *expression.Literal:
   770  		exp, ok := exp.(*expression.Literal)
   771  		require.True(t, ok)
   772  		require.Equal(t, exp.Value(), cmp.Value())
   773  	case *expression.GetField:
   774  		exp, ok := exp.(*expression.GetField)
   775  		require.True(t, ok)
   776  		require.Equal(t, exp.Table(), cmp.Table())
   777  		require.Equal(t, exp.Name(), cmp.Name())
   778  		require.Equal(t, exp.String(), cmp.String())
   779  	}
   780  }