github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/triggers.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/vitess/go/vt/sqlparser"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    27  	"github.com/dolthub/go-mysql-server/sql/transform"
    28  )
    29  
    30  // validateCreateTrigger handles CreateTrigger nodes, resolving references to "old" and "new" table references in
    31  // the trigger body. Also validates that these old and new references are being used appropriately -- they are only
    32  // valid for certain kinds of triggers and certain statements.
    33  func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    34  	ct, ok := node.(*plan.CreateTrigger)
    35  	if !ok {
    36  		return node, transform.SameTree, nil
    37  	}
    38  
    39  	// We just want to verify that the trigger is correctly defined before creating it. If it is, we replace the
    40  	// UnresolvedColumn expressions with placeholder expressions that say they are Resolved().
    41  	// TODO: this might work badly for databases with tables named new and old. Needs tests.
    42  	var err error
    43  	transform.InspectExpressions(ct.Body, func(e sql.Expression) bool {
    44  		switch e := e.(type) {
    45  		case *expression.UnresolvedColumn:
    46  			if strings.ToLower(e.Table()) == "new" {
    47  				if ct.TriggerEvent == sqlparser.DeleteStr {
    48  					err = sql.ErrInvalidUseOfOldNew.New("new", ct.TriggerEvent)
    49  				}
    50  			}
    51  			if strings.ToLower(e.Table()) == "old" {
    52  				if ct.TriggerEvent == sqlparser.InsertStr {
    53  					err = sql.ErrInvalidUseOfOldNew.New("old", ct.TriggerEvent)
    54  				}
    55  			}
    56  		}
    57  		return true
    58  	})
    59  
    60  	if err != nil {
    61  		return nil, transform.SameTree, err
    62  	}
    63  
    64  	// Check to see if the plan sets a value for "old" rows, or if an AFTER trigger assigns to NEW. Both are illegal.
    65  	transform.InspectExpressionsWithNode(ct.Body, func(n sql.Node, e sql.Expression) bool {
    66  		if _, ok := n.(*plan.Set); !ok {
    67  			return true
    68  		}
    69  
    70  		switch e := e.(type) {
    71  		case *expression.SetField:
    72  			switch left := e.LeftChild.(type) {
    73  			case column:
    74  				if strings.ToLower(left.Table()) == "old" {
    75  					err = sql.ErrInvalidUpdateOfOldRow.New()
    76  				}
    77  				if ct.TriggerTime == sqlparser.AfterStr && strings.ToLower(left.Table()) == "new" {
    78  					err = sql.ErrInvalidUpdateInAfterTrigger.New()
    79  				}
    80  			}
    81  		}
    82  
    83  		return true
    84  	})
    85  
    86  	if err != nil {
    87  		return nil, transform.SameTree, err
    88  	}
    89  
    90  	trigTable := getResolvedTable(ct.Table)
    91  	sch := trigTable.Schema()
    92  	colsList := make(map[string]struct{})
    93  	for _, c := range sch {
    94  		colsList[c.Name] = struct{}{}
    95  	}
    96  
    97  	// Check to see if the columns with "new" and "old" table reference are valid columns from the trigger table.
    98  	transform.InspectExpressions(ct.Body, func(e sql.Expression) bool {
    99  		switch e := e.(type) {
   100  		case *expression.UnresolvedColumn:
   101  			if strings.ToLower(e.Table()) == "old" || strings.ToLower(e.Table()) == "new" {
   102  				if _, ok := colsList[e.Name()]; !ok {
   103  					err = sql.ErrUnknownColumn.New(e.Name(), e.Table())
   104  				}
   105  			}
   106  		}
   107  		return true
   108  	})
   109  
   110  	if err != nil {
   111  		return nil, transform.SameTree, err
   112  	}
   113  	return node, transform.NewTree, nil
   114  }
   115  
   116  func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   117  	// Skip this step for CreateTrigger statements
   118  	if _, ok := n.(*plan.CreateTrigger); ok {
   119  		return n, transform.SameTree, nil
   120  	}
   121  
   122  	var affectedTables []string
   123  	var triggerEvent plan.TriggerEvent
   124  	db := ctx.GetCurrentDatabase()
   125  	transform.Inspect(n, func(n sql.Node) bool {
   126  		switch n := n.(type) {
   127  		case *plan.InsertInto:
   128  			affectedTables = append(affectedTables, getTableName(n))
   129  			triggerEvent = plan.InsertTrigger
   130  			if n.Database() != nil && n.Database().Name() != "" {
   131  				db = n.Database().Name()
   132  			}
   133  		case *plan.Update:
   134  			affectedTables = append(affectedTables, getTableName(n))
   135  			triggerEvent = plan.UpdateTrigger
   136  			if n.Database() != "" {
   137  				db = n.Database()
   138  			}
   139  		case *plan.DeleteFrom:
   140  			for _, target := range n.GetDeleteTargets() {
   141  				affectedTables = append(affectedTables, getTableName(target))
   142  			}
   143  			triggerEvent = plan.DeleteTrigger
   144  			if n.Database() != "" {
   145  				db = n.Database()
   146  			}
   147  		}
   148  		return true
   149  	})
   150  
   151  	if len(affectedTables) == 0 {
   152  		return n, transform.SameTree, nil
   153  	}
   154  
   155  	// TODO: database should be dependent on the table being inserted / updated, but we don't have that info available
   156  	//  from the table object yet.
   157  	database, err := a.Catalog.Database(ctx, db)
   158  	if err != nil {
   159  		return nil, transform.SameTree, err
   160  	}
   161  
   162  	var affectedTriggers []*plan.CreateTrigger
   163  	if tdb, ok := database.(sql.TriggerDatabase); ok {
   164  		triggers, err := tdb.GetTriggers(ctx)
   165  		if err != nil {
   166  			return nil, transform.SameTree, err
   167  		}
   168  
   169  		b := planbuilder.New(ctx, a.Catalog)
   170  		prevActive := b.TriggerCtx().Active
   171  		b.TriggerCtx().Active = true
   172  		defer func() {
   173  			b.TriggerCtx().Active = prevActive
   174  		}()
   175  
   176  		for _, trigger := range triggers {
   177  			var parsedTrigger sql.Node
   178  			sqlMode := sql.NewSqlModeFromString(trigger.SqlMode)
   179  			b.SetParserOptions(sqlMode.ParserOptions())
   180  			parsedTrigger, _, _, err = b.Parse(trigger.CreateStatement, false)
   181  			b.Reset()
   182  			if err != nil {
   183  				return nil, transform.SameTree, err
   184  			}
   185  
   186  			ct, ok := parsedTrigger.(*plan.CreateTrigger)
   187  			if !ok {
   188  				return nil, transform.SameTree, sql.ErrTriggerCreateStatementInvalid.New(trigger.CreateStatement)
   189  			}
   190  
   191  			var triggerTable string
   192  			switch t := ct.Table.(type) {
   193  			case *plan.ResolvedTable:
   194  				triggerTable = t.Name()
   195  			default:
   196  			}
   197  			if stringContains(affectedTables, triggerTable) && triggerEventsMatch(triggerEvent, ct.TriggerEvent) {
   198  				// first pass allows unresolved before we know whether trigger is relevant
   199  				// TODO store destination table name with trigger, so we don't have to do parse twice
   200  				b.TriggerCtx().Call = true
   201  				parsedTrigger, _, _, err = b.Parse(trigger.CreateStatement, false)
   202  				b.TriggerCtx().Call = false
   203  				b.Reset()
   204  				if err != nil {
   205  					return nil, transform.SameTree, err
   206  				}
   207  
   208  				ct, ok := parsedTrigger.(*plan.CreateTrigger)
   209  				if !ok {
   210  					return nil, transform.SameTree, sql.ErrTriggerCreateStatementInvalid.New(trigger.CreateStatement)
   211  				}
   212  
   213  				if block, ok := ct.Body.(*plan.BeginEndBlock); ok {
   214  					ct.Body = plan.NewTriggerBeginEndBlock(block)
   215  				}
   216  				affectedTriggers = append(affectedTriggers, ct)
   217  			}
   218  		}
   219  	}
   220  
   221  	if len(affectedTriggers) == 0 {
   222  		return n, transform.SameTree, nil
   223  	}
   224  
   225  	triggers := orderTriggersAndReverseAfter(affectedTriggers)
   226  	originalNode := n
   227  	same := transform.SameTree
   228  	allSame := transform.SameTree
   229  	for _, trigger := range triggers {
   230  		err = validateNoCircularUpdates(trigger, originalNode, scope)
   231  		if err != nil {
   232  			return nil, transform.SameTree, err
   233  		}
   234  
   235  		n, same, err = applyTrigger(ctx, a, originalNode, n, scope, trigger)
   236  		if err != nil {
   237  			return nil, transform.SameTree, err
   238  		}
   239  		allSame = same && allSame
   240  	}
   241  
   242  	return n, allSame, nil
   243  }
   244  
   245  // applyTrigger applies the trigger given to the node given, returning the resulting node
   246  func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope *plan.Scope, trigger *plan.CreateTrigger) (sql.Node, transform.TreeIdentity, error) {
   247  	triggerLogic, err := getTriggerLogic(ctx, a, originalNode, scope, trigger)
   248  	if err != nil {
   249  		return nil, transform.SameTree, err
   250  	}
   251  
   252  	return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
   253  		// Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the
   254  		// parent is a trigger body.
   255  		// TODO: this won't work for BEGIN END blocks, stored procedures, etc. For those, we need to examine all ancestors,
   256  		//  not just the immediate parent. Alternately, we could do something like not walk all children of some node types
   257  		//  (probably better).
   258  		if _, ok := c.Parent.(*plan.TriggerExecutor); ok {
   259  			if c.ChildNum == 1 { // Right child is the trigger execution logic
   260  				return c.Node, transform.SameTree, nil
   261  			}
   262  		}
   263  
   264  		switch n := c.Node.(type) {
   265  		case *plan.InsertInto:
   266  			if trigger.TriggerTime == sqlparser.BeforeStr {
   267  				triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   268  					Name:            trigger.TriggerName,
   269  					CreateStatement: trigger.CreateTriggerString,
   270  				})
   271  				return n.WithSource(triggerExecutor), transform.NewTree, nil
   272  			} else {
   273  				return plan.NewTriggerExecutor(n, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   274  					Name:            trigger.TriggerName,
   275  					CreateStatement: trigger.CreateTriggerString,
   276  				}), transform.NewTree, nil
   277  			}
   278  		case *plan.Update:
   279  			if trigger.TriggerTime == sqlparser.BeforeStr {
   280  				triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   281  					Name:            trigger.TriggerName,
   282  					CreateStatement: trigger.CreateTriggerString,
   283  				})
   284  				node, err := n.WithChildren(triggerExecutor)
   285  				return node, transform.NewTree, err
   286  			} else {
   287  				return plan.NewTriggerExecutor(n, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   288  					Name:            trigger.TriggerName,
   289  					CreateStatement: trigger.CreateTriggerString,
   290  				}), transform.NewTree, nil
   291  			}
   292  		case *plan.DeleteFrom:
   293  			// TODO: This should work correctly when there is only one table that
   294  			//       has a trigger on it, but it won't work if a DELETE FROM JOIN
   295  			//       is deleting from two tables that both have triggers. Seems
   296  			//       like we need something like a MultipleTriggerExecutor node
   297  			//       that could execute multiple triggers on the same row from its
   298  			//       wrapped iterator. There is also an issue with running triggers
   299  			//       because their field indexes assume the row they evalute will
   300  			//       only ever contain the columns from the single table the trigger
   301  			//       is based on, but this isn't true with UPDATE JOIN or DELETE JOIN.
   302  			if n.HasExplicitTargets() {
   303  				return nil, transform.SameTree, fmt.Errorf("delete from with explicit target tables " +
   304  					"does not support triggers; retry with single table deletes")
   305  			}
   306  
   307  			if trigger.TriggerTime == sqlparser.BeforeStr {
   308  				triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   309  					Name:            trigger.TriggerName,
   310  					CreateStatement: trigger.CreateTriggerString,
   311  				})
   312  				node, err := n.WithChildren(triggerExecutor)
   313  				return node, transform.NewTree, err
   314  			} else {
   315  				return plan.NewTriggerExecutor(n, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
   316  					Name:            trigger.TriggerName,
   317  					CreateStatement: trigger.CreateTriggerString,
   318  				}), transform.NewTree, nil
   319  			}
   320  		}
   321  
   322  		return c.Node, transform.SameTree, nil
   323  	})
   324  }
   325  
   326  // getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
   327  // plan node given, which must be an insert, update, or delete.
   328  func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, trigger *plan.CreateTrigger) (sql.Node, error) {
   329  	// For trigger body analysis, we don't want any row update accumulators applied to insert / update / delete
   330  	// statements, we need the raw output from them.
   331  	var noRowUpdateAccumulators RuleSelector
   332  	noRowUpdateAccumulators = func(id RuleId) bool {
   333  		return DefaultRuleSelector(id) && id != applyRowUpdateAccumulatorsId
   334  	}
   335  
   336  	// For the reference to the row in the trigger table, we use the scope mechanism. This is a little strange because
   337  	// scopes for subqueries work with the child schemas of a scope node, but we don't have such a node here. Instead we
   338  	// fabricate one with the right properties (its child schema matches the table schema, with the right aliased name)
   339  	var triggerLogic sql.Node
   340  	var err error
   341  	switch trigger.TriggerEvent {
   342  	case sqlparser.InsertStr:
   343  		scopeNode := plan.NewProject(
   344  			[]sql.Expression{expression.NewStar()},
   345  			plan.NewTableAlias("new", getResolvedTable(n)),
   346  		)
   347  		s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
   348  		triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators)
   349  	case sqlparser.UpdateStr:
   350  		scopeNode := plan.NewProject(
   351  			[]sql.Expression{expression.NewStar()},
   352  			plan.NewCrossJoin(
   353  				plan.NewTableAlias("old", getResolvedTable(n)),
   354  				plan.NewTableAlias("new", getResolvedTable(n)),
   355  			),
   356  		)
   357  		s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
   358  		triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators)
   359  	case sqlparser.DeleteStr:
   360  		scopeNode := plan.NewProject(
   361  			[]sql.Expression{expression.NewStar()},
   362  			plan.NewTableAlias("old", getResolvedTable(n)),
   363  		)
   364  		s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
   365  		triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators)
   366  	}
   367  
   368  	return StripPassthroughNodes(triggerLogic), err
   369  }
   370  
   371  // validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
   372  // table being updated in an outer scope of this analysis)
   373  func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error {
   374  	var circularRef error
   375  	transform.Inspect(trigger.Body, func(node sql.Node) bool {
   376  		switch node := node.(type) {
   377  		case *plan.Update, *plan.InsertInto, *plan.DeleteFrom:
   378  			for _, n := range append([]sql.Node{n}, scope.MemoNodes()...) {
   379  				invokingTableName := getUnaliasedTableName(n)
   380  				updatedTable := getUnaliasedTableName(node)
   381  				// TODO: need to compare DB as well
   382  				if updatedTable == invokingTableName {
   383  					circularRef = sql.ErrTriggerTableInUse.New(updatedTable)
   384  					return false
   385  				}
   386  			}
   387  		}
   388  		return true
   389  	})
   390  
   391  	return circularRef
   392  }
   393  
   394  func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.CreateTrigger {
   395  	beforeTriggers, afterTriggers := plan.OrderTriggers(triggers)
   396  
   397  	// Reverse the order of after triggers. This is because we always apply them to the Insert / Update / Delete node
   398  	// that initiated the trigger, so after triggers, which wrap the Insert, need be applied in reverse order for them to
   399  	// run in the correct order.
   400  	for left, right := 0, len(afterTriggers)-1; left < right; left, right = left+1, right-1 {
   401  		afterTriggers[left], afterTriggers[right] = afterTriggers[right], afterTriggers[left]
   402  	}
   403  
   404  	return append(beforeTriggers, afterTriggers...)
   405  }
   406  
   407  func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool {
   408  	return strings.ToLower((string)(event)) == strings.ToLower(event2)
   409  }
   410  
   411  // wrapWritesWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors
   412  func wrapWritesWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   413  	// Check if tree contains a TriggerExecutor
   414  	containsTrigger := false
   415  	transform.Inspect(n, func(n sql.Node) bool {
   416  		// After Triggers wrap nodes
   417  		if _, ok := n.(*plan.TriggerExecutor); ok {
   418  			containsTrigger = true
   419  			return false // done, don't bother to recurse
   420  		}
   421  
   422  		// Before Triggers on Inserts are inside Source
   423  		if n, ok := n.(*plan.InsertInto); ok {
   424  			if _, ok := n.Source.(*plan.TriggerExecutor); ok {
   425  				containsTrigger = true
   426  				return false
   427  			}
   428  		}
   429  
   430  		// Before Triggers on Delete and Update should be in children
   431  		return true
   432  	})
   433  
   434  	// No TriggerExecutor, so return same tree
   435  	if !containsTrigger {
   436  		return n, transform.SameTree, nil
   437  	}
   438  
   439  	// If we don't have a transaction session we can't do rollbacks
   440  	_, ok := ctx.Session.(sql.TransactionSession)
   441  	if !ok {
   442  		return plan.NewNoopTriggerRollback(n), transform.NewTree, nil
   443  	}
   444  
   445  	// Wrap tree with new node
   446  	return plan.NewTriggerRollback(n), transform.NewTree, nil
   447  }