github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/parallelize_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/memory" 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/expression/function/aggregation/window" 26 "github.com/dolthub/go-mysql-server/sql/plan" 27 "github.com/dolthub/go-mysql-server/sql/transform" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 func TestParallelize(t *testing.T) { 32 require := require.New(t) 33 db := memory.NewDatabase("db") 34 pro := memory.NewDBProvider(db) 35 ctx := newContext(pro) 36 37 table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil) 38 rule := getRuleFrom(OnceAfterAll, parallelizeId) 39 node := plan.NewProject( 40 nil, 41 plan.NewInnerJoin( 42 plan.NewFilter( 43 expression.NewLiteral(1, types.Int64), 44 plan.NewResolvedTable(table, nil, nil), 45 ), 46 plan.NewFilter( 47 expression.NewLiteral(1, types.Int64), 48 plan.NewResolvedTable(table, nil, nil), 49 ), 50 expression.NewLiteral(1, types.Int64), 51 ), 52 ) 53 54 expected := plan.NewProject( 55 nil, 56 plan.NewInnerJoin( 57 plan.NewExchange( 58 2, 59 plan.NewFilter( 60 expression.NewLiteral(1, types.Int64), 61 plan.NewResolvedTable(table, nil, nil), 62 ), 63 ), 64 plan.NewExchange( 65 2, 66 plan.NewFilter( 67 expression.NewLiteral(1, types.Int64), 68 plan.NewResolvedTable(table, nil, nil), 69 ), 70 ), 71 expression.NewLiteral(1, types.Int64), 72 ), 73 ) 74 75 result, _, err := rule.Apply(ctx, &Analyzer{Parallelism: 2}, node, nil, DefaultRuleSelector) 76 require.NoError(err) 77 require.Equal(expected, result) 78 } 79 80 func TestParallelizeCreateIndex(t *testing.T) { 81 require := require.New(t) 82 db := memory.NewDatabase("db") 83 pro := memory.NewDBProvider(db) 84 ctx := newContext(pro) 85 86 table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil) 87 rule := getRuleFrom(OnceAfterAll, parallelizeId) 88 node := plan.NewCreateIndex( 89 "", 90 plan.NewResolvedTable(table, nil, nil), 91 nil, 92 "", 93 nil, 94 ) 95 96 result, _, err := rule.Apply(ctx, &Analyzer{Parallelism: 1}, node, nil, DefaultRuleSelector) 97 require.NoError(err) 98 require.Equal(node, result) 99 } 100 101 func TestIsParallelizable(t *testing.T) { 102 db := memory.NewDatabase("db") 103 table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil) 104 105 testCases := []struct { 106 name string 107 node sql.Node 108 parallelizable bool 109 }{ 110 { 111 "just table", 112 plan.NewResolvedTable(table, nil, nil), 113 true, 114 }, 115 { 116 "filter", 117 plan.NewFilter( 118 expression.NewLiteral(1, types.Int64), 119 plan.NewResolvedTable(table, nil, nil), 120 ), 121 true, 122 }, 123 { 124 "filter with a subquery", 125 plan.NewFilter( 126 eq( 127 lit(1), 128 plan.NewSubquery( 129 plan.NewProject([]sql.Expression{lit(1)}, plan.NewResolvedTable(table, nil, nil)), "select 1 from table")), 130 plan.NewResolvedTable(table, nil, nil), 131 ), 132 true, 133 }, 134 { 135 "filter with an incompatible subquery", 136 plan.NewFilter( 137 eq( 138 lit(1), 139 plan.NewSubquery( 140 plan.NewProject([]sql.Expression{gf(0, "", "row_number()")}, 141 plan.NewWindow([]sql.Expression{window.NewRowNumber()}, plan.NewResolvedTable(table, nil, nil)), 142 ), 143 "select row_number over () from table", 144 ), 145 ), 146 plan.NewResolvedTable(table, nil, nil), 147 ), 148 false, 149 }, 150 { 151 "project", 152 plan.NewProject( 153 nil, 154 plan.NewFilter( 155 expression.NewLiteral(1, types.Int64), 156 plan.NewResolvedTable(table, nil, nil), 157 ), 158 ), 159 true, 160 }, 161 { 162 "project with a subquery", 163 plan.NewProject([]sql.Expression{ 164 plan.NewSubquery( 165 plan.NewProject([]sql.Expression{lit(1)}, plan.NewResolvedTable(table, nil, nil)), "select 1 from table"), 166 }, 167 plan.NewFilter( 168 expression.NewLiteral(1, types.Int64), 169 plan.NewResolvedTable(table, nil, nil), 170 ), 171 ), 172 true, 173 }, 174 { 175 "project with an incompatible subquery", 176 plan.NewProject([]sql.Expression{ 177 plan.NewSubquery( 178 plan.NewProject([]sql.Expression{gf(0, "", "row_number()")}, 179 plan.NewWindow([]sql.Expression{window.NewRowNumber()}, plan.NewResolvedTable(table, nil, nil)), 180 ), 181 "select row_number over () from table", 182 ), 183 }, 184 plan.NewFilter( 185 expression.NewLiteral(1, types.Int64), 186 plan.NewResolvedTable(table, nil, nil), 187 ), 188 ), 189 false, 190 }, 191 { 192 "join", 193 plan.NewInnerJoin( 194 plan.NewResolvedTable(table, nil, nil), 195 plan.NewResolvedTable(table, nil, nil), 196 expression.NewLiteral(1, types.Int64), 197 ), 198 false, 199 }, 200 { 201 "group by", 202 plan.NewGroupBy( 203 nil, 204 nil, 205 plan.NewResolvedTable(nil, nil, nil), 206 ), 207 false, 208 }, 209 { 210 "limit", 211 plan.NewLimit( 212 expression.NewLiteral(5, types.Int8), 213 plan.NewResolvedTable(nil, nil, nil), 214 ), 215 false, 216 }, 217 { 218 "offset", 219 plan.NewOffset( 220 expression.NewLiteral(5, types.Int8), 221 plan.NewResolvedTable(nil, nil, nil), 222 ), 223 false, 224 }, 225 { 226 "sort", 227 plan.NewSort( 228 nil, 229 plan.NewResolvedTable(nil, nil, nil), 230 ), 231 false, 232 }, 233 { 234 "distinct", 235 plan.NewDistinct( 236 plan.NewResolvedTable(nil, nil, nil), 237 ), 238 false, 239 }, 240 { 241 "ordered distinct", 242 plan.NewOrderedDistinct( 243 plan.NewResolvedTable(nil, nil, nil), 244 ), 245 false, 246 }, 247 } 248 249 for _, tt := range testCases { 250 t.Run(tt.name, func(t *testing.T) { 251 require.Equal(t, tt.parallelizable, isParallelizable(tt.node)) 252 }) 253 } 254 } 255 256 func TestRemoveRedundantExchanges(t *testing.T) { 257 require := require.New(t) 258 db := memory.NewDatabase("db") 259 260 table := memory.NewTable(db, "t", sql.PrimaryKeySchema{}, nil) 261 262 node := plan.NewProject( 263 nil, 264 plan.NewInnerJoin( 265 plan.NewExchange( 266 1, 267 plan.NewFilter( 268 expression.NewLiteral(1, types.Int64), 269 plan.NewExchange( 270 1, 271 plan.NewResolvedTable(table, nil, nil), 272 ), 273 ), 274 ), 275 plan.NewExchange( 276 1, 277 plan.NewFilter( 278 expression.NewLiteral(1, types.Int64), 279 plan.NewExchange( 280 1, 281 plan.NewResolvedTable(table, nil, nil), 282 ), 283 ), 284 ), 285 expression.NewLiteral(1, types.Int64), 286 ), 287 ) 288 289 expected := plan.NewProject( 290 nil, 291 plan.NewInnerJoin( 292 plan.NewExchange( 293 1, 294 plan.NewFilter( 295 expression.NewLiteral(1, types.Int64), 296 plan.NewResolvedTable(table, nil, nil), 297 ), 298 ), 299 plan.NewExchange( 300 1, 301 plan.NewFilter( 302 expression.NewLiteral(1, types.Int64), 303 plan.NewResolvedTable(table, nil, nil), 304 ), 305 ), 306 expression.NewLiteral(1, types.Int64), 307 ), 308 ) 309 310 result, _, err := transform.Node(node, removeRedundantExchanges) 311 require.NoError(err) 312 require.Equal(expected, result) 313 }