github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/visit_plan.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "context" 19 20 "github.com/matrixorigin/matrixone/pkg/pb/plan" 21 ) 22 23 type VisitPlanRule interface { 24 MatchNode(*Node) bool 25 IsApplyExpr() bool 26 ApplyNode(*Node) error 27 ApplyExpr(*Expr) (*Expr, error) 28 } 29 30 type VisitPlan struct { 31 plan *Plan 32 isUpdatePlan bool 33 rules []VisitPlanRule 34 } 35 36 func NewVisitPlan(pl *Plan, rules []VisitPlanRule) *VisitPlan { 37 return &VisitPlan{ 38 plan: pl, 39 isUpdatePlan: false, 40 rules: rules, 41 } 42 } 43 44 func (vq *VisitPlan) visitNode(ctx context.Context, qry *Query, node *Node, idx int32) error { 45 for i := range node.Children { 46 if err := vq.visitNode(ctx, qry, qry.Nodes[node.Children[i]], node.Children[i]); err != nil { 47 return err 48 } 49 } 50 51 for _, rule := range vq.rules { 52 if rule.MatchNode(node) { 53 err := rule.ApplyNode(node) 54 if err != nil { 55 return err 56 } 57 } else if rule.IsApplyExpr() { 58 err := vq.exploreNode(ctx, rule, node, idx) 59 if err != nil { 60 return err 61 } 62 } 63 } 64 65 return nil 66 } 67 68 func (vq *VisitPlan) exploreNode(ctx context.Context, rule VisitPlanRule, node *Node, idx int32) error { 69 var err error 70 if node.Limit != nil { 71 node.Limit, err = rule.ApplyExpr(node.Limit) 72 if err != nil { 73 return err 74 } 75 } 76 77 if node.Offset != nil { 78 node.Offset, err = rule.ApplyExpr(node.Offset) 79 if err != nil { 80 return err 81 } 82 } 83 84 for i := range node.OnList { 85 node.OnList[i], err = rule.ApplyExpr(node.OnList[i]) 86 if err != nil { 87 return err 88 } 89 } 90 91 for i := range node.FilterList { 92 node.FilterList[i], err = rule.ApplyExpr(node.FilterList[i]) 93 if err != nil { 94 return err 95 } 96 } 97 98 if node.RowsetData != nil { 99 for i := range node.RowsetData.Cols { 100 for j := range node.RowsetData.Cols[i].Data { 101 node.RowsetData.Cols[i].Data[j], err = rule.ApplyExpr(node.RowsetData.Cols[i].Data[j]) 102 if err != nil { 103 return err 104 } 105 } 106 } 107 } 108 109 for i := range node.ProjectList { 110 // if prepare statement is: update set decimal_col = decimal_col + ? ; 111 // and then: 'set @a=12.22; execute stmt1 using @a;' decimal_col + ? will be float64 112 if vq.isUpdatePlan { 113 pl, _ := vq.plan.Plan.(*Plan_Query) 114 num := len(pl.Query.Nodes) - int(idx) 115 // last project Node 116 if num == 2 { 117 oldType := DeepCopyTyp(node.ProjectList[i].Typ) 118 node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i]) 119 if node.ProjectList[i].Typ.Id != oldType.Id { 120 node.ProjectList[i], err = makePlan2CastExpr(ctx, node.ProjectList[i], oldType) 121 } 122 } else { 123 node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i]) 124 } 125 } else { 126 node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i]) 127 } 128 129 if err != nil { 130 return err 131 } 132 } 133 134 return nil 135 } 136 137 func (vq *VisitPlan) Visit(ctx context.Context) error { 138 switch pl := vq.plan.Plan.(type) { 139 case *Plan_Query: 140 qry := pl.Query 141 vq.isUpdatePlan = (pl.Query.StmtType == plan.Query_UPDATE) 142 143 if len(qry.Steps) == 0 { 144 return nil 145 } 146 147 for _, step := range qry.Steps { 148 err := vq.visitNode(ctx, qry, qry.Nodes[step], step) 149 if err != nil { 150 return err 151 } 152 } 153 154 default: 155 // do nothing 156 157 } 158 159 return nil 160 }