github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/parallelize_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/memory"
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation/window"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/transform"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func TestParallelize(t *testing.T) {
    32  	require := require.New(t)
    33  	db := memory.NewDatabase("db")
    34  	pro := memory.NewDBProvider(db)
    35  	ctx := newContext(pro)
    36  
    37  	table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil)
    38  	rule := getRuleFrom(OnceAfterAll, parallelizeId)
    39  	node := plan.NewProject(
    40  		nil,
    41  		plan.NewInnerJoin(
    42  			plan.NewFilter(
    43  				expression.NewLiteral(1, types.Int64),
    44  				plan.NewResolvedTable(table, nil, nil),
    45  			),
    46  			plan.NewFilter(
    47  				expression.NewLiteral(1, types.Int64),
    48  				plan.NewResolvedTable(table, nil, nil),
    49  			),
    50  			expression.NewLiteral(1, types.Int64),
    51  		),
    52  	)
    53  
    54  	expected := plan.NewProject(
    55  		nil,
    56  		plan.NewInnerJoin(
    57  			plan.NewExchange(
    58  				2,
    59  				plan.NewFilter(
    60  					expression.NewLiteral(1, types.Int64),
    61  					plan.NewResolvedTable(table, nil, nil),
    62  				),
    63  			),
    64  			plan.NewExchange(
    65  				2,
    66  				plan.NewFilter(
    67  					expression.NewLiteral(1, types.Int64),
    68  					plan.NewResolvedTable(table, nil, nil),
    69  				),
    70  			),
    71  			expression.NewLiteral(1, types.Int64),
    72  		),
    73  	)
    74  
    75  	result, _, err := rule.Apply(ctx, &Analyzer{Parallelism: 2}, node, nil, DefaultRuleSelector)
    76  	require.NoError(err)
    77  	require.Equal(expected, result)
    78  }
    79  
    80  func TestParallelizeCreateIndex(t *testing.T) {
    81  	require := require.New(t)
    82  	db := memory.NewDatabase("db")
    83  	pro := memory.NewDBProvider(db)
    84  	ctx := newContext(pro)
    85  
    86  	table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil)
    87  	rule := getRuleFrom(OnceAfterAll, parallelizeId)
    88  	node := plan.NewCreateIndex(
    89  		"",
    90  		plan.NewResolvedTable(table, nil, nil),
    91  		nil,
    92  		"",
    93  		nil,
    94  	)
    95  
    96  	result, _, err := rule.Apply(ctx, &Analyzer{Parallelism: 1}, node, nil, DefaultRuleSelector)
    97  	require.NoError(err)
    98  	require.Equal(node, result)
    99  }
   100  
   101  func TestIsParallelizable(t *testing.T) {
   102  	db := memory.NewDatabase("db")
   103  	table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil)
   104  
   105  	testCases := []struct {
   106  		name           string
   107  		node           sql.Node
   108  		parallelizable bool
   109  	}{
   110  		{
   111  			"just table",
   112  			plan.NewResolvedTable(table, nil, nil),
   113  			true,
   114  		},
   115  		{
   116  			"filter",
   117  			plan.NewFilter(
   118  				expression.NewLiteral(1, types.Int64),
   119  				plan.NewResolvedTable(table, nil, nil),
   120  			),
   121  			true,
   122  		},
   123  		{
   124  			"filter with a subquery",
   125  			plan.NewFilter(
   126  				eq(
   127  					lit(1),
   128  					plan.NewSubquery(
   129  						plan.NewProject([]sql.Expression{lit(1)}, plan.NewResolvedTable(table, nil, nil)), "select 1 from table")),
   130  				plan.NewResolvedTable(table, nil, nil),
   131  			),
   132  			true,
   133  		},
   134  		{
   135  			"filter with an incompatible subquery",
   136  			plan.NewFilter(
   137  				eq(
   138  					lit(1),
   139  					plan.NewSubquery(
   140  						plan.NewProject([]sql.Expression{gf(0, "", "row_number()")},
   141  							plan.NewWindow([]sql.Expression{window.NewRowNumber()}, plan.NewResolvedTable(table, nil, nil)),
   142  						),
   143  						"select row_number over () from table",
   144  					),
   145  				),
   146  				plan.NewResolvedTable(table, nil, nil),
   147  			),
   148  			false,
   149  		},
   150  		{
   151  			"project",
   152  			plan.NewProject(
   153  				nil,
   154  				plan.NewFilter(
   155  					expression.NewLiteral(1, types.Int64),
   156  					plan.NewResolvedTable(table, nil, nil),
   157  				),
   158  			),
   159  			true,
   160  		},
   161  		{
   162  			"project with a subquery",
   163  			plan.NewProject([]sql.Expression{
   164  				plan.NewSubquery(
   165  					plan.NewProject([]sql.Expression{lit(1)}, plan.NewResolvedTable(table, nil, nil)), "select 1 from table"),
   166  			},
   167  				plan.NewFilter(
   168  					expression.NewLiteral(1, types.Int64),
   169  					plan.NewResolvedTable(table, nil, nil),
   170  				),
   171  			),
   172  			true,
   173  		},
   174  		{
   175  			"project with an incompatible subquery",
   176  			plan.NewProject([]sql.Expression{
   177  				plan.NewSubquery(
   178  					plan.NewProject([]sql.Expression{gf(0, "", "row_number()")},
   179  						plan.NewWindow([]sql.Expression{window.NewRowNumber()}, plan.NewResolvedTable(table, nil, nil)),
   180  					),
   181  					"select row_number over () from table",
   182  				),
   183  			},
   184  				plan.NewFilter(
   185  					expression.NewLiteral(1, types.Int64),
   186  					plan.NewResolvedTable(table, nil, nil),
   187  				),
   188  			),
   189  			false,
   190  		},
   191  		{
   192  			"join",
   193  			plan.NewInnerJoin(
   194  				plan.NewResolvedTable(table, nil, nil),
   195  				plan.NewResolvedTable(table, nil, nil),
   196  				expression.NewLiteral(1, types.Int64),
   197  			),
   198  			false,
   199  		},
   200  		{
   201  			"group by",
   202  			plan.NewGroupBy(
   203  				nil,
   204  				nil,
   205  				plan.NewResolvedTable(nil, nil, nil),
   206  			),
   207  			false,
   208  		},
   209  		{
   210  			"limit",
   211  			plan.NewLimit(
   212  				expression.NewLiteral(5, types.Int8),
   213  				plan.NewResolvedTable(nil, nil, nil),
   214  			),
   215  			false,
   216  		},
   217  		{
   218  			"offset",
   219  			plan.NewOffset(
   220  				expression.NewLiteral(5, types.Int8),
   221  				plan.NewResolvedTable(nil, nil, nil),
   222  			),
   223  			false,
   224  		},
   225  		{
   226  			"sort",
   227  			plan.NewSort(
   228  				nil,
   229  				plan.NewResolvedTable(nil, nil, nil),
   230  			),
   231  			false,
   232  		},
   233  		{
   234  			"distinct",
   235  			plan.NewDistinct(
   236  				plan.NewResolvedTable(nil, nil, nil),
   237  			),
   238  			false,
   239  		},
   240  		{
   241  			"ordered distinct",
   242  			plan.NewOrderedDistinct(
   243  				plan.NewResolvedTable(nil, nil, nil),
   244  			),
   245  			false,
   246  		},
   247  	}
   248  
   249  	for _, tt := range testCases {
   250  		t.Run(tt.name, func(t *testing.T) {
   251  			require.Equal(t, tt.parallelizable, isParallelizable(tt.node))
   252  		})
   253  	}
   254  }
   255  
   256  func TestRemoveRedundantExchanges(t *testing.T) {
   257  	require := require.New(t)
   258  	db := memory.NewDatabase("db")
   259  
   260  	table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil)
   261  
   262  	node := plan.NewProject(
   263  		nil,
   264  		plan.NewInnerJoin(
   265  			plan.NewExchange(
   266  				1,
   267  				plan.NewFilter(
   268  					expression.NewLiteral(1, types.Int64),
   269  					plan.NewExchange(
   270  						1,
   271  						plan.NewResolvedTable(table, nil, nil),
   272  					),
   273  				),
   274  			),
   275  			plan.NewExchange(
   276  				1,
   277  				plan.NewFilter(
   278  					expression.NewLiteral(1, types.Int64),
   279  					plan.NewExchange(
   280  						1,
   281  						plan.NewResolvedTable(table, nil, nil),
   282  					),
   283  				),
   284  			),
   285  			expression.NewLiteral(1, types.Int64),
   286  		),
   287  	)
   288  
   289  	expected := plan.NewProject(
   290  		nil,
   291  		plan.NewInnerJoin(
   292  			plan.NewExchange(
   293  				1,
   294  				plan.NewFilter(
   295  					expression.NewLiteral(1, types.Int64),
   296  					plan.NewResolvedTable(table, nil, nil),
   297  				),
   298  			),
   299  			plan.NewExchange(
   300  				1,
   301  				plan.NewFilter(
   302  					expression.NewLiteral(1, types.Int64),
   303  					plan.NewResolvedTable(table, nil, nil),
   304  				),
   305  			),
   306  			expression.NewLiteral(1, types.Int64),
   307  		),
   308  	)
   309  
   310  	result, _, err := transform.Node(node, removeRedundantExchanges)
   311  	require.NoError(err)
   312  	require.Equal(expected, result)
   313  }