github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/join_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "context" 19 "fmt" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/dolthub/go-mysql-server/memory" 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestJoinSchema(t *testing.T) { 32 db := memory.NewDatabase("test") 33 t1 := plan.NewResolvedTable(memory.NewTable(db.BaseDatabase, "foo", sql.NewPrimaryKeySchema(sql.Schema{ 34 {Name: "a", Source: "foo", Type: types.Int64}, 35 }), nil), nil, nil) 36 37 t2 := plan.NewResolvedTable(memory.NewTable(db.BaseDatabase, "bar", sql.NewPrimaryKeySchema(sql.Schema{ 38 {Name: "b", Source: "bar", Type: types.Int64}, 39 }), nil), nil, nil) 40 41 t.Run("inner", func(t *testing.T) { 42 j := plan.NewInnerJoin(t1, t2, nil) 43 result := j.Schema() 44 45 require.Equal(t, sql.Schema{ 46 {Name: "a", Source: "foo", Type: types.Int64}, 47 {Name: "b", Source: "bar", Type: types.Int64}, 48 }, result) 49 }) 50 51 t.Run("left", func(t *testing.T) { 52 j := plan.NewLeftOuterJoin(t1, t2, nil) 53 result := j.Schema() 54 55 require.Equal(t, sql.Schema{ 56 {Name: "a", Source: "foo", Type: types.Int64}, 57 {Name: "b", Source: "bar", Type: types.Int64, Nullable: true}, 58 }, result) 59 }) 60 61 t.Run("right", func(t *testing.T) { 62 j := plan.NewRightOuterJoin(t1, t2, nil) 63 result := j.Schema() 64 65 require.Equal(t, sql.Schema{ 66 {Name: "a", Source: "foo", Type: types.Int64, Nullable: true}, 67 {Name: "b", Source: "bar", Type: types.Int64}, 68 }, result) 69 }) 70 } 71 72 func TestInnerJoin(t *testing.T) { 73 db := memory.NewDatabase("test") 74 pro := memory.NewDBProvider(db) 75 ctx := newContext(pro) 76 testInnerJoin(t, db, ctx) 77 } 78 79 func TestMultiPassInnerJoin(t *testing.T) { 80 db := memory.NewDatabase("test") 81 pro := memory.NewDBProvider(db) 82 ctx := sql.NewContext(context.TODO(), sql.WithMemoryManager( 83 sql.NewMemoryManager(mockReporter{2, 1}), 84 ), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 85 testInnerJoin(t, db, ctx) 86 } 87 88 func testInnerJoin(t *testing.T, db *memory.Database, ctx *sql.Context) { 89 t.Helper() 90 91 require := require.New(t) 92 93 ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil) 94 rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil) 95 insertData(t, ctx, ltable) 96 insertData(t, ctx, rtable) 97 98 j := plan.NewInnerJoin( 99 plan.NewResolvedTable(ltable, nil, nil), 100 plan.NewResolvedTable(rtable, nil, nil), 101 expression.NewEquals( 102 expression.NewGetField(0, types.Text, "lcol1", false), 103 expression.NewGetField(4, types.Text, "rcol1", false), 104 )) 105 106 rows := collectRows(t, ctx, j) 107 require.Len(rows, 2) 108 109 require.Equal([]sql.Row{ 110 {"col1_1", "col2_1", int32(1), int64(2), "col1_1", "col2_1", int32(1), int64(2)}, 111 {"col1_2", "col2_2", int32(3), int64(4), "col1_2", "col2_2", int32(3), int64(4)}, 112 }, rows) 113 } 114 115 func TestInnerJoinEmpty(t *testing.T) { 116 require := require.New(t) 117 118 db := memory.NewDatabase("test") 119 pro := memory.NewDBProvider(db) 120 ctx := newContext(pro) 121 122 ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil) 123 rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil) 124 125 j := plan.NewInnerJoin( 126 plan.NewResolvedTable(ltable, nil, nil), 127 plan.NewResolvedTable(rtable, nil, nil), 128 expression.NewEquals( 129 expression.NewGetField(0, types.Text, "lcol1", false), 130 expression.NewGetField(4, types.Text, "rcol1", false), 131 )) 132 133 iter, err := DefaultBuilder.Build(ctx, j, nil) 134 require.NoError(err) 135 136 assertRows(t, ctx, iter, 0) 137 } 138 139 func BenchmarkInnerJoin(b *testing.B) { 140 db := memory.NewDatabase("test") 141 pro := memory.NewDBProvider(db) 142 143 t1 := memory.NewTable(db.BaseDatabase, "foo", sql.NewPrimaryKeySchema(sql.Schema{ 144 {Name: "a", Source: "foo", Type: types.Int64}, 145 {Name: "b", Source: "foo", Type: types.Text}, 146 }), nil) 147 148 t2 := memory.NewTable(db.BaseDatabase, "bar", sql.NewPrimaryKeySchema(sql.Schema{ 149 {Name: "a", Source: "bar", Type: types.Int64}, 150 {Name: "b", Source: "bar", Type: types.Text}, 151 }), nil) 152 153 for i := 0; i < 5; i++ { 154 t1.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t1_%d", i))) 155 t2.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t2_%d", i))) 156 } 157 158 n1 := plan.NewInnerJoin( 159 plan.NewResolvedTable(t1, nil, nil), 160 plan.NewResolvedTable(t2, nil, nil), 161 expression.NewEquals( 162 expression.NewGetField(0, types.Int64, "a", false), 163 expression.NewGetField(2, types.Int64, "a", false), 164 ), 165 ) 166 167 n2 := plan.NewFilter( 168 expression.NewEquals( 169 expression.NewGetField(0, types.Int64, "a", false), 170 expression.NewGetField(2, types.Int64, "a", false), 171 ), 172 plan.NewCrossJoin( 173 plan.NewResolvedTable(t1, nil, nil), 174 plan.NewResolvedTable(t2, nil, nil), 175 ), 176 ) 177 178 expected := []sql.Row{ 179 {int64(0), "t1_0", int64(0), "t2_0"}, 180 {int64(1), "t1_1", int64(1), "t2_1"}, 181 {int64(2), "t1_2", int64(2), "t2_2"}, 182 {int64(3), "t1_3", int64(3), "t2_3"}, 183 {int64(4), "t1_4", int64(4), "t2_4"}, 184 } 185 186 ctx := sql.NewContext(context.Background(), sql.WithMemoryManager( 187 sql.NewMemoryManager(mockReporter{1, 5}), 188 ), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro))) 189 190 b.Run("inner join", func(b *testing.B) { 191 require := require.New(b) 192 193 for i := 0; i < b.N; i++ { 194 iter, err := DefaultBuilder.Build(ctx, n1, nil) 195 require.NoError(err) 196 197 rows, err := sql.RowIterToRows(ctx, iter) 198 require.NoError(err) 199 200 require.Equal(expected, rows) 201 } 202 }) 203 204 b.Run("within memory threshold", func(b *testing.B) { 205 require := require.New(b) 206 207 for i := 0; i < b.N; i++ { 208 iter, err := DefaultBuilder.Build(ctx, n1, nil) 209 require.NoError(err) 210 211 rows, err := sql.RowIterToRows(ctx, iter) 212 require.NoError(err) 213 214 require.Equal(expected, rows) 215 } 216 }) 217 218 b.Run("cross join with filter", func(b *testing.B) { 219 require := require.New(b) 220 221 for i := 0; i < b.N; i++ { 222 iter, err := DefaultBuilder.Build(ctx, n2, nil) 223 require.NoError(err) 224 225 rows, err := sql.RowIterToRows(ctx, iter) 226 require.NoError(err) 227 228 require.Equal(expected, rows) 229 } 230 }) 231 } 232 233 func TestLeftJoin(t *testing.T) { 234 require := require.New(t) 235 236 db := memory.NewDatabase("test") 237 pro := memory.NewDBProvider(db) 238 ctx := newContext(pro) 239 240 ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil) 241 rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil) 242 243 insertData(t, newContext(pro), ltable) 244 insertData(t, newContext(pro), rtable) 245 246 j := plan.NewLeftOuterJoin( 247 plan.NewResolvedTable(ltable, nil, nil), 248 plan.NewResolvedTable(rtable, nil, nil), 249 expression.NewEquals( 250 expression.NewPlus( 251 expression.NewGetField(2, types.Text, "lcol3", false), 252 expression.NewLiteral(int32(2), types.Int32), 253 ), 254 expression.NewGetField(6, types.Text, "rcol3", false), 255 )) 256 257 iter, err := DefaultBuilder.Build(ctx, j, nil) 258 require.NoError(err) 259 rows, err := sql.RowIterToRows(ctx, iter) 260 require.NoError(err) 261 require.ElementsMatch([]sql.Row{ 262 {"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)}, 263 {"col1_2", "col2_2", int32(3), int64(4), nil, nil, nil, nil}, 264 }, rows) 265 } 266 267 type mockReporter struct { 268 val uint64 269 max uint64 270 } 271 272 func (m mockReporter) UsedMemory() uint64 { return m.val } 273 func (m mockReporter) MaxMemory() uint64 { return m.max }