github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/testutils/opttester/forcing_opt.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package opttester
    12  
    13  import (
    14  	"fmt"
    15  
    16  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/opt/xform"
    20  )
    21  
    22  // forcingOptimizer is a wrapper around an Optimizer which adds low-level
    23  // control, like restricting rule application or the expressions that can be
    24  // part of the final expression.
    25  type forcingOptimizer struct {
    26  	o xform.Optimizer
    27  
    28  	groups memoGroups
    29  
    30  	coster forcingCoster
    31  
    32  	// remaining is the number of "unused" steps remaining.
    33  	remaining int
    34  
    35  	// lastMatched records the name of the rule that was most recently matched
    36  	// by the optimizer.
    37  	lastMatched opt.RuleName
    38  
    39  	// lastApplied records the name of the rule that was most recently applied by
    40  	// the optimizer. This is not necessarily the same with lastMatched because
    41  	// normalization rules can run in-between the match and the application of an
    42  	// exploration rule.
    43  	lastApplied opt.RuleName
    44  
    45  	// lastAppliedSource is the expression matched by an exploration rule, or is
    46  	// nil for a normalization rule.
    47  	lastAppliedSource opt.Expr
    48  
    49  	// lastAppliedTarget is the new expression constructed by a normalization or
    50  	// exploration rule. For an exploration rule, it can be nil if no expressions
    51  	// were constructed, or can have additional expressions beyond the first that
    52  	// are accessible via NextExpr links.
    53  	lastAppliedTarget opt.Expr
    54  }
    55  
    56  // newForcingOptimizer creates a forcing optimizer that stops applying any rules
    57  // after <steps> rules are matched. If ignoreNormRules is true, normalization
    58  // rules don't count against this limit.
    59  func newForcingOptimizer(
    60  	tester *OptTester, steps int, ignoreNormRules bool,
    61  ) (*forcingOptimizer, error) {
    62  	fo := &forcingOptimizer{
    63  		remaining:   steps,
    64  		lastMatched: opt.InvalidRuleName,
    65  	}
    66  	fo.o.Init(&tester.evalCtx, tester.catalog)
    67  	fo.coster.Init(&fo.o, &fo.groups)
    68  	fo.o.SetCoster(&fo.coster)
    69  
    70  	fo.o.NotifyOnMatchedRule(func(ruleName opt.RuleName) bool {
    71  		if ignoreNormRules && ruleName.IsNormalize() {
    72  			return true
    73  		}
    74  		if fo.remaining == 0 {
    75  			return false
    76  		}
    77  		if tester.Flags.DisableRules.Contains(int(ruleName)) {
    78  			return false
    79  		}
    80  		fo.remaining--
    81  		fo.lastMatched = ruleName
    82  		return true
    83  	})
    84  
    85  	// Hook the AppliedRule notification in order to track the portion of the
    86  	// expression tree affected by each transformation rule.
    87  	fo.o.NotifyOnAppliedRule(
    88  		func(ruleName opt.RuleName, source, target opt.Expr) {
    89  			if ignoreNormRules && ruleName.IsNormalize() {
    90  				return
    91  			}
    92  			fo.lastApplied = ruleName
    93  			fo.lastAppliedSource = source
    94  			fo.lastAppliedTarget = target
    95  		},
    96  	)
    97  
    98  	fo.o.Memo().NotifyOnNewGroup(func(expr opt.Expr) {
    99  		fo.groups.AddGroup(expr)
   100  	})
   101  
   102  	if err := tester.buildExpr(fo.o.Factory()); err != nil {
   103  		return nil, err
   104  	}
   105  	return fo, nil
   106  }
   107  
   108  func (fo *forcingOptimizer) Optimize() opt.Expr {
   109  	expr, err := fo.o.Optimize()
   110  	if err != nil {
   111  		// Print the full error (it might contain a stack trace).
   112  		fmt.Printf("%+v\n", err)
   113  		panic(err)
   114  	}
   115  	return expr
   116  }
   117  
   118  // LookupPath returns the path of the given node.
   119  func (fo *forcingOptimizer) LookupPath(target opt.Expr) []memoLoc {
   120  	return fo.groups.FindPath(fo.o.Memo().RootExpr(), target)
   121  }
   122  
   123  // RestrictToExpr sets up the optimizer to restrict the result to only those
   124  // expression trees which include the given expression path.
   125  func (fo *forcingOptimizer) RestrictToExpr(path []memoLoc) {
   126  	for _, l := range path {
   127  		fo.coster.RestrictGroupToMember(l)
   128  	}
   129  }
   130  
   131  // forcingCoster implements the xform.Coster interface so that it can suppress
   132  // expressions in the memo that can't be part of the output tree.
   133  type forcingCoster struct {
   134  	o      *xform.Optimizer
   135  	groups *memoGroups
   136  
   137  	inner xform.Coster
   138  
   139  	restricted map[groupID]memberOrd
   140  }
   141  
   142  func (fc *forcingCoster) Init(o *xform.Optimizer, groups *memoGroups) {
   143  	fc.o = o
   144  	fc.groups = groups
   145  	fc.inner = o.Coster()
   146  }
   147  
   148  // RestrictGroupToMember forces the expression in the given location to be the
   149  // best expression for its group.
   150  func (fc *forcingCoster) RestrictGroupToMember(loc memoLoc) {
   151  	if fc.restricted == nil {
   152  		fc.restricted = make(map[groupID]memberOrd)
   153  	}
   154  	fc.restricted[loc.group] = loc.member
   155  }
   156  
   157  // ComputeCost is part of the xform.Coster interface.
   158  func (fc *forcingCoster) ComputeCost(e memo.RelExpr, required *physical.Required) memo.Cost {
   159  	if fc.restricted != nil {
   160  		loc := fc.groups.MemoLoc(e)
   161  		if mIdx, ok := fc.restricted[loc.group]; ok && loc.member != mIdx {
   162  			return memo.MaxCost
   163  		}
   164  	}
   165  
   166  	return fc.inner.ComputeCost(e, required)
   167  }