github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/biased_coster_test.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "testing" 20 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/memory" 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/memo" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/transform" 29 "github.com/dolthub/go-mysql-server/sql/types" 30 ) 31 32 func TestBiasedCoster(t *testing.T) { 33 db := memory.NewDatabase("mydb") 34 db.EnablePrimaryKeyIndexes() 35 pro := memory.NewDBProvider(db) 36 ctx := newContext(pro) 37 38 xy := memory.NewFilteredTable(db, "xy", sql.NewPrimaryKeySchema(sql.Schema{ 39 {Name: "x", Type: types.Int64, Source: "xy"}, 40 {Name: "y", Type: types.Int64, Source: "xy"}, 41 }, 0), nil) 42 43 ab := memory.NewFilteredTable(db, "ab", sql.NewPrimaryKeySchema(sql.Schema{ 44 {Name: "a", Type: types.Int64, Source: "ab"}, 45 {Name: "b", Type: types.Int64, Source: "ab"}, 46 }, 0), nil) 47 48 xy.EnablePrimaryKeyIndexes() 49 ab.EnablePrimaryKeyIndexes() 50 51 ab.Insert(ctx, sql.Row{int64(0), int64(0)}) 52 ab.Insert(ctx, sql.Row{int64(1), int64(1)}) 53 ab.Insert(ctx, sql.Row{int64(2), int64(2)}) 54 xy.Insert(ctx, sql.Row{int64(0), int64(0)}) 55 xy.Insert(ctx, sql.Row{int64(1), int64(1)}) 56 57 db.AddTable("xy", xy) 58 db.AddTable("ab", ab) 59 60 a := NewDefault(sql.NewDatabaseProvider(db)) 61 62 n := plan.NewInnerJoin( 63 plan.NewResolvedTable(xy, db, nil).WithId(1).WithColumns(sql.NewColSet(1, 2)), 64 plan.NewResolvedTable(ab, db, nil).WithId(2).WithColumns(sql.NewColSet(3, 4)), 65 expression.NewEquals( 66 expression.NewGetFieldWithTable(1, 1, types.Int64, "db", "xy", "x", false), 67 expression.NewGetFieldWithTable(3, 2, types.Int64, "db", "ab", "a", false), 68 ), 69 ) 70 71 tests := []struct { 72 name string 73 c func() memo.Coster 74 exp plan.JoinType 75 }{ 76 { 77 name: "inner", 78 c: memo.NewInnerBiasedCoster, 79 exp: plan.JoinTypeInner, 80 }, 81 { 82 name: "lookup", 83 c: memo.NewLookupBiasedCoster, 84 exp: plan.JoinTypeLookup, 85 }, 86 { 87 name: "hash", 88 c: memo.NewHashBiasedCoster, 89 exp: plan.JoinTypeHash, 90 }, 91 { 92 name: "merge", 93 c: memo.NewMergeBiasedCoster, 94 exp: plan.JoinTypeMerge, 95 }, 96 } 97 98 for _, tt := range tests { 99 t.Run(fmt.Sprintf("%s biased coster", tt.name), func(t *testing.T) { 100 a.Coster = tt.c() 101 cmp, err := replanJoin(ctx, n, a, nil) 102 require.NoError(t, err) 103 types := collectJoinTypes(cmp)[0] 104 require.Equal(t, tt.exp, types) 105 }) 106 } 107 } 108 109 func collectJoinTypes(n sql.Node) []plan.JoinType { 110 var types []plan.JoinType 111 transform.Inspect(n, func(n sql.Node) bool { 112 if n == nil { 113 return true 114 } 115 j, ok := n.(*plan.JoinNode) 116 if ok { 117 types = append(types, j.Op) 118 } 119 120 if ex, ok := n.(sql.Expressioner); ok { 121 for _, e := range ex.Expressions() { 122 transform.InspectExpr(e, func(e sql.Expression) bool { 123 sq, ok := e.(*plan.Subquery) 124 if !ok { 125 return false 126 } 127 types = append(types, collectJoinTypes(sq.Query)...) 128 return false 129 }) 130 } 131 } 132 return true 133 }) 134 return types 135 }