github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/cross_join_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"io"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/memory"
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  var lSchema = sql.NewPrimaryKeySchema(sql.Schema{
    30  	{Name: "lcol1", Type: types.Text},
    31  	{Name: "lcol2", Type: types.Text},
    32  	{Name: "lcol3", Type: types.Int32},
    33  	{Name: "lcol4", Type: types.Int64},
    34  })
    35  
    36  var rSchema = sql.NewPrimaryKeySchema(sql.Schema{
    37  	{Name: "rcol1", Type: types.Text},
    38  	{Name: "rcol2", Type: types.Text},
    39  	{Name: "rcol3", Type: types.Int32},
    40  	{Name: "rcol4", Type: types.Int64},
    41  })
    42  
    43  func TestCrossJoin(t *testing.T) {
    44  	require := require.New(t)
    45  
    46  	resultSchema := sql.Schema{
    47  		{Name: "lcol1", Type: types.Text},
    48  		{Name: "lcol2", Type: types.Text},
    49  		{Name: "lcol3", Type: types.Int32},
    50  		{Name: "lcol4", Type: types.Int64},
    51  		{Name: "rcol1", Type: types.Text},
    52  		{Name: "rcol2", Type: types.Text},
    53  		{Name: "rcol3", Type: types.Int32},
    54  		{Name: "rcol4", Type: types.Int64},
    55  	}
    56  
    57  	db := memory.NewDatabase("test")
    58  	pro := memory.NewDBProvider(db)
    59  	ctx := newContext(pro)
    60  
    61  	ltable := memory.NewTable(db.Database(), "left", lSchema, nil)
    62  	rtable := memory.NewTable(db.Database(), "right", rSchema, nil)
    63  	insertData(t, newContext(pro), ltable)
    64  	insertData(t, newContext(pro), rtable)
    65  
    66  	j := plan.NewCrossJoin(
    67  		plan.NewResolvedTable(ltable, nil, nil),
    68  		plan.NewResolvedTable(rtable, nil, nil),
    69  	)
    70  
    71  	require.Equal(resultSchema, j.Schema())
    72  
    73  	iter, err := DefaultBuilder.Build(ctx, j, nil)
    74  	require.NoError(err)
    75  	require.NotNil(iter)
    76  
    77  	row, err := iter.Next(ctx)
    78  	require.NoError(err)
    79  	require.NotNil(row)
    80  
    81  	require.Equal(8, len(row))
    82  
    83  	require.Equal("col1_1", row[0])
    84  	require.Equal("col2_1", row[1])
    85  	require.Equal(int32(1), row[2])
    86  	require.Equal(int64(2), row[3])
    87  	require.Equal("col1_1", row[4])
    88  	require.Equal("col2_1", row[5])
    89  	require.Equal(int32(1), row[6])
    90  	require.Equal(int64(2), row[7])
    91  
    92  	row, err = iter.Next(ctx)
    93  	require.NoError(err)
    94  	require.NotNil(row)
    95  
    96  	require.Equal("col1_1", row[0])
    97  	require.Equal("col2_1", row[1])
    98  	require.Equal(int32(1), row[2])
    99  	require.Equal(int64(2), row[3])
   100  	require.Equal("col1_2", row[4])
   101  	require.Equal("col2_2", row[5])
   102  	require.Equal(int32(3), row[6])
   103  	require.Equal(int64(4), row[7])
   104  
   105  	for i := 0; i < 2; i++ {
   106  		row, err = iter.Next(ctx)
   107  		require.NoError(err)
   108  		require.NotNil(row)
   109  	}
   110  
   111  	// total: 4 rows
   112  	row, err = iter.Next(ctx)
   113  	require.NotNil(err)
   114  	require.Equal(err, io.EOF)
   115  	require.Nil(row)
   116  }
   117  
   118  func TestCrossJoin_Empty(t *testing.T) {
   119  	require := require.New(t)
   120  
   121  	db := memory.NewDatabase("test")
   122  	pro := memory.NewDBProvider(db)
   123  	ctx := newContext(pro)
   124  
   125  	ltable := memory.NewTable(db.Database(), "left", lSchema, nil)
   126  	rtable := memory.NewTable(db.Database(), "right", rSchema, nil)
   127  	insertData(t, newContext(pro), ltable)
   128  
   129  	j := plan.NewCrossJoin(
   130  		plan.NewResolvedTable(ltable, nil, nil),
   131  		plan.NewResolvedTable(rtable, nil, nil),
   132  	)
   133  
   134  	iter, err := DefaultBuilder.Build(ctx, j, nil)
   135  	require.NoError(err)
   136  	require.NotNil(iter)
   137  
   138  	row, err := iter.Next(ctx)
   139  	require.Equal(io.EOF, err)
   140  	require.Nil(row)
   141  
   142  	ltable = memory.NewTable(db.Database(), "left", lSchema, nil)
   143  	rtable = memory.NewTable(db.Database(), "right", rSchema, nil)
   144  	insertData(t, newContext(pro), rtable)
   145  
   146  	j = plan.NewCrossJoin(
   147  		plan.NewResolvedTable(ltable, nil, nil),
   148  		plan.NewResolvedTable(rtable, nil, nil),
   149  	)
   150  
   151  	iter, err = DefaultBuilder.Build(ctx, j, nil)
   152  	require.NoError(err)
   153  	require.NotNil(iter)
   154  
   155  	row, err = iter.Next(ctx)
   156  	require.Equal(io.EOF, err)
   157  	require.Nil(row)
   158  }
   159  
   160  func insertData(t *testing.T, ctx *sql.Context, table *memory.Table) {
   161  	t.Helper()
   162  	require := require.New(t)
   163  
   164  	rows := []sql.Row{
   165  		sql.NewRow("col1_1", "col2_1", int32(1), int64(2)),
   166  		sql.NewRow("col1_2", "col2_2", int32(3), int64(4)),
   167  	}
   168  
   169  	for _, r := range rows {
   170  		require.NoError(table.Insert(ctx, r))
   171  	}
   172  }