github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/analyzer_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "testing" 20 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation" 26 "github.com/dolthub/go-mysql-server/sql/plan" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 ) 29 30 func TestAddRule(t *testing.T) { 31 require := require.New(t) 32 33 defRulesCount := countRules(NewDefault(nil).Batches) 34 35 a := NewBuilder(nil).AddPostAnalyzeRule(-1, pushFilters).Build() 36 37 require.Equal(countRules(a.Batches), defRulesCount+1) 38 } 39 40 func TestAddPreValidationRule(t *testing.T) { 41 require := require.New(t) 42 43 defRulesCount := countRules(NewDefault(nil).Batches) 44 45 a := NewBuilder(nil).AddPreValidationRule(-1, pushFilters).Build() 46 47 require.Equal(countRules(a.Batches), defRulesCount+1) 48 } 49 50 func TestAddPostValidationRule(t *testing.T) { 51 require := require.New(t) 52 53 defRulesCount := countRules(NewDefault(nil).Batches) 54 55 a := NewBuilder(nil).AddPostValidationRule(-1, pushFilters).Build() 56 57 require.Equal(countRules(a.Batches), defRulesCount+1) 58 } 59 60 func TestRemoveOnceBeforeRule(t *testing.T) { 61 require := require.New(t) 62 63 a := NewBuilder(nil).RemoveOnceBeforeRule(applyDefaultSelectLimitId).Build() 64 65 defRulesCount := countRules(NewDefault(nil).Batches) 66 67 require.Equal(countRules(a.Batches), defRulesCount-1) 68 } 69 70 func TestRemoveDefaultRule(t *testing.T) { 71 require := require.New(t) 72 73 a := NewBuilder(nil).RemoveDefaultRule(resolveSubqueriesId).Build() 74 75 defRulesCount := countRules(NewDefault(nil).Batches) 76 77 require.Equal(countRules(a.Batches), defRulesCount-1) 78 } 79 80 func TestRemoveOnceAfterRule(t *testing.T) { 81 require := require.New(t) 82 83 a := NewBuilder(nil).RemoveOnceAfterRule(loadTriggersId).Build() 84 85 defRulesCount := countRules(NewDefault(nil).Batches) 86 87 require.Equal(countRules(a.Batches), defRulesCount-1) 88 } 89 90 func TestRemoveValidationRule(t *testing.T) { 91 require := require.New(t) 92 93 a := NewBuilder(nil).RemoveValidationRule(validateResolvedId).Build() 94 95 defRulesCount := countRules(NewDefault(nil).Batches) 96 97 require.Equal(countRules(a.Batches), defRulesCount-1) 98 } 99 100 func TestRemoveAfterAllRule(t *testing.T) { 101 require := require.New(t) 102 103 a := NewBuilder(nil).RemoveAfterAllRule(TrackProcessId).Build() 104 105 defRulesCount := countRules(NewDefault(nil).Batches) 106 107 require.Equal(countRules(a.Batches), defRulesCount-1) 108 } 109 110 func countRules(batches []*Batch) int { 111 var count int 112 for _, b := range batches { 113 count = count + len(b.Rules) 114 } 115 return count 116 117 } 118 119 func TestDeepCopyNode(t *testing.T) { 120 tests := []struct { 121 node sql.Node 122 exp sql.Node 123 }{ 124 { 125 node: plan.NewProject( 126 []sql.Expression{ 127 expression.NewLiteral(1, types.Int64), 128 }, 129 plan.NewNaturalJoin( 130 plan.NewInnerJoin( 131 plan.NewUnresolvedTable("mytable", ""), 132 plan.NewUnresolvedTable("mytable2", ""), 133 expression.NewEquals( 134 expression.NewUnresolvedQualifiedColumn("mytable", "i"), 135 expression.NewUnresolvedQualifiedColumn("mytable2", "i2"), 136 ), 137 ), 138 plan.NewFilter( 139 expression.NewEquals( 140 expression.NewBindVar("v1"), 141 expression.NewBindVar("v2"), 142 ), 143 plan.NewUnresolvedTable("mytable3", ""), 144 ), 145 ), 146 ), 147 }, 148 { 149 node: plan.NewProject( 150 []sql.Expression{ 151 expression.NewLiteral(1, types.Int64), 152 }, 153 plan.NewSetOp( 154 plan.UnionType, 155 plan.NewProject( 156 []sql.Expression{ 157 expression.NewLiteral(1, types.Int64), 158 }, 159 plan.NewUnresolvedTable("mytable", ""), 160 ), 161 plan.NewProject( 162 []sql.Expression{ 163 expression.NewBindVar("v1"), 164 expression.NewBindVar("v2"), 165 }, 166 plan.NewUnresolvedTable("mytable", ""), 167 ), 168 false, nil, nil, nil), 169 ), 170 }, 171 { 172 node: plan.NewFilter( 173 expression.NewEquals( 174 expression.NewLiteral(1, types.Int64), 175 expression.NewLiteral(1, types.Int64), 176 ), 177 plan.NewWindow( 178 []sql.Expression{ 179 aggregation.NewSum( 180 expression.NewGetFieldWithTable(0, 0, types.Int64, "db", "a", "x", false), 181 ), 182 expression.NewGetFieldWithTable(0, 1, types.Int64, "db", "a", "x", false), 183 expression.NewBindVar("v1"), 184 }, 185 plan.NewProject( 186 []sql.Expression{ 187 expression.NewBindVar("v2"), 188 }, 189 plan.NewUnresolvedTable("x", ""), 190 ), 191 ), 192 ), 193 }, 194 { 195 node: plan.NewFilter( 196 expression.NewEquals( 197 expression.NewLiteral(1, types.Int64), 198 expression.NewLiteral(1, types.Int64), 199 ), 200 plan.NewSubqueryAlias("cte1", "select x from a", 201 plan.NewProject( 202 []sql.Expression{ 203 expression.NewBindVar("v1"), 204 expression.NewUnresolvedColumn("v2"), 205 }, 206 plan.NewUnresolvedTable("a", ""), 207 ), 208 ), 209 ), 210 }, 211 } 212 213 for i, tt := range tests { 214 t.Run(fmt.Sprintf("DeepCopyTest_%d", i), func(t *testing.T) { 215 cop, err := DeepCopyNode(tt.node) 216 require.NoError(t, err) 217 cop, _, err = plan.ApplyBindings(cop, map[string]sql.Expression{ 218 "v1": expression.NewLiteral(1, types.Int64), 219 "v2": expression.NewLiteral("x", types.Text), 220 }) 221 require.NoError(t, err) 222 require.NotEqual(t, cop, tt.node) 223 }) 224 } 225 }