github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/indexed_joins_test.go (about)

     1  package analyzer
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/dolthub/go-mysql-server/memory"
    10  	"github.com/dolthub/go-mysql-server/sql"
    11  	"github.com/dolthub/go-mysql-server/sql/expression"
    12  	"github.com/dolthub/go-mysql-server/sql/memo"
    13  	"github.com/dolthub/go-mysql-server/sql/plan"
    14  	"github.com/dolthub/go-mysql-server/sql/types"
    15  )
    16  
    17  func TestHashJoins(t *testing.T) {
    18  	db := memory.NewDatabase("db")
    19  
    20  	tests := []struct {
    21  		name string
    22  		plan sql.Node
    23  		memo string
    24  	}{
    25  		{
    26  			name: "hash join 1",
    27  			plan: plan.NewInnerJoin(
    28  				plan.NewInnerJoin(
    29  					plan.NewInnerJoin(
    30  						ab(db),
    31  						xy(db),
    32  						newEq("ab.a=xy.x"),
    33  					),
    34  					pq(db),
    35  					newEq("xy.x=pq.p"),
    36  				),
    37  				uv(db),
    38  				newEq("pq.q=uv.u"),
    39  			),
    40  			memo: `memo:
    41  ├── G1: (tablescan: ab)
    42  ├── G2: (tablescan: xy)
    43  ├── G3: (hashjoin 1 2) (hashjoin 2 1) (innerjoin 2 1) (innerjoin 1 2)
    44  ├── G4: (tablescan: pq)
    45  ├── G5: (hashjoin 3 4) (hashjoin 1 9) (hashjoin 9 1) (hashjoin 2 8) (hashjoin 8 2) (hashjoin 4 3) (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4)
    46  ├── G6: (tablescan: uv)
    47  ├── G7: (hashjoin 5 6) (hashjoin 1 12) (hashjoin 12 1) (hashjoin 2 11) (hashjoin 11 2) (hashjoin 3 10) (hashjoin 10 3) (hashjoin 6 5) (innerjoin 6 5) (innerjoin 10 3) (innerjoin 3 10) (innerjoin 11 2) (innerjoin 2 11) (innerjoin 12 1) (innerjoin 1 12) (innerjoin 5 6)
    48  ├── G8: (hashjoin 1 4) (hashjoin 4 1) (innerjoin 4 1) (innerjoin 1 4)
    49  ├── G9: (hashjoin 2 4) (hashjoin 4 2) (innerjoin 4 2) (innerjoin 2 4)
    50  ├── G10: (hashjoin 4 6) (hashjoin 6 4) (innerjoin 6 4) (innerjoin 4 6)
    51  ├── G11: (hashjoin 1 10) (hashjoin 10 1) (hashjoin 8 6) (hashjoin 6 8) (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 1) (innerjoin 1 10)
    52  └── G12: (hashjoin 2 10) (hashjoin 10 2) (hashjoin 9 6) (hashjoin 6 9) (innerjoin 6 9) (innerjoin 9 6) (innerjoin 10 2) (innerjoin 2 10)
    53  `,
    54  		},
    55  	}
    56  
    57  	pro := memory.NewDBProvider(db)
    58  	ctx := newContext(pro)
    59  
    60  	for _, tt := range tests {
    61  		t.Run(tt.name, func(t *testing.T) {
    62  			m := memo.NewMemo(ctx, newTestCatalog(db), nil, 0, memo.NewDefaultCoster())
    63  			j := memo.NewJoinOrderBuilder(m)
    64  			j.ReorderJoin(tt.plan)
    65  			addHashJoins(m)
    66  			require.Equal(t, tt.memo, m.String())
    67  		})
    68  	}
    69  }
    70  
    71  var childSchema = sql.NewPrimaryKeySchema(sql.Schema{
    72  	{Name: "i", Type: types.Int64, Nullable: true},
    73  	{Name: "s", Type: types.Text, Nullable: true},
    74  })
    75  
    76  func uv(db *memory.Database) sql.Node {
    77  	t := memory.NewTable(db, "uv", sql.NewPrimaryKeySchema(sql.Schema{
    78  		{Name: "u", Type: types.Int64, Nullable: true},
    79  		{Name: "v", Type: types.Text, Nullable: true},
    80  	}, 0), nil)
    81  	return plan.NewResolvedTable(t, db, nil).WithId(4).WithColumns(sql.NewColSet(7, 8))
    82  }
    83  
    84  func xy(db *memory.Database) sql.Node {
    85  	t := memory.NewTable(db, "xy", sql.NewPrimaryKeySchema(sql.Schema{
    86  		{Name: "x", Type: types.Int64, Nullable: true},
    87  		{Name: "y", Type: types.Text, Nullable: true},
    88  	}, 0), nil)
    89  	return plan.NewResolvedTable(t, db, nil).WithId(1).WithColumns(sql.NewColSet(1, 2))
    90  }
    91  
    92  func ab(db *memory.Database) sql.Node {
    93  	t := memory.NewTable(db, "ab", sql.NewPrimaryKeySchema(sql.Schema{
    94  		{Name: "a", Type: types.Int64, Nullable: true},
    95  		{Name: "b", Type: types.Text, Nullable: true},
    96  	}, 0), nil)
    97  	return plan.NewResolvedTable(t, db, nil).WithId(2).WithColumns(sql.NewColSet(3, 4))
    98  }
    99  
   100  func pq(db *memory.Database) sql.Node {
   101  	t := memory.NewTable(db, "pq", sql.NewPrimaryKeySchema(sql.Schema{
   102  		{Name: "p", Type: types.Int64, Nullable: true},
   103  		{Name: "q", Type: types.Text, Nullable: true},
   104  	}, 0), nil)
   105  	return plan.NewResolvedTable(t, db, nil).WithId(3).WithColumns(sql.NewColSet(5, 6))
   106  }
   107  
   108  func newEq(eq string) sql.Expression {
   109  	vars := strings.Split(eq, "=")
   110  	if len(vars) > 2 {
   111  		panic("invalid equal expression")
   112  	}
   113  	left := strings.Split(vars[0], ".")
   114  	right := strings.Split(vars[1], ".")
   115  	return expression.NewEquals(
   116  		expression.NewGetFieldWithTable(colId(left[1]), tabId(left[0]), types.Int64, "db", left[0], left[1], false),
   117  		expression.NewGetFieldWithTable(colId(right[1]), tabId(right[0]), types.Int64, "db", right[0], right[1], false),
   118  	)
   119  }
   120  
   121  func colId(n string) int {
   122  	switch n {
   123  	case "x":
   124  		return 1
   125  	case "y":
   126  		return 2
   127  	case "a":
   128  		return 3
   129  	case "b":
   130  		return 4
   131  	case "p":
   132  		return 5
   133  	case "q":
   134  		return 6
   135  	case "u":
   136  		return 7
   137  	case "v":
   138  		return 8
   139  	default:
   140  		panic("unknown col")
   141  	}
   142  }
   143  
   144  func tabId(n string) int {
   145  	switch n {
   146  	case "xy":
   147  		return 1
   148  	case "ab":
   149  		return 2
   150  	case "pq":
   151  		return 3
   152  	case "uv":
   153  		return 4
   154  	default:
   155  		panic("unknown table")
   156  	}
   157  }