github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/testutils/opttester/opt_tester.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package opttester
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	gosql "database/sql"
    17  	"encoding/json"
    18  	"flag"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"math"
    22  	"path/filepath"
    23  	"runtime"
    24  	"sort"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  	"text/tabwriter"
    29  
    30  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    31  	"github.com/cockroachdb/cockroach/pkg/settings/cluster"
    32  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    33  	"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
    34  	_ "github.com/cockroachdb/cockroach/pkg/sql/opt/exec/execbuilder" // for ExprFmtHideScalars.
    35  	"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
    36  	"github.com/cockroachdb/cockroach/pkg/sql/opt/norm"
    37  	"github.com/cockroachdb/cockroach/pkg/sql/opt/optbuilder"
    38  	"github.com/cockroachdb/cockroach/pkg/sql/opt/optgen/exprgen"
    39  	"github.com/cockroachdb/cockroach/pkg/sql/opt/testutils/testcat"
    40  	"github.com/cockroachdb/cockroach/pkg/sql/opt/xform"
    41  	"github.com/cockroachdb/cockroach/pkg/sql/parser"
    42  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    43  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    44  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    45  	"github.com/cockroachdb/cockroach/pkg/sql/stats"
    46  	"github.com/cockroachdb/cockroach/pkg/util"
    47  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    48  	"github.com/cockroachdb/cockroach/pkg/util/treeprinter"
    49  	"github.com/cockroachdb/datadriven"
    50  	"github.com/cockroachdb/errors"
    51  	"github.com/pmezard/go-difflib/difflib"
    52  )
    53  
    54  const rewriteActualFlag = "rewrite-actual-stats"
    55  
    56  var (
    57  	rewriteActualStats = flag.Bool(
    58  		rewriteActualFlag, false,
    59  		"used to update the actual statistics for statistics quality tests. If true, the opttester "+
    60  			"will actually run the test queries to calculate actual statistics for comparison with the "+
    61  			"estimated stats.",
    62  	)
    63  	pgurl = flag.String(
    64  		"pgurl", "postgresql://localhost:26257/?sslmode=disable&user=root",
    65  		"the database url to connect to",
    66  	)
    67  
    68  	formatFlags = map[string]memo.ExprFmtFlags{
    69  		"miscprops":   memo.ExprFmtHideMiscProps,
    70  		"constraints": memo.ExprFmtHideConstraints,
    71  		"funcdeps":    memo.ExprFmtHideFuncDeps,
    72  		"ruleprops":   memo.ExprFmtHideRuleProps,
    73  		"stats":       memo.ExprFmtHideStats,
    74  		"cost":        memo.ExprFmtHideCost,
    75  		"qual":        memo.ExprFmtHideQualifications,
    76  		"scalars":     memo.ExprFmtHideScalars,
    77  		"physprops":   memo.ExprFmtHidePhysProps,
    78  		"types":       memo.ExprFmtHideTypes,
    79  		"notnull":     memo.ExprFmtHideNotNull,
    80  		"columns":     memo.ExprFmtHideColumns,
    81  		"all":         memo.ExprFmtHideAll,
    82  	}
    83  )
    84  
    85  // RuleSet efficiently stores an unordered set of RuleNames.
    86  type RuleSet = util.FastIntSet
    87  
    88  // OptTester is a helper for testing the various optimizer components. It
    89  // contains the boiler-plate code for the following useful tasks:
    90  //   - Build an unoptimized opt expression tree
    91  //   - Build an optimized opt expression tree
    92  //   - Format the optimizer memo structure
    93  //   - Create a diff showing the optimizer's work, step-by-step
    94  //   - Build the exec node tree
    95  //   - Execute the exec node tree
    96  //
    97  // The OptTester is used by tests in various sub-packages of the opt package.
    98  type OptTester struct {
    99  	Flags Flags
   100  
   101  	catalog   cat.Catalog
   102  	sql       string
   103  	ctx       context.Context
   104  	semaCtx   tree.SemaContext
   105  	evalCtx   tree.EvalContext
   106  	seenRules RuleSet
   107  
   108  	builder strings.Builder
   109  }
   110  
   111  // Flags are control knobs for tests. Note that specific testcases can
   112  // override these defaults.
   113  type Flags struct {
   114  	// ExprFormat controls the output detail of build / opt/ optsteps command
   115  	// directives.
   116  	ExprFormat memo.ExprFmtFlags
   117  
   118  	// MemoFormat controls the output detail of memo command directives.
   119  	MemoFormat xform.FmtFlags
   120  
   121  	// FullyQualifyNames if set: when building a query, the optbuilder fully
   122  	// qualifies all column names before adding them to the metadata. This flag
   123  	// allows us to test that name resolution works correctly, and avoids
   124  	// cluttering test output with schema and catalog names in the general case.
   125  	FullyQualifyNames bool
   126  
   127  	// Verbose indicates whether verbose test debugging information will be
   128  	// output to stdout when commands run. Only certain commands support this.
   129  	Verbose bool
   130  
   131  	// DisableRules is a set of rules that are not allowed to run.
   132  	DisableRules RuleSet
   133  
   134  	// ExploreTraceRule restricts the ExploreTrace output to only show the effects
   135  	// of a specific rule.
   136  	ExploreTraceRule opt.RuleName
   137  
   138  	// ExploreTraceSkipNoop hides the ExploreTrace output for instances of rules
   139  	// that fire but don't add any new expressions to the memo.
   140  	ExploreTraceSkipNoop bool
   141  
   142  	// ExpectedRules is a set of rules which must be exercised for the test to
   143  	// pass.
   144  	ExpectedRules RuleSet
   145  
   146  	// UnexpectedRules is a set of rules which must not be exercised for the test
   147  	// to pass.
   148  	UnexpectedRules RuleSet
   149  
   150  	// ColStats is a list of ColSets for which a column statistic is requested.
   151  	ColStats []opt.ColSet
   152  
   153  	// PerturbCost indicates how much to randomly perturb the cost. It is used
   154  	// to generate alternative plans for testing. For example, if PerturbCost is
   155  	// 0.5, and the estimated cost of an expression is c, the cost returned by
   156  	// the coster will be in the range [c - 0.5 * c, c + 0.5 * c).
   157  	PerturbCost float64
   158  
   159  	// ReorderJoinsLimit is the maximum number of joins in a query which the optimizer
   160  	// should attempt to reorder.
   161  	JoinLimit int
   162  
   163  	// Locality specifies the location of the planning node as a set of user-
   164  	// defined key/value pairs, ordered from most inclusive to least inclusive.
   165  	// If there are no tiers, then the node's location is not known. Examples:
   166  	//
   167  	//   [region=eu]
   168  	//   [region=us,dc=east]
   169  	//
   170  	Locality roachpb.Locality
   171  
   172  	// Database specifies the current database to use for the query. This field
   173  	// is only used by the save-tables command when rewriteActualFlag=true.
   174  	Database string
   175  
   176  	// Table specifies the current table to use for the command. This field
   177  	// is only used by the stats and inject-stats commands.
   178  	Table string
   179  
   180  	// SaveTablesPrefix specifies the prefix of the table to create or print
   181  	// for each subexpression in the query.
   182  	SaveTablesPrefix string
   183  
   184  	// File specifies the name of the file to import. This field is only used by
   185  	// the import command.
   186  	File string
   187  
   188  	// CascadeLevels limits the depth of recursive cascades for build-cascades.
   189  	CascadeLevels int
   190  }
   191  
   192  // New constructs a new instance of the OptTester for the given SQL statement.
   193  // Metadata used by the SQL query is accessed via the catalog.
   194  func New(catalog cat.Catalog, sql string) *OptTester {
   195  	ctx := context.Background()
   196  	ot := &OptTester{
   197  		catalog: catalog,
   198  		sql:     sql,
   199  		ctx:     ctx,
   200  		semaCtx: tree.MakeSemaContext(),
   201  		evalCtx: tree.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()),
   202  	}
   203  
   204  	// Set any OptTester-wide session flags here.
   205  
   206  	ot.evalCtx.SessionData.ZigzagJoinEnabled = true
   207  	ot.evalCtx.SessionData.OptimizerFKChecks = true
   208  	ot.evalCtx.SessionData.OptimizerFKCascades = true
   209  	ot.evalCtx.SessionData.OptimizerUseHistograms = true
   210  	ot.evalCtx.SessionData.OptimizerUseMultiColStats = true
   211  	ot.evalCtx.SessionData.ReorderJoinsLimit = opt.DefaultJoinOrderLimit
   212  	ot.evalCtx.SessionData.InsertFastPath = true
   213  
   214  	return ot
   215  }
   216  
   217  // RunCommand implements commands that are used by most tests:
   218  //
   219  //  - exec-ddl
   220  //
   221  //    Runs a SQL DDL statement to build the test catalog. Only a small number
   222  //    of DDL statements are supported, and those not fully. This is only
   223  //    available when using a TestCatalog.
   224  //
   225  //  - build [flags]
   226  //
   227  //    Builds an expression tree from a SQL query and outputs it without any
   228  //    optimizations applied to it.
   229  //
   230  //  - norm [flags]
   231  //
   232  //    Builds an expression tree from a SQL query, applies normalization
   233  //    optimizations, and outputs it without any exploration optimizations
   234  //    applied to it.
   235  //
   236  //  - opt [flags]
   237  //
   238  //    Builds an expression tree from a SQL query, fully optimizes it using the
   239  //    memo, and then outputs the lowest cost tree.
   240  //
   241  //  - build-cascades [flags]
   242  //
   243  //    Builds a query and then recursively builds cascading queries. Outputs all
   244  //    unoptimized plans.
   245  //
   246  //  - optsteps [flags]
   247  //
   248  //    Outputs the lowest cost tree for each step in optimization using the
   249  //    standard unified diff format. Used for debugging the optimizer.
   250  //
   251  //  - exploretrace [flags]
   252  //
   253  //    Outputs information about exploration rule application. Used for debugging
   254  //    the optimizer.
   255  //
   256  //  - memo [flags]
   257  //
   258  //    Builds an expression tree from a SQL query, fully optimizes it using the
   259  //    memo, and then outputs the memo containing the forest of trees.
   260  //
   261  //  - rulestats [flags]
   262  //
   263  //    Performs the optimization and outputs statistics about applied rules.
   264  //
   265  //  - expr
   266  //
   267  //    Builds an expression directly from an opt-gen-like string; see
   268  //    exprgen.Build.
   269  //
   270  //  - exprnorm
   271  //
   272  //    Builds an expression directly from an opt-gen-like string (see
   273  //    exprgen.Build), applies normalization optimizations, and outputs the tree
   274  //    without any exploration optimizations applied to it.
   275  //
   276  //  - save-tables [flags]
   277  //
   278  //    Fully optimizes the given query and saves the subexpressions as tables
   279  //    in the test catalog with their estimated statistics injected.
   280  //    If rewriteActualFlag=true, also executes the given query against a
   281  //    running database and saves the intermediate results as tables.
   282  //
   283  //  - stats table=... [flags]
   284  //
   285  //    Compares estimated statistics for a relational expression with the actual
   286  //    statistics calculated by calling CREATE STATISTICS on the output of the
   287  //    expression. save-tables must have been called previously to save the
   288  //    target expression as a table. The name of this table must be provided
   289  //    with the table flag.
   290  //
   291  //  - import file=...
   292  //
   293  //    Imports a file containing exec-ddl commands in order to add tables and/or
   294  //    stats to the catalog. This allows commonly-used schemas such as TPC-C or
   295  //    TPC-H to be used by multiple test files without copying the schemas and
   296  //    stats multiple times. The file name must be provided with the file flag.
   297  //    The path of the file should be relative to
   298  //    testutils/opttester/testfixtures.
   299  //
   300  //  - inject-stats file=... table=...
   301  //
   302  //    Injects table statistics from a json file.
   303  //
   304  // Supported flags:
   305  //
   306  //  - format: controls the formatting of expressions for build, opt, and
   307  //    optsteps commands. Format flags are of the form
   308  //      (show|hide)-(all|miscprops|constraints|scalars|types|...)
   309  //    See formatFlags for all flags. Multiple flags can be specified; each flag
   310  //    modifies the existing set of the flags.
   311  //
   312  //  - fully-qualify-names: fully qualify all column names in the test output.
   313  //
   314  //  - expect: fail the test if the rules specified by name do not match.
   315  //
   316  //  - expect-not: fail the test if the rules specified by name match.
   317  //
   318  //  - disable: disables optimizer rules by name. Examples:
   319  //      opt disable=ConstrainScan
   320  //      norm disable=(NegateOr,NegateAnd)
   321  //
   322  //  - rule: used with exploretrace; the value is the name of a rule. When
   323  //    specified, the exploretrace output is filtered to only show expression
   324  //    changes due to that specific rule.
   325  //
   326  //  - skip-no-op: used with exploretrace; hide instances of rules that don't
   327  //    generate any new expressions.
   328  //
   329  //  - colstat: requests the calculation of a column statistic on the top-level
   330  //    expression. The value is a column or a list of columns. The flag can
   331  //    be used multiple times to request different statistics.
   332  //
   333  //  - perturb-cost: used to randomly perturb the estimated cost of each
   334  //    expression in the query tree for the purpose of creating alternate query
   335  //    plans in the optimizer.
   336  //
   337  //  - locality: used to set the locality of the node that plans the query. This
   338  //    can affect costing when there are multiple possible indexes to choose
   339  //    from, each in different localities.
   340  //
   341  //  - database: used to set the current database used by the query. This is
   342  //    used by the save-tables command when rewriteActualFlag=true.
   343  //
   344  //  - table: used to set the current table used by the command. This is used by
   345  //    the stats command.
   346  //
   347  //  - save-tables-prefix: must be used with the save-tables command. If
   348  //    rewriteActualFlag=true, indicates that a table should be created with the
   349  //    given prefix for the output of each subexpression in the query.
   350  //    Otherwise, outputs the name of the table that would be created for each
   351  //    subexpression.
   352  //
   353  //  - file: specifies a file, used for the following commands:
   354  //     - import: the file path is relative to opttester/testfixtures;
   355  //     - inject-stats: the file path is relative to the test file.
   356  //
   357  //  - cascade-levels: used to limit the depth of recursive cascades for
   358  //    build-cascades.
   359  //
   360  func (ot *OptTester) RunCommand(tb testing.TB, d *datadriven.TestData) string {
   361  	// Allow testcases to override the flags.
   362  	for _, a := range d.CmdArgs {
   363  		if err := ot.Flags.Set(a); err != nil {
   364  			d.Fatalf(tb, "%+v", err)
   365  		}
   366  	}
   367  
   368  	if ot.Flags.JoinLimit != 0 {
   369  		defer func(oldValue int) {
   370  			ot.evalCtx.SessionData.ReorderJoinsLimit = oldValue
   371  		}(ot.evalCtx.SessionData.ReorderJoinsLimit)
   372  		ot.evalCtx.SessionData.ReorderJoinsLimit = ot.Flags.JoinLimit
   373  	}
   374  
   375  	ot.Flags.Verbose = testing.Verbose()
   376  	ot.evalCtx.TestingKnobs.OptimizerCostPerturbation = ot.Flags.PerturbCost
   377  	ot.evalCtx.Locality = ot.Flags.Locality
   378  	ot.evalCtx.SessionData.SaveTablesPrefix = ot.Flags.SaveTablesPrefix
   379  
   380  	switch d.Cmd {
   381  	case "exec-ddl":
   382  		testCatalog, ok := ot.catalog.(*testcat.Catalog)
   383  		if !ok {
   384  			d.Fatalf(tb, "exec-ddl can only be used with TestCatalog")
   385  		}
   386  		s, err := testCatalog.ExecuteDDL(d.Input)
   387  		if err != nil {
   388  			d.Fatalf(tb, "%v", err)
   389  		}
   390  		return s
   391  
   392  	case "build":
   393  		e, err := ot.OptBuild()
   394  		if err != nil {
   395  			if errors.HasAssertionFailure(err) {
   396  				d.Fatalf(tb, "%+v", err)
   397  			}
   398  			pgerr := pgerror.Flatten(err)
   399  			text := strings.TrimSpace(pgerr.Error())
   400  			if pgerr.Code != pgcode.Uncategorized {
   401  				// Output Postgres error code if it's available.
   402  				return fmt.Sprintf("error (%s): %s\n", pgerr.Code, text)
   403  			}
   404  			return fmt.Sprintf("error: %s\n", text)
   405  		}
   406  		ot.postProcess(tb, d, e)
   407  		return ot.FormatExpr(e)
   408  
   409  	case "norm":
   410  		e, err := ot.OptNorm()
   411  		if err != nil {
   412  			if errors.HasAssertionFailure(err) {
   413  				d.Fatalf(tb, "%+v", err)
   414  			}
   415  			pgerr := pgerror.Flatten(err)
   416  			text := strings.TrimSpace(pgerr.Error())
   417  			if pgerr.Code != pgcode.Uncategorized {
   418  				// Output Postgres error code if it's available.
   419  				return fmt.Sprintf("error (%s): %s\n", pgerr.Code, text)
   420  			}
   421  			return fmt.Sprintf("error: %s\n", text)
   422  		}
   423  		ot.postProcess(tb, d, e)
   424  		return ot.FormatExpr(e)
   425  
   426  	case "opt":
   427  		e, err := ot.Optimize()
   428  		if err != nil {
   429  			d.Fatalf(tb, "%+v", err)
   430  		}
   431  		ot.postProcess(tb, d, e)
   432  		return ot.FormatExpr(e)
   433  
   434  	case "build-cascades":
   435  		o := ot.makeOptimizer()
   436  		o.DisableOptimizations()
   437  		if err := ot.buildExpr(o.Factory()); err != nil {
   438  			d.Fatalf(tb, "%+v", err)
   439  		}
   440  		e := o.Memo().RootExpr()
   441  
   442  		var buildCascades func(e opt.Expr, tp treeprinter.Node, level int)
   443  		buildCascades = func(e opt.Expr, tp treeprinter.Node, level int) {
   444  			if ot.Flags.CascadeLevels != 0 && level > ot.Flags.CascadeLevels {
   445  				return
   446  			}
   447  			if opt.IsMutationOp(e) {
   448  				p := e.Private().(*memo.MutationPrivate)
   449  
   450  				for _, c := range p.FKCascades {
   451  					// We use the same memo to build the cascade. This makes the entire
   452  					// tree easier to read (e.g. the column IDs won't overlap).
   453  					cascade, err := c.Builder.Build(
   454  						context.Background(),
   455  						&ot.semaCtx,
   456  						&ot.evalCtx,
   457  						ot.catalog,
   458  						o.Factory(),
   459  						c.WithID,
   460  						e.Child(0).(memo.RelExpr).Relational(),
   461  						c.OldValues,
   462  						c.NewValues,
   463  					)
   464  					if err != nil {
   465  						d.Fatalf(tb, "error building cascade: %+v", err)
   466  					}
   467  					n := tp.Child("cascade")
   468  					n.Child(strings.TrimRight(ot.FormatExpr(cascade), "\n"))
   469  					buildCascades(cascade, n, level+1)
   470  				}
   471  			}
   472  			for i := 0; i < e.ChildCount(); i++ {
   473  				buildCascades(e.Child(i), tp, level)
   474  			}
   475  		}
   476  		tp := treeprinter.New()
   477  		root := tp.Child("root")
   478  		root.Child(strings.TrimRight(ot.FormatExpr(e), "\n"))
   479  		buildCascades(e, root, 1)
   480  
   481  		return tp.String()
   482  
   483  	case "optsteps":
   484  		result, err := ot.OptSteps()
   485  		if err != nil {
   486  			d.Fatalf(tb, "%+v", err)
   487  		}
   488  		return result
   489  
   490  	case "exploretrace":
   491  		result, err := ot.ExploreTrace()
   492  		if err != nil {
   493  			d.Fatalf(tb, "%+v", err)
   494  		}
   495  		return result
   496  
   497  	case "rulestats":
   498  		result, err := ot.RuleStats()
   499  		if err != nil {
   500  			d.Fatalf(tb, "%+v", err)
   501  		}
   502  		return result
   503  
   504  	case "memo":
   505  		result, err := ot.Memo()
   506  		if err != nil {
   507  			d.Fatalf(tb, "%+v", err)
   508  		}
   509  		return result
   510  
   511  	case "expr":
   512  		e, err := ot.Expr()
   513  		if err != nil {
   514  			d.Fatalf(tb, "%+v", err)
   515  		}
   516  		ot.postProcess(tb, d, e)
   517  		return ot.FormatExpr(e)
   518  
   519  	case "exprnorm":
   520  		e, err := ot.ExprNorm()
   521  		if err != nil {
   522  			d.Fatalf(tb, "%+v", err)
   523  		}
   524  		ot.postProcess(tb, d, e)
   525  		return ot.FormatExpr(e)
   526  
   527  	case "save-tables":
   528  		e, err := ot.SaveTables()
   529  		if err != nil {
   530  			d.Fatalf(tb, "%+v", err)
   531  		}
   532  		ot.postProcess(tb, d, e)
   533  		return ot.FormatExpr(e)
   534  
   535  	case "stats":
   536  		result, err := ot.Stats(tb, d)
   537  		if err != nil {
   538  			d.Fatalf(tb, "%+v", err)
   539  		}
   540  		return result
   541  
   542  	case "import":
   543  		ot.Import(tb)
   544  		return ""
   545  
   546  	case "inject-stats":
   547  		ot.InjectStats(tb, d)
   548  		return ""
   549  
   550  	default:
   551  		d.Fatalf(tb, "unsupported command: %s", d.Cmd)
   552  		return ""
   553  	}
   554  }
   555  
   556  // FormatExpr is a convenience wrapper for memo.FormatExpr.
   557  func (ot *OptTester) FormatExpr(e opt.Expr) string {
   558  	var mem *memo.Memo
   559  	if rel, ok := e.(memo.RelExpr); ok {
   560  		mem = rel.Memo()
   561  	}
   562  	return memo.FormatExpr(e, ot.Flags.ExprFormat, mem, ot.catalog)
   563  }
   564  
   565  func formatRuleSet(r RuleSet) string {
   566  	var buf bytes.Buffer
   567  	comma := false
   568  	for i, ok := r.Next(0); ok; i, ok = r.Next(i + 1) {
   569  		if comma {
   570  			buf.WriteString(", ")
   571  		}
   572  		comma = true
   573  		fmt.Fprintf(&buf, "%v", opt.RuleName(i))
   574  	}
   575  	return buf.String()
   576  }
   577  
   578  func (ot *OptTester) postProcess(tb testing.TB, d *datadriven.TestData, e opt.Expr) {
   579  	fillInLazyProps(e)
   580  
   581  	if rel, ok := e.(memo.RelExpr); ok {
   582  		for _, cols := range ot.Flags.ColStats {
   583  			memo.RequestColStat(&ot.evalCtx, rel, cols)
   584  		}
   585  	}
   586  
   587  	if !ot.Flags.ExpectedRules.SubsetOf(ot.seenRules) {
   588  		unseen := ot.Flags.ExpectedRules.Difference(ot.seenRules)
   589  		d.Fatalf(tb, "expected to see %s, but was not triggered. Did see %s",
   590  			formatRuleSet(unseen), formatRuleSet(ot.seenRules))
   591  	}
   592  
   593  	if ot.Flags.UnexpectedRules.Intersects(ot.seenRules) {
   594  		seen := ot.Flags.UnexpectedRules.Intersection(ot.seenRules)
   595  		d.Fatalf(tb, "expected not to see %s, but it was triggered", formatRuleSet(seen))
   596  	}
   597  }
   598  
   599  // Fills in lazily-derived properties (for display).
   600  func fillInLazyProps(e opt.Expr) {
   601  	if rel, ok := e.(memo.RelExpr); ok {
   602  		// Derive columns that are candidates for pruning.
   603  		norm.DerivePruneCols(rel)
   604  
   605  		// Derive columns that are candidates for null rejection.
   606  		norm.DeriveRejectNullCols(rel)
   607  
   608  		// Make sure the interesting orderings are calculated.
   609  		xform.DeriveInterestingOrderings(rel)
   610  
   611  		// Make sure the multiplicity is populated.
   612  		memo.DeriveJoinMultiplicity(rel)
   613  	}
   614  
   615  	for i, n := 0, e.ChildCount(); i < n; i++ {
   616  		fillInLazyProps(e.Child(i))
   617  	}
   618  }
   619  
   620  func ruleNamesToRuleSet(args []string) (RuleSet, error) {
   621  	var result RuleSet
   622  	for _, r := range args {
   623  		rn, err := ruleFromString(r)
   624  		if err != nil {
   625  			return result, err
   626  		}
   627  		result.Add(int(rn))
   628  	}
   629  	return result, nil
   630  }
   631  
   632  // Set parses an argument that refers to a flag.
   633  // See OptTester.RunCommand for supported flags.
   634  func (f *Flags) Set(arg datadriven.CmdArg) error {
   635  	switch arg.Key {
   636  	case "format":
   637  		if len(arg.Vals) == 0 {
   638  			return fmt.Errorf("format flag requires value(s)")
   639  		}
   640  		for _, v := range arg.Vals {
   641  			// Format values are of the form (hide|show)-(flag). These flags modify
   642  			// the default flags for the test and multiple flags are applied in order.
   643  			parts := strings.SplitN(v, "-", 2)
   644  			if len(parts) != 2 ||
   645  				(parts[0] != "show" && parts[0] != "hide") ||
   646  				formatFlags[parts[1]] == 0 {
   647  				return fmt.Errorf("unknown format value %s", v)
   648  			}
   649  			if parts[0] == "hide" {
   650  				f.ExprFormat |= formatFlags[parts[1]]
   651  			} else {
   652  				f.ExprFormat &= ^formatFlags[parts[1]]
   653  			}
   654  		}
   655  
   656  	case "fully-qualify-names":
   657  		f.FullyQualifyNames = true
   658  		// Hiding qualifications defeats the purpose.
   659  		f.ExprFormat &= ^memo.ExprFmtHideQualifications
   660  
   661  	case "disable":
   662  		if len(arg.Vals) == 0 {
   663  			return fmt.Errorf("disable requires arguments")
   664  		}
   665  		for _, s := range arg.Vals {
   666  			r, err := ruleFromString(s)
   667  			if err != nil {
   668  				return err
   669  			}
   670  			f.DisableRules.Add(int(r))
   671  		}
   672  
   673  	case "join-limit":
   674  		if len(arg.Vals) != 1 {
   675  			return fmt.Errorf("join-limit requires a single argument")
   676  		}
   677  		limit, err := strconv.ParseInt(arg.Vals[0], 10, 64)
   678  		if err != nil {
   679  			return errors.Wrap(err, "join-limit")
   680  		}
   681  		f.JoinLimit = int(limit)
   682  
   683  	case "rule":
   684  		if len(arg.Vals) != 1 {
   685  			return fmt.Errorf("rule requires one argument")
   686  		}
   687  		var err error
   688  		f.ExploreTraceRule, err = ruleFromString(arg.Vals[0])
   689  		if err != nil {
   690  			return err
   691  		}
   692  
   693  	case "skip-no-op":
   694  		f.ExploreTraceSkipNoop = true
   695  
   696  	case "expect":
   697  		ruleset, err := ruleNamesToRuleSet(arg.Vals)
   698  		if err != nil {
   699  			return err
   700  		}
   701  		f.ExpectedRules.UnionWith(ruleset)
   702  
   703  	case "expect-not":
   704  		ruleset, err := ruleNamesToRuleSet(arg.Vals)
   705  		if err != nil {
   706  			return err
   707  		}
   708  		f.UnexpectedRules.UnionWith(ruleset)
   709  
   710  	case "colstat":
   711  		if len(arg.Vals) == 0 {
   712  			return fmt.Errorf("colstat requires arguments")
   713  		}
   714  		var cols opt.ColSet
   715  		for _, v := range arg.Vals {
   716  			col, err := strconv.Atoi(v)
   717  			if err != nil {
   718  				return fmt.Errorf("invalid colstat column %v", v)
   719  			}
   720  			cols.Add(opt.ColumnID(col))
   721  		}
   722  		f.ColStats = append(f.ColStats, cols)
   723  
   724  	case "perturb-cost":
   725  		if len(arg.Vals) != 1 {
   726  			return fmt.Errorf("perturb-cost requires one argument")
   727  		}
   728  		var err error
   729  		f.PerturbCost, err = strconv.ParseFloat(arg.Vals[0], 64)
   730  		if err != nil {
   731  			return err
   732  		}
   733  
   734  	case "locality":
   735  		// Recombine multiple arguments, separated by commas.
   736  		locality := strings.Join(arg.Vals, ",")
   737  		err := f.Locality.Set(locality)
   738  		if err != nil {
   739  			return err
   740  		}
   741  
   742  	case "database":
   743  		if len(arg.Vals) != 1 {
   744  			return fmt.Errorf("database requires one argument")
   745  		}
   746  		f.Database = arg.Vals[0]
   747  
   748  	case "table":
   749  		if len(arg.Vals) != 1 {
   750  			return fmt.Errorf("table requires one argument")
   751  		}
   752  		f.Table = arg.Vals[0]
   753  
   754  	case "save-tables-prefix":
   755  		if len(arg.Vals) != 1 {
   756  			return fmt.Errorf("save-tables-prefix requires one argument")
   757  		}
   758  		f.SaveTablesPrefix = arg.Vals[0]
   759  
   760  	case "file":
   761  		if len(arg.Vals) != 1 {
   762  			return fmt.Errorf("file requires one argument")
   763  		}
   764  		f.File = arg.Vals[0]
   765  
   766  	case "cascade-levels":
   767  		if len(arg.Vals) != 1 {
   768  			return fmt.Errorf("cascade-levels requires a single argument")
   769  		}
   770  		levels, err := strconv.ParseInt(arg.Vals[0], 10, 64)
   771  		if err != nil {
   772  			return errors.Wrap(err, "cascade-levels")
   773  		}
   774  		f.CascadeLevels = int(levels)
   775  
   776  	default:
   777  		return fmt.Errorf("unknown argument: %s", arg.Key)
   778  	}
   779  	return nil
   780  }
   781  
   782  // OptBuild constructs an opt expression tree for the SQL query, with no
   783  // transformations applied to it. The untouched output of the optbuilder is the
   784  // final expression tree.
   785  func (ot *OptTester) OptBuild() (opt.Expr, error) {
   786  	o := ot.makeOptimizer()
   787  	o.DisableOptimizations()
   788  	return ot.optimizeExpr(o)
   789  }
   790  
   791  // OptNorm constructs an opt expression tree for the SQL query, with all
   792  // normalization transformations applied to it. The normalized output of the
   793  // optbuilder is the final expression tree.
   794  func (ot *OptTester) OptNorm() (opt.Expr, error) {
   795  	o := ot.makeOptimizer()
   796  	o.NotifyOnMatchedRule(func(ruleName opt.RuleName) bool {
   797  		if !ruleName.IsNormalize() {
   798  			return false
   799  		}
   800  		if ot.Flags.DisableRules.Contains(int(ruleName)) {
   801  			return false
   802  		}
   803  		ot.seenRules.Add(int(ruleName))
   804  		return true
   805  	})
   806  	return ot.optimizeExpr(o)
   807  }
   808  
   809  // Optimize constructs an opt expression tree for the SQL query, with all
   810  // transformations applied to it. The result is the memo expression tree with
   811  // the lowest estimated cost.
   812  func (ot *OptTester) Optimize() (opt.Expr, error) {
   813  	o := ot.makeOptimizer()
   814  	o.NotifyOnMatchedRule(func(ruleName opt.RuleName) bool {
   815  		if ot.Flags.DisableRules.Contains(int(ruleName)) {
   816  			return false
   817  		}
   818  		ot.seenRules.Add(int(ruleName))
   819  		return true
   820  	})
   821  	return ot.optimizeExpr(o)
   822  }
   823  
   824  // Memo returns a string that shows the memo data structure that is constructed
   825  // by the optimizer.
   826  func (ot *OptTester) Memo() (string, error) {
   827  	var o xform.Optimizer
   828  	o.Init(&ot.evalCtx, ot.catalog)
   829  	if _, err := ot.optimizeExpr(&o); err != nil {
   830  		return "", err
   831  	}
   832  	return o.FormatMemo(ot.Flags.MemoFormat), nil
   833  }
   834  
   835  // Expr parses the input directly into an expression; see exprgen.Build.
   836  func (ot *OptTester) Expr() (opt.Expr, error) {
   837  	var f norm.Factory
   838  	f.Init(&ot.evalCtx, ot.catalog)
   839  	f.DisableOptimizations()
   840  
   841  	return exprgen.Build(ot.catalog, &f, ot.sql)
   842  }
   843  
   844  // ExprNorm parses the input directly into an expression and runs
   845  // normalization; see exprgen.Build.
   846  func (ot *OptTester) ExprNorm() (opt.Expr, error) {
   847  	var f norm.Factory
   848  	f.Init(&ot.evalCtx, ot.catalog)
   849  
   850  	f.NotifyOnMatchedRule(func(ruleName opt.RuleName) bool {
   851  		// exprgen.Build doesn't run optimization, so we don't need to explicitly
   852  		// disallow exploration rules here.
   853  
   854  		if ot.Flags.DisableRules.Contains(int(ruleName)) {
   855  			return false
   856  		}
   857  		ot.seenRules.Add(int(ruleName))
   858  		return true
   859  	})
   860  
   861  	return exprgen.Build(ot.catalog, &f, ot.sql)
   862  }
   863  
   864  // RuleStats performs the optimization and returns statistics about how many
   865  // rules were applied.
   866  func (ot *OptTester) RuleStats() (string, error) {
   867  	type ruleStats struct {
   868  		rule       opt.RuleName
   869  		numApplied int
   870  		numAdded   int
   871  	}
   872  	stats := make([]ruleStats, opt.NumRuleNames)
   873  	for i := range stats {
   874  		stats[i].rule = opt.RuleName(i)
   875  	}
   876  
   877  	o := ot.makeOptimizer()
   878  	o.NotifyOnAppliedRule(
   879  		func(ruleName opt.RuleName, source, target opt.Expr) {
   880  			stats[ruleName].numApplied++
   881  			if target != nil {
   882  				stats[ruleName].numAdded++
   883  				if rel, ok := target.(memo.RelExpr); ok {
   884  					for {
   885  						rel = rel.NextExpr()
   886  						if rel == nil {
   887  							break
   888  						}
   889  						stats[ruleName].numAdded++
   890  					}
   891  				}
   892  			}
   893  		},
   894  	)
   895  	if _, err := ot.optimizeExpr(o); err != nil {
   896  		return "", err
   897  	}
   898  
   899  	// Split the rules.
   900  	var norm, explore []ruleStats
   901  	var allNorm, allExplore ruleStats
   902  	for i := range stats {
   903  		if stats[i].numApplied > 0 {
   904  			if stats[i].rule.IsNormalize() {
   905  				allNorm.numApplied += stats[i].numApplied
   906  				norm = append(norm, stats[i])
   907  			} else {
   908  				allExplore.numApplied += stats[i].numApplied
   909  				allExplore.numAdded += stats[i].numAdded
   910  				explore = append(explore, stats[i])
   911  			}
   912  		}
   913  	}
   914  	// Sort with most applied rules first.
   915  	sort.SliceStable(norm, func(i, j int) bool {
   916  		return norm[i].numApplied > norm[j].numApplied
   917  	})
   918  	sort.SliceStable(explore, func(i, j int) bool {
   919  		return explore[i].numApplied > explore[j].numApplied
   920  	})
   921  
   922  	// Only show the top 5 rules.
   923  	const topK = 5
   924  	if len(norm) > topK {
   925  		norm = norm[:topK]
   926  	}
   927  	if len(explore) > topK {
   928  		explore = explore[:topK]
   929  	}
   930  
   931  	// Ready to report.
   932  	var res strings.Builder
   933  	fmt.Fprintf(&res, "Normalization rules applied %d times.\n", allNorm.numApplied)
   934  	if len(norm) > 0 {
   935  		fmt.Fprintf(&res, "Top normalization rules:\n")
   936  		tw := tabwriter.NewWriter(&res, 1 /* minwidth */, 1 /* tabwidth */, 1 /* padding */, ' ', 0)
   937  		for _, s := range norm {
   938  			fmt.Fprintf(tw, "  %s\tapplied\t%d\ttimes.\n", s.rule, s.numApplied)
   939  		}
   940  		_ = tw.Flush()
   941  	}
   942  
   943  	fmt.Fprintf(
   944  		&res, "Exploration rules applied %d times, added %d expressions.\n",
   945  		allExplore.numApplied, allExplore.numAdded,
   946  	)
   947  
   948  	if len(explore) > 0 {
   949  		fmt.Fprintf(&res, "Top exploration rules:\n")
   950  		tw := tabwriter.NewWriter(&res, 1 /* minwidth */, 1 /* tabwidth */, 1 /* padding */, ' ', 0)
   951  		for _, s := range explore {
   952  			fmt.Fprintf(
   953  				tw, "  %s\tapplied\t%d\ttimes, added\t%d\texpressions.\n", s.rule, s.numApplied, s.numAdded,
   954  			)
   955  		}
   956  		_ = tw.Flush()
   957  	}
   958  	return res.String(), nil
   959  }
   960  
   961  // OptSteps steps through the transformations performed by the optimizer on the
   962  // memo, one-by-one. The output of each step is the lowest cost expression tree
   963  // that also contains the expressions that were changed or added by the
   964  // transformation. The output of each step is diff'd against the output of a
   965  // previous step, using the standard unified diff format.
   966  //
   967  //   CREATE TABLE a (x INT PRIMARY KEY, y INT, UNIQUE INDEX (y))
   968  //
   969  //   SELECT x FROM a WHERE x=1
   970  //
   971  // At the time of this writing, this query triggers 6 rule applications:
   972  //   EnsureSelectFilters     Wrap Select predicate with Filters operator
   973  //   FilterUnusedSelectCols  Do not return unused "y" column from Scan
   974  //   EliminateProject        Remove unneeded Project operator
   975  //   GenerateIndexScans      Explore scanning "y" index to get "x" values
   976  //   ConstrainScan           Explore pushing "x=1" into "x" index Scan
   977  //   ConstrainScan           Explore pushing "x=1" into "y" index Scan
   978  //
   979  // Some steps produce better plans that have a lower execution cost. Other steps
   980  // don't. However, it's useful to see both kinds of steps. The optsteps output
   981  // distinguishes these two cases by using stronger "====" header delimiters when
   982  // a better plan has been found, and weaker "----" header delimiters when not.
   983  // In both cases, the output shows the expressions that were changed or added by
   984  // the rule, even if the total expression tree cost worsened.
   985  //
   986  func (ot *OptTester) OptSteps() (string, error) {
   987  	var prevBest, prev, next string
   988  	ot.builder.Reset()
   989  
   990  	os := newOptSteps(ot)
   991  	for {
   992  		err := os.Next()
   993  		if err != nil {
   994  			return "", err
   995  		}
   996  
   997  		next = os.fo.o.FormatExpr(os.Root(), ot.Flags.ExprFormat)
   998  
   999  		// This call comes after setting "next", because we want to output the
  1000  		// final expression, even though there were no diffs from the previous
  1001  		// iteration.
  1002  		if os.Done() {
  1003  			break
  1004  		}
  1005  
  1006  		if prev == "" {
  1007  			// Output starting tree.
  1008  			ot.optStepsDisplay("", next, os)
  1009  			prevBest = next
  1010  		} else if next == prev || next == prevBest {
  1011  			ot.optStepsDisplay(next, next, os)
  1012  		} else if os.IsBetter() {
  1013  			// New expression is better than the previous expression. Diff
  1014  			// it against the previous *best* expression (might not be the
  1015  			// previous expression).
  1016  			ot.optStepsDisplay(prevBest, next, os)
  1017  			prevBest = next
  1018  		} else {
  1019  			// New expression is not better than the previous expression, but
  1020  			// still show the change. Diff it against the previous expression,
  1021  			// regardless if it was a "best" expression or not.
  1022  			ot.optStepsDisplay(prev, next, os)
  1023  		}
  1024  
  1025  		prev = next
  1026  	}
  1027  
  1028  	// Output ending tree.
  1029  	ot.optStepsDisplay(next, "", os)
  1030  
  1031  	return ot.builder.String(), nil
  1032  }
  1033  
  1034  func (ot *OptTester) optStepsDisplay(before string, after string, os *optSteps) {
  1035  	// bestHeader is used when the expression is an improvement over the previous
  1036  	// expression.
  1037  	bestHeader := func(e opt.Expr, format string, args ...interface{}) {
  1038  		ot.separator("=")
  1039  		ot.output(format, args...)
  1040  		if rel, ok := e.(memo.RelExpr); ok {
  1041  			ot.output("  Cost: %.2f\n", rel.Cost())
  1042  		} else {
  1043  			ot.output("\n")
  1044  		}
  1045  		ot.separator("=")
  1046  	}
  1047  
  1048  	// altHeader is used when the expression doesn't improve over the previous
  1049  	// expression, but it's still desirable to see what changed.
  1050  	altHeader := func(format string, args ...interface{}) {
  1051  		ot.separator("-")
  1052  		ot.output(format, args...)
  1053  		ot.separator("-")
  1054  	}
  1055  
  1056  	if before == "" {
  1057  		if ot.Flags.Verbose {
  1058  			fmt.Print("------ optsteps verbose output starts ------\n")
  1059  		}
  1060  		bestHeader(os.Root(), "Initial expression\n")
  1061  		ot.indent(after)
  1062  		return
  1063  	}
  1064  
  1065  	if before == after {
  1066  		altHeader("%s (no changes)\n", os.LastRuleName())
  1067  		return
  1068  	}
  1069  
  1070  	if after == "" {
  1071  		bestHeader(os.Root(), "Final best expression\n")
  1072  		ot.indent(before)
  1073  
  1074  		if ot.Flags.Verbose {
  1075  			fmt.Print("------ optsteps verbose output ends ------\n")
  1076  		}
  1077  		return
  1078  	}
  1079  
  1080  	var diff difflib.UnifiedDiff
  1081  	if os.IsBetter() {
  1082  		// New expression is better than the previous expression. Diff
  1083  		// it against the previous *best* expression (might not be the
  1084  		// previous expression).
  1085  		bestHeader(os.Root(), "%s\n", os.LastRuleName())
  1086  	} else {
  1087  		altHeader("%s (higher cost)\n", os.LastRuleName())
  1088  	}
  1089  
  1090  	diff = difflib.UnifiedDiff{
  1091  		A:       difflib.SplitLines(before),
  1092  		B:       difflib.SplitLines(after),
  1093  		Context: 100,
  1094  	}
  1095  	text, _ := difflib.GetUnifiedDiffString(diff)
  1096  	// Skip the "@@ ... @@" header (first line).
  1097  	text = strings.SplitN(text, "\n", 2)[1]
  1098  	ot.indent(text)
  1099  }
  1100  
  1101  // ExploreTrace steps through exploration transformations performed by the
  1102  // optimizer, one-by-one. The output of each step is the expression on which the
  1103  // rule was applied, and the expressions that were generated by the rule.
  1104  func (ot *OptTester) ExploreTrace() (string, error) {
  1105  	ot.builder.Reset()
  1106  
  1107  	et := newExploreTracer(ot)
  1108  
  1109  	for step := 0; ; step++ {
  1110  		if step > 2000 {
  1111  			ot.output("step limit reached\n")
  1112  			break
  1113  		}
  1114  		err := et.Next()
  1115  		if err != nil {
  1116  			return "", err
  1117  		}
  1118  		if et.Done() {
  1119  			break
  1120  		}
  1121  
  1122  		if ot.Flags.ExploreTraceRule != opt.InvalidRuleName &&
  1123  			et.LastRuleName() != ot.Flags.ExploreTraceRule {
  1124  			continue
  1125  		}
  1126  		newNodes := et.NewExprs()
  1127  		if ot.Flags.ExploreTraceSkipNoop && len(newNodes) == 0 {
  1128  			continue
  1129  		}
  1130  
  1131  		if ot.builder.Len() > 0 {
  1132  			ot.output("\n")
  1133  		}
  1134  		ot.separator("=")
  1135  		ot.output("%s\n", et.LastRuleName())
  1136  		ot.separator("=")
  1137  		ot.output("Source expression:\n")
  1138  		ot.indent(et.fo.o.FormatExpr(et.SrcExpr(), ot.Flags.ExprFormat))
  1139  		if len(newNodes) == 0 {
  1140  			ot.output("\nNo new expressions.\n")
  1141  		}
  1142  		for i := range newNodes {
  1143  			ot.output("\nNew expression %d of %d:\n", i+1, len(newNodes))
  1144  			ot.indent(memo.FormatExpr(newNodes[i], ot.Flags.ExprFormat, et.fo.o.Memo(), ot.catalog))
  1145  		}
  1146  	}
  1147  	return ot.builder.String(), nil
  1148  }
  1149  
  1150  // Stats compares the estimated statistics of a relational expression with
  1151  // actual statistics collected from running CREATE STATISTICS on the output
  1152  // of the relational expression. If the -rewrite-actual-stats flag is
  1153  // used, the actual stats are recalculated.
  1154  func (ot *OptTester) Stats(tb testing.TB, d *datadriven.TestData) (string, error) {
  1155  	if ot.Flags.Table == "" {
  1156  		tb.Fatal("table not specified")
  1157  	}
  1158  	catalog, ok := ot.catalog.(*testcat.Catalog)
  1159  	if !ok {
  1160  		return "", fmt.Errorf("stats can only be used with TestCatalog")
  1161  	}
  1162  
  1163  	st := statsTester{}
  1164  	return st.testStats(catalog, d, ot.Flags.Table)
  1165  }
  1166  
  1167  // Import imports a file containing exec-ddl commands in order to add tables
  1168  // and/or stats to the catalog. This allows commonly-used schemas such as
  1169  // TPC-C or TPC-H to be used by multiple test files without copying the schemas
  1170  // and stats multiple times.
  1171  func (ot *OptTester) Import(tb testing.TB) {
  1172  	if ot.Flags.File == "" {
  1173  		tb.Fatal("file not specified")
  1174  	}
  1175  	// Find the file to be imported in opttester/testfixtures.
  1176  	_, optTesterFile, _, ok := runtime.Caller(1)
  1177  	if !ok {
  1178  		tb.Fatalf("unable to find file %s", ot.Flags.File)
  1179  	}
  1180  	path := filepath.Join(filepath.Dir(optTesterFile), "testfixtures", ot.Flags.File)
  1181  	datadriven.RunTest(tb.(*testing.T), path, func(t *testing.T, d *datadriven.TestData) string {
  1182  		tester := New(ot.catalog, d.Input)
  1183  		return tester.RunCommand(t, d)
  1184  	})
  1185  }
  1186  
  1187  // InjectStats constructs and executes an ALTER TABLE INJECT STATISTICS
  1188  // statement using the statistics in a separate json file.
  1189  func (ot *OptTester) InjectStats(tb testing.TB, d *datadriven.TestData) {
  1190  	if ot.Flags.File == "" {
  1191  		tb.Fatal("file not specified")
  1192  	}
  1193  	if ot.Flags.Table == "" {
  1194  		tb.Fatal("table not specified")
  1195  	}
  1196  	// We get the file path from the Pos string which always of the form
  1197  	// "file:linenum".
  1198  	testfilePath := strings.SplitN(d.Pos, ":", 1)[0]
  1199  	path := filepath.Join(filepath.Dir(testfilePath), ot.Flags.File)
  1200  	stats, err := ioutil.ReadFile(path)
  1201  	if err != nil {
  1202  		tb.Fatalf("error reading %s: %v", path, err)
  1203  	}
  1204  	stmt := fmt.Sprintf(
  1205  		"ALTER TABLE %s INJECT STATISTICS '%s'",
  1206  		ot.Flags.Table,
  1207  		strings.Replace(string(stats), "'", "''", -1),
  1208  	)
  1209  	testCatalog, ok := ot.catalog.(*testcat.Catalog)
  1210  	if !ok {
  1211  		d.Fatalf(tb, "inject-stats can only be used with TestCatalog")
  1212  	}
  1213  	_, err = testCatalog.ExecuteDDL(stmt)
  1214  	if err != nil {
  1215  		d.Fatalf(tb, "%v", err)
  1216  	}
  1217  }
  1218  
  1219  // SaveTables optimizes the given query and saves the subexpressions as tables
  1220  // in the test catalog with their estimated statistics injected.
  1221  // If rewriteActualStats=true, it also executes the given query against a
  1222  // running database and saves the intermediate results as tables.
  1223  func (ot *OptTester) SaveTables() (opt.Expr, error) {
  1224  	if *rewriteActualStats {
  1225  		if err := ot.saveActualTables(); err != nil {
  1226  			return nil, err
  1227  		}
  1228  	}
  1229  
  1230  	expr, err := ot.Optimize()
  1231  	if err != nil {
  1232  		return nil, err
  1233  	}
  1234  
  1235  	// Create a table in the test catalog for each relational expression in the
  1236  	// tree.
  1237  	nameGen := memo.NewExprNameGenerator(ot.Flags.SaveTablesPrefix)
  1238  	var traverse func(e opt.Expr) error
  1239  	traverse = func(e opt.Expr) error {
  1240  		if r, ok := e.(memo.RelExpr); ok {
  1241  			// GenerateName is called in a pre-order traversal of the query tree.
  1242  			tabName := nameGen.GenerateName(e.Op())
  1243  			_, err := ot.createTableAs(tree.MakeUnqualifiedTableName(tree.Name(tabName)), r)
  1244  			if err != nil {
  1245  				return err
  1246  			}
  1247  		}
  1248  		for i, n := 0, e.ChildCount(); i < n; i++ {
  1249  			if err := traverse(e.Child(i)); err != nil {
  1250  				return err
  1251  			}
  1252  		}
  1253  		return nil
  1254  	}
  1255  	if err := traverse(expr); err != nil {
  1256  		return nil, err
  1257  	}
  1258  
  1259  	return expr, nil
  1260  }
  1261  
  1262  // saveActualTables executes the given query against a running database and
  1263  // saves the intermediate results as tables.
  1264  func (ot *OptTester) saveActualTables() error {
  1265  	db, err := gosql.Open("postgres", *pgurl)
  1266  	if err != nil {
  1267  		return errors.Wrap(err,
  1268  			"can only execute a statement when pointed at a running Cockroach cluster",
  1269  		)
  1270  	}
  1271  
  1272  	ctx := context.Background()
  1273  	c, err := db.Conn(ctx)
  1274  	if err != nil {
  1275  		return err
  1276  	}
  1277  
  1278  	if _, err := c.ExecContext(ctx,
  1279  		fmt.Sprintf("DROP DATABASE IF EXISTS %s CASCADE", opt.SaveTablesDatabase),
  1280  	); err != nil {
  1281  		return err
  1282  	}
  1283  
  1284  	if _, err := c.ExecContext(ctx,
  1285  		fmt.Sprintf("CREATE DATABASE %s", opt.SaveTablesDatabase),
  1286  	); err != nil {
  1287  		return err
  1288  	}
  1289  
  1290  	if _, err := c.ExecContext(ctx, fmt.Sprintf("USE %s", ot.Flags.Database)); err != nil {
  1291  		return err
  1292  	}
  1293  
  1294  	if _, err := c.ExecContext(ctx,
  1295  		fmt.Sprintf("SET save_tables_prefix = '%s'", ot.Flags.SaveTablesPrefix),
  1296  	); err != nil {
  1297  		return err
  1298  	}
  1299  
  1300  	if _, err := c.ExecContext(ctx, ot.sql); err != nil {
  1301  		return err
  1302  	}
  1303  
  1304  	return nil
  1305  }
  1306  
  1307  // createTableAs creates a table in the test catalog based on the output
  1308  // of the given relational expression. It also injects the estimated stats
  1309  // for the relational expression into the catalog table. It returns a pointer
  1310  // to the new table.
  1311  func (ot *OptTester) createTableAs(name tree.TableName, rel memo.RelExpr) (*testcat.Table, error) {
  1312  	catalog, ok := ot.catalog.(*testcat.Catalog)
  1313  	if !ok {
  1314  		return nil, fmt.Errorf("createTableAs can only be used with TestCatalog")
  1315  	}
  1316  
  1317  	relProps := rel.Relational()
  1318  	outputCols := relProps.OutputCols
  1319  	colNameGen := memo.NewColumnNameGenerator(rel)
  1320  
  1321  	// Create each of the columns and their estimated stats for the test catalog
  1322  	// table.
  1323  	columns := make([]*testcat.Column, outputCols.Len())
  1324  	jsonStats := make([]stats.JSONStatistic, outputCols.Len())
  1325  	i := 0
  1326  	for col, ok := outputCols.Next(0); ok; col, ok = outputCols.Next(col + 1) {
  1327  		colMeta := rel.Memo().Metadata().ColumnMeta(col)
  1328  		colName := colNameGen.GenerateName(col)
  1329  
  1330  		columns[i] = &testcat.Column{
  1331  			Ordinal:  i,
  1332  			Name:     colName,
  1333  			Type:     colMeta.Type,
  1334  			Nullable: !relProps.NotNullCols.Contains(col),
  1335  		}
  1336  
  1337  		// Make sure we have estimated stats for this column.
  1338  		colSet := opt.MakeColSet(col)
  1339  		memo.RequestColStat(&ot.evalCtx, rel, colSet)
  1340  		stat, ok := relProps.Stats.ColStats.Lookup(colSet)
  1341  		if !ok {
  1342  			return nil, fmt.Errorf("could not find statistic for column %s", colName)
  1343  		}
  1344  		jsonStats[i] = ot.makeStat(
  1345  			[]string{colName},
  1346  			uint64(int64(math.Round(relProps.Stats.RowCount))),
  1347  			uint64(int64(math.Round(stat.DistinctCount))),
  1348  			uint64(int64(math.Round(stat.NullCount))),
  1349  		)
  1350  
  1351  		i++
  1352  	}
  1353  
  1354  	tab := catalog.CreateTableAs(name, columns)
  1355  	if err := ot.injectStats(name, jsonStats); err != nil {
  1356  		return nil, err
  1357  	}
  1358  	return tab, nil
  1359  }
  1360  
  1361  // injectStats injects statistics into the given table in the test catalog.
  1362  func (ot *OptTester) injectStats(name tree.TableName, jsonStats []stats.JSONStatistic) error {
  1363  	catalog, ok := ot.catalog.(*testcat.Catalog)
  1364  	if !ok {
  1365  		return fmt.Errorf("injectStats can only be used with TestCatalog")
  1366  	}
  1367  
  1368  	encoded, err := json.Marshal(jsonStats)
  1369  	if err != nil {
  1370  		return err
  1371  	}
  1372  	alterStmt := fmt.Sprintf("ALTER TABLE %s INJECT STATISTICS '%s'", name.String(), encoded)
  1373  	stmt, err := parser.ParseOne(alterStmt)
  1374  	if err != nil {
  1375  		return err
  1376  	}
  1377  	catalog.AlterTable(stmt.AST.(*tree.AlterTable))
  1378  	return nil
  1379  }
  1380  
  1381  // makeStat creates a JSONStatistic for the given columns, rowCount,
  1382  // distinctCount, and nullCount.
  1383  func (ot *OptTester) makeStat(
  1384  	columns []string, rowCount, distinctCount, nullCount uint64,
  1385  ) stats.JSONStatistic {
  1386  	return stats.JSONStatistic{
  1387  		Name: stats.AutoStatsName,
  1388  		CreatedAt: tree.AsStringWithFlags(
  1389  			&tree.DTimestamp{Time: timeutil.Now()}, tree.FmtBareStrings,
  1390  		),
  1391  		Columns:       columns,
  1392  		RowCount:      rowCount,
  1393  		DistinctCount: distinctCount,
  1394  		NullCount:     nullCount,
  1395  	}
  1396  }
  1397  
  1398  func (ot *OptTester) buildExpr(factory *norm.Factory) error {
  1399  	stmt, err := parser.ParseOne(ot.sql)
  1400  	if err != nil {
  1401  		return err
  1402  	}
  1403  	if err := ot.semaCtx.Placeholders.Init(stmt.NumPlaceholders, nil /* typeHints */); err != nil {
  1404  		return err
  1405  	}
  1406  	ot.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations)
  1407  	b := optbuilder.New(ot.ctx, &ot.semaCtx, &ot.evalCtx, ot.catalog, factory, stmt.AST)
  1408  	return b.Build()
  1409  }
  1410  
  1411  func (ot *OptTester) makeOptimizer() *xform.Optimizer {
  1412  	var o xform.Optimizer
  1413  	o.Init(&ot.evalCtx, ot.catalog)
  1414  	return &o
  1415  }
  1416  
  1417  func (ot *OptTester) optimizeExpr(o *xform.Optimizer) (opt.Expr, error) {
  1418  	err := ot.buildExpr(o.Factory())
  1419  	if err != nil {
  1420  		return nil, err
  1421  	}
  1422  	root, err := o.Optimize()
  1423  	if err != nil {
  1424  		return nil, err
  1425  	}
  1426  	if ot.Flags.PerturbCost != 0 {
  1427  		o.RecomputeCost()
  1428  	}
  1429  	return root, nil
  1430  }
  1431  
  1432  func (ot *OptTester) output(format string, args ...interface{}) {
  1433  	fmt.Fprintf(&ot.builder, format, args...)
  1434  	if ot.Flags.Verbose {
  1435  		fmt.Printf(format, args...)
  1436  	}
  1437  }
  1438  
  1439  func (ot *OptTester) separator(sep string) {
  1440  	ot.output("%s\n", strings.Repeat(sep, 80))
  1441  }
  1442  
  1443  func (ot *OptTester) indent(str string) {
  1444  	str = strings.TrimRight(str, " \n\t\r")
  1445  	lines := strings.Split(str, "\n")
  1446  	for _, line := range lines {
  1447  		ot.output("  %s\n", line)
  1448  	}
  1449  }
  1450  
  1451  // ruleFromString returns the rule that matches the given string,
  1452  // or InvalidRuleName if there is no such rule.
  1453  func ruleFromString(str string) (opt.RuleName, error) {
  1454  	for i := opt.RuleName(1); i < opt.NumRuleNames; i++ {
  1455  		if i.String() == str {
  1456  			return i, nil
  1457  		}
  1458  	}
  1459  
  1460  	return opt.InvalidRuleName, fmt.Errorf("rule '%s' does not exist", str)
  1461  }