github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/join_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/dolthub/go-mysql-server/memory"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestJoinSchema(t *testing.T) {
    32  	db := memory.NewDatabase("test")
    33  	t1 := plan.NewResolvedTable(memory.NewTable(db.BaseDatabase, "foo", sql.NewPrimaryKeySchema(sql.Schema{
    34  		{Name: "a", Source: "foo", Type: types.Int64},
    35  	}), nil), nil, nil)
    36  
    37  	t2 := plan.NewResolvedTable(memory.NewTable(db.BaseDatabase, "bar", sql.NewPrimaryKeySchema(sql.Schema{
    38  		{Name: "b", Source: "bar", Type: types.Int64},
    39  	}), nil), nil, nil)
    40  
    41  	t.Run("inner", func(t *testing.T) {
    42  		j := plan.NewInnerJoin(t1, t2, nil)
    43  		result := j.Schema()
    44  
    45  		require.Equal(t, sql.Schema{
    46  			{Name: "a", Source: "foo", Type: types.Int64},
    47  			{Name: "b", Source: "bar", Type: types.Int64},
    48  		}, result)
    49  	})
    50  
    51  	t.Run("left", func(t *testing.T) {
    52  		j := plan.NewLeftOuterJoin(t1, t2, nil)
    53  		result := j.Schema()
    54  
    55  		require.Equal(t, sql.Schema{
    56  			{Name: "a", Source: "foo", Type: types.Int64},
    57  			{Name: "b", Source: "bar", Type: types.Int64, Nullable: true},
    58  		}, result)
    59  	})
    60  
    61  	t.Run("right", func(t *testing.T) {
    62  		j := plan.NewRightOuterJoin(t1, t2, nil)
    63  		result := j.Schema()
    64  
    65  		require.Equal(t, sql.Schema{
    66  			{Name: "a", Source: "foo", Type: types.Int64, Nullable: true},
    67  			{Name: "b", Source: "bar", Type: types.Int64},
    68  		}, result)
    69  	})
    70  }
    71  
    72  func TestInnerJoin(t *testing.T) {
    73  	db := memory.NewDatabase("test")
    74  	pro := memory.NewDBProvider(db)
    75  	ctx := newContext(pro)
    76  	testInnerJoin(t, db, ctx)
    77  }
    78  
    79  func TestMultiPassInnerJoin(t *testing.T) {
    80  	db := memory.NewDatabase("test")
    81  	pro := memory.NewDBProvider(db)
    82  	ctx := sql.NewContext(context.TODO(), sql.WithMemoryManager(
    83  		sql.NewMemoryManager(mockReporter{2, 1}),
    84  	), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro)))
    85  	testInnerJoin(t, db, ctx)
    86  }
    87  
    88  func testInnerJoin(t *testing.T, db *memory.Database, ctx *sql.Context) {
    89  	t.Helper()
    90  
    91  	require := require.New(t)
    92  
    93  	ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil)
    94  	rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil)
    95  	insertData(t, ctx, ltable)
    96  	insertData(t, ctx, rtable)
    97  
    98  	j := plan.NewInnerJoin(
    99  		plan.NewResolvedTable(ltable, nil, nil),
   100  		plan.NewResolvedTable(rtable, nil, nil),
   101  		expression.NewEquals(
   102  			expression.NewGetField(0, types.Text, "lcol1", false),
   103  			expression.NewGetField(4, types.Text, "rcol1", false),
   104  		))
   105  
   106  	rows := collectRows(t, ctx, j)
   107  	require.Len(rows, 2)
   108  
   109  	require.Equal([]sql.Row{
   110  		{"col1_1", "col2_1", int32(1), int64(2), "col1_1", "col2_1", int32(1), int64(2)},
   111  		{"col1_2", "col2_2", int32(3), int64(4), "col1_2", "col2_2", int32(3), int64(4)},
   112  	}, rows)
   113  }
   114  
   115  func TestInnerJoinEmpty(t *testing.T) {
   116  	require := require.New(t)
   117  
   118  	db := memory.NewDatabase("test")
   119  	pro := memory.NewDBProvider(db)
   120  	ctx := newContext(pro)
   121  
   122  	ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil)
   123  	rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil)
   124  
   125  	j := plan.NewInnerJoin(
   126  		plan.NewResolvedTable(ltable, nil, nil),
   127  		plan.NewResolvedTable(rtable, nil, nil),
   128  		expression.NewEquals(
   129  			expression.NewGetField(0, types.Text, "lcol1", false),
   130  			expression.NewGetField(4, types.Text, "rcol1", false),
   131  		))
   132  
   133  	iter, err := DefaultBuilder.Build(ctx, j, nil)
   134  	require.NoError(err)
   135  
   136  	assertRows(t, ctx, iter, 0)
   137  }
   138  
   139  func BenchmarkInnerJoin(b *testing.B) {
   140  	db := memory.NewDatabase("test")
   141  	pro := memory.NewDBProvider(db)
   142  
   143  	t1 := memory.NewTable(db.BaseDatabase, "foo", sql.NewPrimaryKeySchema(sql.Schema{
   144  		{Name: "a", Source: "foo", Type: types.Int64},
   145  		{Name: "b", Source: "foo", Type: types.Text},
   146  	}), nil)
   147  
   148  	t2 := memory.NewTable(db.BaseDatabase, "bar", sql.NewPrimaryKeySchema(sql.Schema{
   149  		{Name: "a", Source: "bar", Type: types.Int64},
   150  		{Name: "b", Source: "bar", Type: types.Text},
   151  	}), nil)
   152  
   153  	for i := 0; i < 5; i++ {
   154  		t1.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t1_%d", i)))
   155  		t2.Insert(sql.NewEmptyContext(), sql.NewRow(int64(i), fmt.Sprintf("t2_%d", i)))
   156  	}
   157  
   158  	n1 := plan.NewInnerJoin(
   159  		plan.NewResolvedTable(t1, nil, nil),
   160  		plan.NewResolvedTable(t2, nil, nil),
   161  		expression.NewEquals(
   162  			expression.NewGetField(0, types.Int64, "a", false),
   163  			expression.NewGetField(2, types.Int64, "a", false),
   164  		),
   165  	)
   166  
   167  	n2 := plan.NewFilter(
   168  		expression.NewEquals(
   169  			expression.NewGetField(0, types.Int64, "a", false),
   170  			expression.NewGetField(2, types.Int64, "a", false),
   171  		),
   172  		plan.NewCrossJoin(
   173  			plan.NewResolvedTable(t1, nil, nil),
   174  			plan.NewResolvedTable(t2, nil, nil),
   175  		),
   176  	)
   177  
   178  	expected := []sql.Row{
   179  		{int64(0), "t1_0", int64(0), "t2_0"},
   180  		{int64(1), "t1_1", int64(1), "t2_1"},
   181  		{int64(2), "t1_2", int64(2), "t2_2"},
   182  		{int64(3), "t1_3", int64(3), "t2_3"},
   183  		{int64(4), "t1_4", int64(4), "t2_4"},
   184  	}
   185  
   186  	ctx := sql.NewContext(context.Background(), sql.WithMemoryManager(
   187  		sql.NewMemoryManager(mockReporter{1, 5}),
   188  	), sql.WithSession(memory.NewSession(sql.NewBaseSession(), pro)))
   189  
   190  	b.Run("inner join", func(b *testing.B) {
   191  		require := require.New(b)
   192  
   193  		for i := 0; i < b.N; i++ {
   194  			iter, err := DefaultBuilder.Build(ctx, n1, nil)
   195  			require.NoError(err)
   196  
   197  			rows, err := sql.RowIterToRows(ctx, iter)
   198  			require.NoError(err)
   199  
   200  			require.Equal(expected, rows)
   201  		}
   202  	})
   203  
   204  	b.Run("within memory threshold", func(b *testing.B) {
   205  		require := require.New(b)
   206  
   207  		for i := 0; i < b.N; i++ {
   208  			iter, err := DefaultBuilder.Build(ctx, n1, nil)
   209  			require.NoError(err)
   210  
   211  			rows, err := sql.RowIterToRows(ctx, iter)
   212  			require.NoError(err)
   213  
   214  			require.Equal(expected, rows)
   215  		}
   216  	})
   217  
   218  	b.Run("cross join with filter", func(b *testing.B) {
   219  		require := require.New(b)
   220  
   221  		for i := 0; i < b.N; i++ {
   222  			iter, err := DefaultBuilder.Build(ctx, n2, nil)
   223  			require.NoError(err)
   224  
   225  			rows, err := sql.RowIterToRows(ctx, iter)
   226  			require.NoError(err)
   227  
   228  			require.Equal(expected, rows)
   229  		}
   230  	})
   231  }
   232  
   233  func TestLeftJoin(t *testing.T) {
   234  	require := require.New(t)
   235  
   236  	db := memory.NewDatabase("test")
   237  	pro := memory.NewDBProvider(db)
   238  	ctx := newContext(pro)
   239  
   240  	ltable := memory.NewTable(db.BaseDatabase, "left", lSchema, nil)
   241  	rtable := memory.NewTable(db.BaseDatabase, "right", rSchema, nil)
   242  
   243  	insertData(t, newContext(pro), ltable)
   244  	insertData(t, newContext(pro), rtable)
   245  
   246  	j := plan.NewLeftOuterJoin(
   247  		plan.NewResolvedTable(ltable, nil, nil),
   248  		plan.NewResolvedTable(rtable, nil, nil),
   249  		expression.NewEquals(
   250  			expression.NewPlus(
   251  				expression.NewGetField(2, types.Text, "lcol3", false),
   252  				expression.NewLiteral(int32(2), types.Int32),
   253  			),
   254  			expression.NewGetField(6, types.Text, "rcol3", false),
   255  		))
   256  
   257  	iter, err := DefaultBuilder.Build(ctx, j, nil)
   258  	require.NoError(err)
   259  	rows, err := sql.RowIterToRows(ctx, iter)
   260  	require.NoError(err)
   261  	require.ElementsMatch([]sql.Row{
   262  		{"col1_1", "col2_1", int32(1), int64(2), "col1_2", "col2_2", int32(3), int64(4)},
   263  		{"col1_2", "col2_2", int32(3), int64(4), nil, nil, nil, nil},
   264  	}, rows)
   265  }
   266  
   267  type mockReporter struct {
   268  	val uint64
   269  	max uint64
   270  }
   271  
   272  func (m mockReporter) UsedMemory() uint64 { return m.val }
   273  func (m mockReporter) MaxMemory() uint64  { return m.max }