github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/visit_plan.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"context"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    21  )
    22  
    23  type VisitPlanRule interface {
    24  	MatchNode(*Node) bool
    25  	IsApplyExpr() bool
    26  	ApplyNode(*Node) error
    27  	ApplyExpr(*Expr) (*Expr, error)
    28  }
    29  
    30  type VisitPlan struct {
    31  	plan         *Plan
    32  	isUpdatePlan bool
    33  	rules        []VisitPlanRule
    34  }
    35  
    36  func NewVisitPlan(pl *Plan, rules []VisitPlanRule) *VisitPlan {
    37  	return &VisitPlan{
    38  		plan:         pl,
    39  		isUpdatePlan: false,
    40  		rules:        rules,
    41  	}
    42  }
    43  
    44  func (vq *VisitPlan) visitNode(ctx context.Context, qry *Query, node *Node, idx int32) error {
    45  	for i := range node.Children {
    46  		if err := vq.visitNode(ctx, qry, qry.Nodes[node.Children[i]], node.Children[i]); err != nil {
    47  			return err
    48  		}
    49  	}
    50  
    51  	for _, rule := range vq.rules {
    52  		if rule.MatchNode(node) {
    53  			err := rule.ApplyNode(node)
    54  			if err != nil {
    55  				return err
    56  			}
    57  		} else if rule.IsApplyExpr() {
    58  			err := vq.exploreNode(ctx, rule, node, idx)
    59  			if err != nil {
    60  				return err
    61  			}
    62  		}
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  func (vq *VisitPlan) exploreNode(ctx context.Context, rule VisitPlanRule, node *Node, idx int32) error {
    69  	var err error
    70  	if node.Limit != nil {
    71  		node.Limit, err = rule.ApplyExpr(node.Limit)
    72  		if err != nil {
    73  			return err
    74  		}
    75  	}
    76  
    77  	if node.Offset != nil {
    78  		node.Offset, err = rule.ApplyExpr(node.Offset)
    79  		if err != nil {
    80  			return err
    81  		}
    82  	}
    83  
    84  	for i := range node.OnList {
    85  		node.OnList[i], err = rule.ApplyExpr(node.OnList[i])
    86  		if err != nil {
    87  			return err
    88  		}
    89  	}
    90  
    91  	for i := range node.FilterList {
    92  		node.FilterList[i], err = rule.ApplyExpr(node.FilterList[i])
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  
    98  	if node.RowsetData != nil {
    99  		for i := range node.RowsetData.Cols {
   100  			for j := range node.RowsetData.Cols[i].Data {
   101  				node.RowsetData.Cols[i].Data[j], err = rule.ApplyExpr(node.RowsetData.Cols[i].Data[j])
   102  				if err != nil {
   103  					return err
   104  				}
   105  			}
   106  		}
   107  	}
   108  
   109  	for i := range node.ProjectList {
   110  		// if prepare statement is:   update set decimal_col = decimal_col + ? ;
   111  		// and then: 'set @a=12.22; execute stmt1 using @a;'  decimal_col + ? will be float64
   112  		if vq.isUpdatePlan {
   113  			pl, _ := vq.plan.Plan.(*Plan_Query)
   114  			num := len(pl.Query.Nodes) - int(idx)
   115  			// last project Node
   116  			if num == 2 {
   117  				oldType := DeepCopyTyp(node.ProjectList[i].Typ)
   118  				node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i])
   119  				if node.ProjectList[i].Typ.Id != oldType.Id {
   120  					node.ProjectList[i], err = makePlan2CastExpr(ctx, node.ProjectList[i], oldType)
   121  				}
   122  			} else {
   123  				node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i])
   124  			}
   125  		} else {
   126  			node.ProjectList[i], err = rule.ApplyExpr(node.ProjectList[i])
   127  		}
   128  
   129  		if err != nil {
   130  			return err
   131  		}
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (vq *VisitPlan) Visit(ctx context.Context) error {
   138  	switch pl := vq.plan.Plan.(type) {
   139  	case *Plan_Query:
   140  		qry := pl.Query
   141  		vq.isUpdatePlan = (pl.Query.StmtType == plan.Query_UPDATE)
   142  
   143  		if len(qry.Steps) == 0 {
   144  			return nil
   145  		}
   146  
   147  		for _, step := range qry.Steps {
   148  			err := vq.visitNode(ctx, qry, qry.Nodes[step], step)
   149  			if err != nil {
   150  				return err
   151  			}
   152  		}
   153  
   154  	default:
   155  		// do nothing
   156  
   157  	}
   158  
   159  	return nil
   160  }