github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/analyzer_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  func TestAddRule(t *testing.T) {
    31  	require := require.New(t)
    32  
    33  	defRulesCount := countRules(NewDefault(nil).Batches)
    34  
    35  	a := NewBuilder(nil).AddPostAnalyzeRule(-1, pushFilters).Build()
    36  
    37  	require.Equal(countRules(a.Batches), defRulesCount+1)
    38  }
    39  
    40  func TestAddPreValidationRule(t *testing.T) {
    41  	require := require.New(t)
    42  
    43  	defRulesCount := countRules(NewDefault(nil).Batches)
    44  
    45  	a := NewBuilder(nil).AddPreValidationRule(-1, pushFilters).Build()
    46  
    47  	require.Equal(countRules(a.Batches), defRulesCount+1)
    48  }
    49  
    50  func TestAddPostValidationRule(t *testing.T) {
    51  	require := require.New(t)
    52  
    53  	defRulesCount := countRules(NewDefault(nil).Batches)
    54  
    55  	a := NewBuilder(nil).AddPostValidationRule(-1, pushFilters).Build()
    56  
    57  	require.Equal(countRules(a.Batches), defRulesCount+1)
    58  }
    59  
    60  func TestRemoveOnceBeforeRule(t *testing.T) {
    61  	require := require.New(t)
    62  
    63  	a := NewBuilder(nil).RemoveOnceBeforeRule(applyDefaultSelectLimitId).Build()
    64  
    65  	defRulesCount := countRules(NewDefault(nil).Batches)
    66  
    67  	require.Equal(countRules(a.Batches), defRulesCount-1)
    68  }
    69  
    70  func TestRemoveDefaultRule(t *testing.T) {
    71  	require := require.New(t)
    72  
    73  	a := NewBuilder(nil).RemoveDefaultRule(resolveSubqueriesId).Build()
    74  
    75  	defRulesCount := countRules(NewDefault(nil).Batches)
    76  
    77  	require.Equal(countRules(a.Batches), defRulesCount-1)
    78  }
    79  
    80  func TestRemoveOnceAfterRule(t *testing.T) {
    81  	require := require.New(t)
    82  
    83  	a := NewBuilder(nil).RemoveOnceAfterRule(loadTriggersId).Build()
    84  
    85  	defRulesCount := countRules(NewDefault(nil).Batches)
    86  
    87  	require.Equal(countRules(a.Batches), defRulesCount-1)
    88  }
    89  
    90  func TestRemoveValidationRule(t *testing.T) {
    91  	require := require.New(t)
    92  
    93  	a := NewBuilder(nil).RemoveValidationRule(validateResolvedId).Build()
    94  
    95  	defRulesCount := countRules(NewDefault(nil).Batches)
    96  
    97  	require.Equal(countRules(a.Batches), defRulesCount-1)
    98  }
    99  
   100  func TestRemoveAfterAllRule(t *testing.T) {
   101  	require := require.New(t)
   102  
   103  	a := NewBuilder(nil).RemoveAfterAllRule(TrackProcessId).Build()
   104  
   105  	defRulesCount := countRules(NewDefault(nil).Batches)
   106  
   107  	require.Equal(countRules(a.Batches), defRulesCount-1)
   108  }
   109  
   110  func countRules(batches []*Batch) int {
   111  	var count int
   112  	for _, b := range batches {
   113  		count = count + len(b.Rules)
   114  	}
   115  	return count
   116  
   117  }
   118  
   119  func TestDeepCopyNode(t *testing.T) {
   120  	tests := []struct {
   121  		node sql.Node
   122  		exp  sql.Node
   123  	}{
   124  		{
   125  			node: plan.NewProject(
   126  				[]sql.Expression{
   127  					expression.NewLiteral(1, types.Int64),
   128  				},
   129  				plan.NewNaturalJoin(
   130  					plan.NewInnerJoin(
   131  						plan.NewUnresolvedTable("mytable", ""),
   132  						plan.NewUnresolvedTable("mytable2", ""),
   133  						expression.NewEquals(
   134  							expression.NewUnresolvedQualifiedColumn("mytable", "i"),
   135  							expression.NewUnresolvedQualifiedColumn("mytable2", "i2"),
   136  						),
   137  					),
   138  					plan.NewFilter(
   139  						expression.NewEquals(
   140  							expression.NewBindVar("v1"),
   141  							expression.NewBindVar("v2"),
   142  						),
   143  						plan.NewUnresolvedTable("mytable3", ""),
   144  					),
   145  				),
   146  			),
   147  		},
   148  		{
   149  			node: plan.NewProject(
   150  				[]sql.Expression{
   151  					expression.NewLiteral(1, types.Int64),
   152  				},
   153  				plan.NewSetOp(
   154  					plan.UnionType,
   155  					plan.NewProject(
   156  						[]sql.Expression{
   157  							expression.NewLiteral(1, types.Int64),
   158  						},
   159  						plan.NewUnresolvedTable("mytable", ""),
   160  					),
   161  					plan.NewProject(
   162  						[]sql.Expression{
   163  							expression.NewBindVar("v1"),
   164  							expression.NewBindVar("v2"),
   165  						},
   166  						plan.NewUnresolvedTable("mytable", ""),
   167  					),
   168  					false, nil, nil, nil),
   169  			),
   170  		},
   171  		{
   172  			node: plan.NewFilter(
   173  				expression.NewEquals(
   174  					expression.NewLiteral(1, types.Int64),
   175  					expression.NewLiteral(1, types.Int64),
   176  				),
   177  				plan.NewWindow(
   178  					[]sql.Expression{
   179  						aggregation.NewSum(
   180  							expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false),
   181  						),
   182  						expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "a", "x", false),
   183  						expression.NewBindVar("v1"),
   184  					},
   185  					plan.NewProject(
   186  						[]sql.Expression{
   187  							expression.NewBindVar("v2"),
   188  						},
   189  						plan.NewUnresolvedTable("x", ""),
   190  					),
   191  				),
   192  			),
   193  		},
   194  		{
   195  			node: plan.NewFilter(
   196  				expression.NewEquals(
   197  					expression.NewLiteral(1, types.Int64),
   198  					expression.NewLiteral(1, types.Int64),
   199  				),
   200  				plan.NewSubqueryAlias("cte1", "select x from a",
   201  					plan.NewProject(
   202  						[]sql.Expression{
   203  							expression.NewBindVar("v1"),
   204  							expression.NewUnresolvedColumn("v2"),
   205  						},
   206  						plan.NewUnresolvedTable("a", ""),
   207  					),
   208  				),
   209  			),
   210  		},
   211  	}
   212  
   213  	for i, tt := range tests {
   214  		t.Run(fmt.Sprintf("DeepCopyTest_%d", i), func(t *testing.T) {
   215  			cop, err := DeepCopyNode(tt.node)
   216  			require.NoError(t, err)
   217  			cop, _, err = plan.ApplyBindings(cop, map[string]sql.Expression{
   218  				"v1": expression.NewLiteral(1, types.Int64),
   219  				"v2": expression.NewLiteral("x", types.Text),
   220  			})
   221  			require.NoError(t, err)
   222  			require.NotEqual(t, cop, tt.node)
   223  		})
   224  	}
   225  }