github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/analyzer.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"os"
    20  	"reflect"
    21  	"strings"
    22  
    23  	"github.com/pmezard/go-difflib/difflib"
    24  	"github.com/sirupsen/logrus"
    25  	"go.opentelemetry.io/otel/attribute"
    26  	"gopkg.in/src-d/go-errors.v1"
    27  
    28  	"github.com/dolthub/go-mysql-server/sql"
    29  	"github.com/dolthub/go-mysql-server/sql/memo"
    30  	"github.com/dolthub/go-mysql-server/sql/plan"
    31  	"github.com/dolthub/go-mysql-server/sql/rowexec"
    32  	"github.com/dolthub/go-mysql-server/sql/transform"
    33  )
    34  
    35  const debugAnalyzerKey = "DEBUG_ANALYZER"
    36  const verboseAnalyzerKey = "VERBOSE_ANALYZER"
    37  
    38  const maxAnalysisIterations = 8
    39  
    40  // ErrMaxAnalysisIters is thrown when the analysis iterations are exceeded
    41  var ErrMaxAnalysisIters = errors.NewKind("exceeded max analysis iterations (%d)")
    42  
    43  // ErrInAnalysis is thrown for generic analyzer errors
    44  var ErrInAnalysis = errors.NewKind("error in analysis: %s")
    45  
    46  // ErrInvalidNodeType is thrown when the analyzer can't handle a particular kind of node type
    47  var ErrInvalidNodeType = errors.NewKind("%s: invalid node of type: %T")
    48  
    49  const disablePrepareStmtKey = "DISABLE_PREPARED_STATEMENTS"
    50  
    51  var PreparedStmtDisabled bool
    52  
    53  func init() {
    54  	if v := os.Getenv(disablePrepareStmtKey); v != "" {
    55  		PreparedStmtDisabled = true
    56  	}
    57  }
    58  
    59  func SetPreparedStmts(v bool) {
    60  	PreparedStmtDisabled = v
    61  }
    62  
    63  // Builder provides an easy way to generate Analyzer with custom rules and options.
    64  type Builder struct {
    65  	preAnalyzeRules     []Rule
    66  	postAnalyzeRules    []Rule
    67  	preValidationRules  []Rule
    68  	postValidationRules []Rule
    69  	onceBeforeRules     []Rule
    70  	defaultRules        []Rule
    71  	onceAfterRules      []Rule
    72  	validationRules     []Rule
    73  	afterAllRules       []Rule
    74  	provider            sql.DatabaseProvider
    75  	debug               bool
    76  	parallelism         int
    77  }
    78  
    79  // NewBuilder creates a new Builder from a specific catalog.
    80  // This builder allow us add custom Rules and modify some internal properties.
    81  func NewBuilder(pro sql.DatabaseProvider) *Builder {
    82  	return &Builder{
    83  		provider:        pro,
    84  		onceBeforeRules: OnceBeforeDefault,
    85  		defaultRules:    DefaultRules,
    86  		onceAfterRules:  OnceAfterDefault,
    87  		validationRules: DefaultValidationRules,
    88  		afterAllRules:   OnceAfterAll,
    89  	}
    90  }
    91  
    92  // WithDebug activates debug on the Analyzer.
    93  func (ab *Builder) WithDebug() *Builder {
    94  	ab.debug = true
    95  
    96  	return ab
    97  }
    98  
    99  // WithParallelism sets the parallelism level on the analyzer.
   100  func (ab *Builder) WithParallelism(parallelism int) *Builder {
   101  	ab.parallelism = parallelism
   102  	return ab
   103  }
   104  
   105  // AddPreAnalyzeRule adds a new rule to the analyze before the standard analyzer rules.
   106  func (ab *Builder) AddPreAnalyzeRule(id RuleId, fn RuleFunc) *Builder {
   107  	ab.preAnalyzeRules = append(ab.preAnalyzeRules, Rule{id, fn})
   108  
   109  	return ab
   110  }
   111  
   112  // AddPostAnalyzeRule adds a new rule to the analyzer after standard analyzer rules.
   113  func (ab *Builder) AddPostAnalyzeRule(id RuleId, fn RuleFunc) *Builder {
   114  	ab.postAnalyzeRules = append(ab.postAnalyzeRules, Rule{id, fn})
   115  
   116  	return ab
   117  }
   118  
   119  // AddPreValidationRule adds a new rule to the analyzer before standard validation rules.
   120  func (ab *Builder) AddPreValidationRule(id RuleId, fn RuleFunc) *Builder {
   121  	ab.preValidationRules = append(ab.preValidationRules, Rule{id, fn})
   122  
   123  	return ab
   124  }
   125  
   126  // AddPostValidationRule adds a new rule to the analyzer after standard validation rules.
   127  func (ab *Builder) AddPostValidationRule(id RuleId, fn RuleFunc) *Builder {
   128  	ab.postValidationRules = append(ab.postValidationRules, Rule{id, fn})
   129  
   130  	return ab
   131  }
   132  
   133  func duplicateRulesWithout(rules []Rule, excludedRuleId RuleId) []Rule {
   134  	newRules := make([]Rule, 0, len(rules))
   135  
   136  	for _, rule := range rules {
   137  		if rule.Id != excludedRuleId {
   138  			newRules = append(newRules, rule)
   139  		}
   140  	}
   141  
   142  	return newRules
   143  }
   144  
   145  // RemoveOnceBeforeRule removes a default rule from the analyzer which would occur before other rules
   146  func (ab *Builder) RemoveOnceBeforeRule(id RuleId) *Builder {
   147  	ab.onceBeforeRules = duplicateRulesWithout(ab.onceBeforeRules, id)
   148  
   149  	return ab
   150  }
   151  
   152  // RemoveDefaultRule removes a default rule from the analyzer that is executed as part of the analysis
   153  func (ab *Builder) RemoveDefaultRule(id RuleId) *Builder {
   154  	ab.defaultRules = duplicateRulesWithout(ab.defaultRules, id)
   155  
   156  	return ab
   157  }
   158  
   159  // RemoveOnceAfterRule removes a default rule from the analyzer which would occur just once after the default analysis
   160  func (ab *Builder) RemoveOnceAfterRule(id RuleId) *Builder {
   161  	ab.onceAfterRules = duplicateRulesWithout(ab.onceAfterRules, id)
   162  
   163  	return ab
   164  }
   165  
   166  // RemoveValidationRule removes a default rule from the analyzer which would occur as part of the validation rules
   167  func (ab *Builder) RemoveValidationRule(id RuleId) *Builder {
   168  	ab.validationRules = duplicateRulesWithout(ab.validationRules, id)
   169  
   170  	return ab
   171  }
   172  
   173  // RemoveAfterAllRule removes a default rule from the analyzer which would occur after all other rules
   174  func (ab *Builder) RemoveAfterAllRule(id RuleId) *Builder {
   175  	ab.afterAllRules = duplicateRulesWithout(ab.afterAllRules, id)
   176  
   177  	return ab
   178  }
   179  
   180  var log = logrus.New()
   181  
   182  func init() {
   183  	// TODO: give the option for debug analyzer logging format to match the global one
   184  	log.SetFormatter(simpleLogFormatter{})
   185  }
   186  
   187  type simpleLogFormatter struct{}
   188  
   189  func (s simpleLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
   190  	lvl := ""
   191  	switch entry.Level {
   192  	case logrus.PanicLevel:
   193  		lvl = "PANIC"
   194  	case logrus.FatalLevel:
   195  		lvl = "FATAL"
   196  	case logrus.ErrorLevel:
   197  		lvl = "ERROR"
   198  	case logrus.WarnLevel:
   199  		lvl = "WARN"
   200  	case logrus.InfoLevel:
   201  		lvl = "INFO"
   202  	case logrus.DebugLevel:
   203  		lvl = "DEBUG"
   204  	case logrus.TraceLevel:
   205  		lvl = "TRACE"
   206  	}
   207  
   208  	msg := fmt.Sprintf("%s: %s\n", lvl, entry.Message)
   209  	return ([]byte)(msg), nil
   210  }
   211  
   212  // Build creates a new Analyzer from the builder parameters
   213  func (ab *Builder) Build() *Analyzer {
   214  	_, debug := os.LookupEnv(debugAnalyzerKey)
   215  	_, verbose := os.LookupEnv(verboseAnalyzerKey)
   216  	var batches = []*Batch{
   217  		{
   218  			Desc:       "pre-analyzer",
   219  			Iterations: maxAnalysisIterations,
   220  			Rules:      ab.preAnalyzeRules,
   221  		},
   222  		{
   223  			Desc:       "once-before",
   224  			Iterations: 1,
   225  			Rules:      ab.onceBeforeRules,
   226  		},
   227  		{
   228  			Desc:       "default-rules",
   229  			Iterations: maxAnalysisIterations,
   230  			Rules:      ab.defaultRules,
   231  		},
   232  		{
   233  			Desc:       "once-after",
   234  			Iterations: 1,
   235  			Rules:      ab.onceAfterRules,
   236  		},
   237  		{
   238  			Desc:       "post-analyzer",
   239  			Iterations: maxAnalysisIterations,
   240  			Rules:      ab.postAnalyzeRules,
   241  		},
   242  		{
   243  			Desc:       "pre-validation",
   244  			Iterations: 1,
   245  			Rules:      ab.preValidationRules,
   246  		},
   247  		{
   248  			Desc:       "validation",
   249  			Iterations: 1,
   250  			Rules:      ab.validationRules,
   251  		},
   252  		{
   253  			Desc:       "post-validation",
   254  			Iterations: 1,
   255  			Rules:      ab.postValidationRules,
   256  		},
   257  		{
   258  			Desc:       "after-all",
   259  			Iterations: 1,
   260  			Rules:      ab.afterAllRules,
   261  		},
   262  	}
   263  
   264  	return &Analyzer{
   265  		Debug:        debug || ab.debug,
   266  		Verbose:      verbose,
   267  		contextStack: make([]string, 0),
   268  		Batches:      batches,
   269  		Catalog:      NewCatalog(ab.provider),
   270  		Parallelism:  ab.parallelism,
   271  		Coster:       memo.NewDefaultCoster(),
   272  		ExecBuilder:  rowexec.DefaultBuilder,
   273  	}
   274  }
   275  
   276  // Analyzer analyzes nodes of the execution plan and applies rules and validations
   277  // to them.
   278  type Analyzer struct {
   279  	// Whether to log various debugging messages
   280  	Debug bool
   281  	// Whether to output the query plan at each step of the analyzer
   282  	Verbose bool
   283  	// A stack of debugger context. See PushDebugContext, PopDebugContext
   284  	contextStack []string
   285  	Parallelism  int
   286  	// Batches of Rules to apply.
   287  	Batches []*Batch
   288  	// Catalog of databases and registered functions.
   289  	Catalog *Catalog
   290  	// Coster estimates the incremental CPU+memory cost for execution operators.
   291  	Coster memo.Coster
   292  	// ExecBuilder converts a sql.Node tree into an executable iterator.
   293  	ExecBuilder sql.NodeExecBuilder
   294  	// EventScheduler is used to communiate with the event scheduler
   295  	// for any EVENT related statements. It can be nil if EventScheduler is not defined.
   296  	EventScheduler sql.EventScheduler
   297  }
   298  
   299  // NewDefault creates a default Analyzer instance with all default Rules and configuration.
   300  // To add custom rules, the easiest way is use the Builder.
   301  func NewDefault(provider sql.DatabaseProvider) *Analyzer {
   302  	return NewBuilder(provider).Build()
   303  
   304  }
   305  
   306  // NewDefaultWithVersion creates a default Analyzer instance either
   307  // experimental or
   308  func NewDefaultWithVersion(provider sql.DatabaseProvider) *Analyzer {
   309  	return NewBuilder(provider).Build()
   310  }
   311  
   312  // Log prints an INFO message to stdout with the given message and args
   313  // if the analyzer is in debug mode.
   314  func (a *Analyzer) Log(msg string, args ...interface{}) {
   315  	if a != nil && a.Debug {
   316  		if len(a.contextStack) > 0 {
   317  			ctx := strings.Join(a.contextStack, "/")
   318  			log.Infof("%s: "+msg, append([]interface{}{ctx}, args...)...)
   319  		} else {
   320  			log.Infof(msg, args...)
   321  		}
   322  	}
   323  }
   324  
   325  func (a *Analyzer) LogFn() func(string, ...any) {
   326  	return func(msg string, args ...interface{}) {
   327  		if a != nil && a.Debug {
   328  			if len(a.contextStack) > 0 {
   329  				ctx := strings.Join(a.contextStack, "/")
   330  				log.Infof("%s: "+msg, append([]interface{}{ctx}, args...)...)
   331  			} else {
   332  				log.Infof(msg, args...)
   333  			}
   334  		}
   335  	}
   336  }
   337  
   338  // LogNode prints the node given if Verbose logging is enabled.
   339  func (a *Analyzer) LogNode(n sql.Node) {
   340  	if a != nil && n != nil && a.Verbose {
   341  		if len(a.contextStack) > 0 {
   342  			ctx := strings.Join(a.contextStack, "/")
   343  			log.Infof("%s:\n%s", ctx, sql.DebugString(n))
   344  		} else {
   345  			log.Infof("%s", sql.DebugString(n))
   346  		}
   347  	}
   348  }
   349  
   350  // LogDiff logs the diff between the query plans after a transformation rules has been applied.
   351  // Only can print a diff when the string representations of the nodes differ, which isn't always the case.
   352  func (a *Analyzer) LogDiff(prev, next sql.Node) {
   353  	if a.Debug && a.Verbose {
   354  		if !reflect.DeepEqual(next, prev) {
   355  			diff, err := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
   356  				A:        difflib.SplitLines(sql.DebugString(prev)),
   357  				B:        difflib.SplitLines(sql.DebugString(next)),
   358  				FromFile: "Prev",
   359  				FromDate: "",
   360  				ToFile:   "Next",
   361  				ToDate:   "",
   362  				Context:  1,
   363  			})
   364  			if err != nil {
   365  				panic(err)
   366  			}
   367  			if len(diff) > 0 {
   368  				a.Log(diff)
   369  			} else {
   370  				a.Log("nodes are different, but no textual diff found (implement better DebugString?)")
   371  			}
   372  		}
   373  	}
   374  }
   375  
   376  // PushDebugContext pushes the given context string onto the context stack, to use when logging debug messages.
   377  func (a *Analyzer) PushDebugContext(msg string) {
   378  	if a != nil && a.Debug {
   379  		a.contextStack = append(a.contextStack, msg)
   380  	}
   381  }
   382  
   383  // PopDebugContext pops a context message off the context stack.
   384  func (a *Analyzer) PopDebugContext() {
   385  	if a != nil && len(a.contextStack) > 0 {
   386  		a.contextStack = a.contextStack[:len(a.contextStack)-1]
   387  	}
   388  }
   389  
   390  func SelectAllBatches(string) bool { return true }
   391  
   392  func DefaultRuleSelector(id RuleId) bool {
   393  	switch id {
   394  	// prepared statement rules are incompatible with default rules
   395  	case reresolveTablesId,
   396  		resolvePreparedInsertId:
   397  		return false
   398  	}
   399  	return true
   400  }
   401  
   402  func NewProcRuleSelector(sel RuleSelector) RuleSelector {
   403  	return func(id RuleId) bool {
   404  		switch id {
   405  		case pruneTablesId,
   406  			unnestInSubqueriesId,
   407  
   408  			// once after default rules should only be run once
   409  			AutocommitId,
   410  			TrackProcessId,
   411  			parallelizeId,
   412  			clearWarningsId:
   413  			return false
   414  		}
   415  		return sel(id)
   416  	}
   417  }
   418  
   419  func NewResolveSubqueryExprSelector(sel RuleSelector) RuleSelector {
   420  	return func(id RuleId) bool {
   421  		switch id {
   422  		case
   423  			// skip recursive finalize rules
   424  			hoistOutOfScopeFiltersId,
   425  			unnestExistsSubqueriesId,
   426  			unnestInSubqueriesId,
   427  			finalizeSubqueriesId,
   428  			assignExecIndexesId:
   429  			return false
   430  		}
   431  		return sel(id)
   432  	}
   433  }
   434  
   435  func NewFinalizeSubquerySel(sel RuleSelector) RuleSelector {
   436  	return func(id RuleId) bool {
   437  		switch id {
   438  		case
   439  			// skip recursive resolve rules
   440  			resolveSubqueryExprsId,
   441  			resolveSubqueriesId,
   442  			resolveUnionsId,
   443  			// skip redundant finalize rules
   444  			finalizeSubqueriesId,
   445  			hoistOutOfScopeFiltersId,
   446  			cacheSubqueryResultsId,
   447  			TrackProcessId,
   448  			assignExecIndexesId:
   449  			return false
   450  		}
   451  		return sel(id)
   452  	}
   453  }
   454  
   455  func NewFinalizeUnionSel(sel RuleSelector) RuleSelector {
   456  	return func(id RuleId) bool {
   457  		switch id {
   458  		case
   459  			// skip recursive resolve rules
   460  			resolveSubqueryExprsId,
   461  			resolveSubqueriesId,
   462  			resolveUnionsId,
   463  			parallelizeId:
   464  			return false
   465  		case finalizeSubqueriesId,
   466  			hoistOutOfScopeFiltersId:
   467  			return true
   468  		}
   469  		return sel(id)
   470  	}
   471  }
   472  
   473  func newInsertSourceSelector(sel RuleSelector) RuleSelector {
   474  	return func(id RuleId) bool {
   475  		switch id {
   476  		case unnestInSubqueriesId,
   477  			pushdownSubqueryAliasFiltersId:
   478  			return false
   479  		}
   480  		return sel(id)
   481  	}
   482  }
   483  
   484  // Analyze applies the transformation rules to the node given. In the case of an error, the last successfully
   485  // transformed node is returned along with the error.
   486  func (a *Analyzer) Analyze(ctx *sql.Context, n sql.Node, scope *plan.Scope) (sql.Node, error) {
   487  	n, _, err := a.analyzeWithSelector(ctx, n, scope, SelectAllBatches, DefaultRuleSelector)
   488  	return n, err
   489  }
   490  
   491  func (a *Analyzer) analyzeThroughBatch(ctx *sql.Context, n sql.Node, scope *plan.Scope, until string, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   492  	stop := false
   493  	return a.analyzeWithSelector(ctx, n, scope, func(desc string) bool {
   494  		if stop {
   495  			return false
   496  		}
   497  		if desc == until {
   498  			stop = true
   499  		}
   500  		// we return true even for the matching description; only start
   501  		// returning false after this batch.
   502  		return true
   503  	}, sel)
   504  }
   505  
   506  // Every time we recursively invoke the analyzer we increment a depth counter to avoid analyzing queries that could
   507  // cause infinite recursion. This limit is high but arbitrary
   508  const maxBatchRecursion = 100
   509  
   510  func (a *Analyzer) analyzeWithSelector(ctx *sql.Context, n sql.Node, scope *plan.Scope, batchSelector BatchSelector, ruleSelector RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   511  	span, ctx := ctx.Span("analyze")
   512  
   513  	if scope.RecursionDepth() > maxBatchRecursion {
   514  		return n, transform.SameTree, ErrMaxAnalysisIters.New(maxBatchRecursion)
   515  	}
   516  
   517  	var (
   518  		same    = transform.SameTree
   519  		allSame = transform.SameTree
   520  		err     error
   521  	)
   522  	a.Log("starting analysis of node of type: %T", n)
   523  	a.LogNode(n)
   524  	for _, batch := range a.Batches {
   525  		if batchSelector(batch.Desc) {
   526  			a.PushDebugContext(batch.Desc)
   527  			n, same, err = batch.Eval(ctx, a, n, scope, ruleSelector)
   528  			allSame = allSame && same
   529  			if err != nil {
   530  				a.Log("Encountered error: %v", err)
   531  				a.PopDebugContext()
   532  				return n, transform.SameTree, err
   533  			}
   534  			a.PopDebugContext()
   535  		}
   536  	}
   537  
   538  	defer func() {
   539  		if n != nil {
   540  			span.SetAttributes(attribute.Bool("IsResolved", n.Resolved()))
   541  		}
   542  		span.End()
   543  	}()
   544  
   545  	return n, allSame, err
   546  }
   547  
   548  func (a *Analyzer) analyzeStartingAtBatch(ctx *sql.Context, n sql.Node, scope *plan.Scope, startAt string, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   549  	start := false
   550  	return a.analyzeWithSelector(ctx, n, scope, func(desc string) bool {
   551  		if desc == startAt {
   552  			start = true
   553  		}
   554  		if start {
   555  			return true
   556  		}
   557  		return false
   558  	}, sel)
   559  }
   560  
   561  func DeepCopyNode(node sql.Node) (sql.Node, error) {
   562  	n, _, err := transform.NodeExprs(node, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   563  		e, err := transform.Clone(e)
   564  		return e, transform.NewTree, err
   565  	})
   566  	return n, err
   567  }