github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/indexed_joins_test.go (about) 1 package analyzer 2 3 import ( 4 "strings" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 9 "github.com/dolthub/go-mysql-server/memory" 10 "github.com/dolthub/go-mysql-server/sql" 11 "github.com/dolthub/go-mysql-server/sql/expression" 12 "github.com/dolthub/go-mysql-server/sql/memo" 13 "github.com/dolthub/go-mysql-server/sql/plan" 14 "github.com/dolthub/go-mysql-server/sql/types" 15 ) 16 17 func TestHashJoins(t *testing.T) { 18 db := memory.NewDatabase("db") 19 20 tests := []struct { 21 name string 22 plan sql.Node 23 memo string 24 }{ 25 { 26 name: "hash join 1", 27 plan: plan.NewInnerJoin( 28 plan.NewInnerJoin( 29 plan.NewInnerJoin( 30 ab(db), 31 xy(db), 32 newEq("ab.a=xy.x"), 33 ), 34 pq(db), 35 newEq("xy.x=pq.p"), 36 ), 37 uv(db), 38 newEq("pq.q=uv.u"), 39 ), 40 memo: `memo: 41 ├── G1: (tablescan: ab) 42 ├── G2: (tablescan: xy) 43 ├── G3: (hashjoin 1 2) (hashjoin 2 1) (innerjoin 2 1) (innerjoin 1 2) 44 ├── G4: (tablescan: pq) 45 ├── G5: (hashjoin 3 4) (hashjoin 1 9) (hashjoin 9 1) (hashjoin 2 8) (hashjoin 8 2) (hashjoin 4 3) (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4) 46 ├── G6: (tablescan: uv) 47 ├── G7: (hashjoin 5 6) (hashjoin 1 12) (hashjoin 12 1) (hashjoin 2 11) (hashjoin 11 2) (hashjoin 3 10) (hashjoin 10 3) (hashjoin 6 5) (innerjoin 6 5) (innerjoin 10 3) (innerjoin 3 10) (innerjoin 11 2) (innerjoin 2 11) (innerjoin 12 1) (innerjoin 1 12) (innerjoin 5 6) 48 ├── G8: (hashjoin 1 4) (hashjoin 4 1) (innerjoin 4 1) (innerjoin 1 4) 49 ├── G9: (hashjoin 2 4) (hashjoin 4 2) (innerjoin 4 2) (innerjoin 2 4) 50 ├── G10: (hashjoin 4 6) (hashjoin 6 4) (innerjoin 6 4) (innerjoin 4 6) 51 ├── G11: (hashjoin 1 10) (hashjoin 10 1) (hashjoin 8 6) (hashjoin 6 8) (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 1) (innerjoin 1 10) 52 └── G12: (hashjoin 2 10) (hashjoin 10 2) (hashjoin 9 6) (hashjoin 6 9) (innerjoin 6 9) (innerjoin 9 6) (innerjoin 10 2) (innerjoin 2 10) 53 `, 54 }, 55 } 56 57 pro := memory.NewDBProvider(db) 58 ctx := newContext(pro) 59 60 for _, tt := range tests { 61 t.Run(tt.name, func(t *testing.T) { 62 m := memo.NewMemo(ctx, newTestCatalog(db), nil, 0, memo.NewDefaultCoster()) 63 j := memo.NewJoinOrderBuilder(m) 64 j.ReorderJoin(tt.plan) 65 addHashJoins(m) 66 require.Equal(t, tt.memo, m.String()) 67 }) 68 } 69 } 70 71 var childSchema = sql.NewPrimaryKeySchema(sql.Schema{ 72 {Name: "i", Type: types.Int64, Nullable: true}, 73 {Name: "s", Type: types.Text, Nullable: true}, 74 }) 75 76 func uv(db *memory.Database) sql.Node { 77 t := memory.NewTable(db, "uv", sql.NewPrimaryKeySchema(sql.Schema{ 78 {Name: "u", Type: types.Int64, Nullable: true}, 79 {Name: "v", Type: types.Text, Nullable: true}, 80 }, 0), nil) 81 return plan.NewResolvedTable(t, db, nil).WithId(4).WithColumns(sql.NewColSet(7, 8)) 82 } 83 84 func xy(db *memory.Database) sql.Node { 85 t := memory.NewTable(db, "xy", sql.NewPrimaryKeySchema(sql.Schema{ 86 {Name: "x", Type: types.Int64, Nullable: true}, 87 {Name: "y", Type: types.Text, Nullable: true}, 88 }, 0), nil) 89 return plan.NewResolvedTable(t, db, nil).WithId(1).WithColumns(sql.NewColSet(1, 2)) 90 } 91 92 func ab(db *memory.Database) sql.Node { 93 t := memory.NewTable(db, "ab", sql.NewPrimaryKeySchema(sql.Schema{ 94 {Name: "a", Type: types.Int64, Nullable: true}, 95 {Name: "b", Type: types.Text, Nullable: true}, 96 }, 0), nil) 97 return plan.NewResolvedTable(t, db, nil).WithId(2).WithColumns(sql.NewColSet(3, 4)) 98 } 99 100 func pq(db *memory.Database) sql.Node { 101 t := memory.NewTable(db, "pq", sql.NewPrimaryKeySchema(sql.Schema{ 102 {Name: "p", Type: types.Int64, Nullable: true}, 103 {Name: "q", Type: types.Text, Nullable: true}, 104 }, 0), nil) 105 return plan.NewResolvedTable(t, db, nil).WithId(3).WithColumns(sql.NewColSet(5, 6)) 106 } 107 108 func newEq(eq string) sql.Expression { 109 vars := strings.Split(eq, "=") 110 if len(vars) > 2 { 111 panic("invalid equal expression") 112 } 113 left := strings.Split(vars[0], ".") 114 right := strings.Split(vars[1], ".") 115 return expression.NewEquals( 116 expression.NewGetFieldWithTable(colId(left[1]), tabId(left[0]), types.Int64, "db", left[0], left[1], false), 117 expression.NewGetFieldWithTable(colId(right[1]), tabId(right[0]), types.Int64, "db", right[0], right[1], false), 118 ) 119 } 120 121 func colId(n string) int { 122 switch n { 123 case "x": 124 return 1 125 case "y": 126 return 2 127 case "a": 128 return 3 129 case "b": 130 return 4 131 case "p": 132 return 5 133 case "q": 134 return 6 135 case "u": 136 return 7 137 case "v": 138 return 8 139 default: 140 panic("unknown col") 141 } 142 } 143 144 func tabId(n string) int { 145 switch n { 146 case "xy": 147 return 1 148 case "ab": 149 return 2 150 case "pq": 151 return 3 152 case "uv": 153 return 4 154 default: 155 panic("unknown table") 156 } 157 }