github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/triggers.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/vitess/go/vt/sqlparser" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/planbuilder" 27 "github.com/dolthub/go-mysql-server/sql/transform" 28 ) 29 30 // validateCreateTrigger handles CreateTrigger nodes, resolving references to "old" and "new" table references in 31 // the trigger body. Also validates that these old and new references are being used appropriately -- they are only 32 // valid for certain kinds of triggers and certain statements. 33 func validateCreateTrigger(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 34 ct, ok := node.(*plan.CreateTrigger) 35 if !ok { 36 return node, transform.SameTree, nil 37 } 38 39 // We just want to verify that the trigger is correctly defined before creating it. If it is, we replace the 40 // UnresolvedColumn expressions with placeholder expressions that say they are Resolved(). 41 // TODO: this might work badly for databases with tables named new and old. Needs tests. 42 var err error 43 transform.InspectExpressions(ct.Body, func(e sql.Expression) bool { 44 switch e := e.(type) { 45 case *expression.UnresolvedColumn: 46 if strings.ToLower(e.Table()) == "new" { 47 if ct.TriggerEvent == sqlparser.DeleteStr { 48 err = sql.ErrInvalidUseOfOldNew.New("new", ct.TriggerEvent) 49 } 50 } 51 if strings.ToLower(e.Table()) == "old" { 52 if ct.TriggerEvent == sqlparser.InsertStr { 53 err = sql.ErrInvalidUseOfOldNew.New("old", ct.TriggerEvent) 54 } 55 } 56 } 57 return true 58 }) 59 60 if err != nil { 61 return nil, transform.SameTree, err 62 } 63 64 // Check to see if the plan sets a value for "old" rows, or if an AFTER trigger assigns to NEW. Both are illegal. 65 transform.InspectExpressionsWithNode(ct.Body, func(n sql.Node, e sql.Expression) bool { 66 if _, ok := n.(*plan.Set); !ok { 67 return true 68 } 69 70 switch e := e.(type) { 71 case *expression.SetField: 72 switch left := e.LeftChild.(type) { 73 case column: 74 if strings.ToLower(left.Table()) == "old" { 75 err = sql.ErrInvalidUpdateOfOldRow.New() 76 } 77 if ct.TriggerTime == sqlparser.AfterStr && strings.ToLower(left.Table()) == "new" { 78 err = sql.ErrInvalidUpdateInAfterTrigger.New() 79 } 80 } 81 } 82 83 return true 84 }) 85 86 if err != nil { 87 return nil, transform.SameTree, err 88 } 89 90 trigTable := getResolvedTable(ct.Table) 91 sch := trigTable.Schema() 92 colsList := make(map[string]struct{}) 93 for _, c := range sch { 94 colsList[c.Name] = struct{}{} 95 } 96 97 // Check to see if the columns with "new" and "old" table reference are valid columns from the trigger table. 98 transform.InspectExpressions(ct.Body, func(e sql.Expression) bool { 99 switch e := e.(type) { 100 case *expression.UnresolvedColumn: 101 if strings.ToLower(e.Table()) == "old" || strings.ToLower(e.Table()) == "new" { 102 if _, ok := colsList[e.Name()]; !ok { 103 err = sql.ErrUnknownColumn.New(e.Name(), e.Table()) 104 } 105 } 106 } 107 return true 108 }) 109 110 if err != nil { 111 return nil, transform.SameTree, err 112 } 113 return node, transform.NewTree, nil 114 } 115 116 func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 117 // Skip this step for CreateTrigger statements 118 if _, ok := n.(*plan.CreateTrigger); ok { 119 return n, transform.SameTree, nil 120 } 121 122 var affectedTables []string 123 var triggerEvent plan.TriggerEvent 124 db := ctx.GetCurrentDatabase() 125 transform.Inspect(n, func(n sql.Node) bool { 126 switch n := n.(type) { 127 case *plan.InsertInto: 128 affectedTables = append(affectedTables, getTableName(n)) 129 triggerEvent = plan.InsertTrigger 130 if n.Database() != nil && n.Database().Name() != "" { 131 db = n.Database().Name() 132 } 133 case *plan.Update: 134 affectedTables = append(affectedTables, getTableName(n)) 135 triggerEvent = plan.UpdateTrigger 136 if n.Database() != "" { 137 db = n.Database() 138 } 139 case *plan.DeleteFrom: 140 for _, target := range n.GetDeleteTargets() { 141 affectedTables = append(affectedTables, getTableName(target)) 142 } 143 triggerEvent = plan.DeleteTrigger 144 if n.Database() != "" { 145 db = n.Database() 146 } 147 } 148 return true 149 }) 150 151 if len(affectedTables) == 0 { 152 return n, transform.SameTree, nil 153 } 154 155 // TODO: database should be dependent on the table being inserted / updated, but we don't have that info available 156 // from the table object yet. 157 database, err := a.Catalog.Database(ctx, db) 158 if err != nil { 159 return nil, transform.SameTree, err 160 } 161 162 var affectedTriggers []*plan.CreateTrigger 163 if tdb, ok := database.(sql.TriggerDatabase); ok { 164 triggers, err := tdb.GetTriggers(ctx) 165 if err != nil { 166 return nil, transform.SameTree, err 167 } 168 169 b := planbuilder.New(ctx, a.Catalog) 170 prevActive := b.TriggerCtx().Active 171 b.TriggerCtx().Active = true 172 defer func() { 173 b.TriggerCtx().Active = prevActive 174 }() 175 176 for _, trigger := range triggers { 177 var parsedTrigger sql.Node 178 sqlMode := sql.NewSqlModeFromString(trigger.SqlMode) 179 b.SetParserOptions(sqlMode.ParserOptions()) 180 parsedTrigger, _, _, err = b.Parse(trigger.CreateStatement, false) 181 b.Reset() 182 if err != nil { 183 return nil, transform.SameTree, err 184 } 185 186 ct, ok := parsedTrigger.(*plan.CreateTrigger) 187 if !ok { 188 return nil, transform.SameTree, sql.ErrTriggerCreateStatementInvalid.New(trigger.CreateStatement) 189 } 190 191 var triggerTable string 192 switch t := ct.Table.(type) { 193 case *plan.ResolvedTable: 194 triggerTable = t.Name() 195 default: 196 } 197 if stringContains(affectedTables, triggerTable) && triggerEventsMatch(triggerEvent, ct.TriggerEvent) { 198 // first pass allows unresolved before we know whether trigger is relevant 199 // TODO store destination table name with trigger, so we don't have to do parse twice 200 b.TriggerCtx().Call = true 201 parsedTrigger, _, _, err = b.Parse(trigger.CreateStatement, false) 202 b.TriggerCtx().Call = false 203 b.Reset() 204 if err != nil { 205 return nil, transform.SameTree, err 206 } 207 208 ct, ok := parsedTrigger.(*plan.CreateTrigger) 209 if !ok { 210 return nil, transform.SameTree, sql.ErrTriggerCreateStatementInvalid.New(trigger.CreateStatement) 211 } 212 213 if block, ok := ct.Body.(*plan.BeginEndBlock); ok { 214 ct.Body = plan.NewTriggerBeginEndBlock(block) 215 } 216 affectedTriggers = append(affectedTriggers, ct) 217 } 218 } 219 } 220 221 if len(affectedTriggers) == 0 { 222 return n, transform.SameTree, nil 223 } 224 225 triggers := orderTriggersAndReverseAfter(affectedTriggers) 226 originalNode := n 227 same := transform.SameTree 228 allSame := transform.SameTree 229 for _, trigger := range triggers { 230 err = validateNoCircularUpdates(trigger, originalNode, scope) 231 if err != nil { 232 return nil, transform.SameTree, err 233 } 234 235 n, same, err = applyTrigger(ctx, a, originalNode, n, scope, trigger) 236 if err != nil { 237 return nil, transform.SameTree, err 238 } 239 allSame = same && allSame 240 } 241 242 return n, allSame, nil 243 } 244 245 // applyTrigger applies the trigger given to the node given, returning the resulting node 246 func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope *plan.Scope, trigger *plan.CreateTrigger) (sql.Node, transform.TreeIdentity, error) { 247 triggerLogic, err := getTriggerLogic(ctx, a, originalNode, scope, trigger) 248 if err != nil { 249 return nil, transform.SameTree, err 250 } 251 252 return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { 253 // Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the 254 // parent is a trigger body. 255 // TODO: this won't work for BEGIN END blocks, stored procedures, etc. For those, we need to examine all ancestors, 256 // not just the immediate parent. Alternately, we could do something like not walk all children of some node types 257 // (probably better). 258 if _, ok := c.Parent.(*plan.TriggerExecutor); ok { 259 if c.ChildNum == 1 { // Right child is the trigger execution logic 260 return c.Node, transform.SameTree, nil 261 } 262 } 263 264 switch n := c.Node.(type) { 265 case *plan.InsertInto: 266 if trigger.TriggerTime == sqlparser.BeforeStr { 267 triggerExecutor := plan.NewTriggerExecutor(n.Source, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 268 Name: trigger.TriggerName, 269 CreateStatement: trigger.CreateTriggerString, 270 }) 271 return n.WithSource(triggerExecutor), transform.NewTree, nil 272 } else { 273 return plan.NewTriggerExecutor(n, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 274 Name: trigger.TriggerName, 275 CreateStatement: trigger.CreateTriggerString, 276 }), transform.NewTree, nil 277 } 278 case *plan.Update: 279 if trigger.TriggerTime == sqlparser.BeforeStr { 280 triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 281 Name: trigger.TriggerName, 282 CreateStatement: trigger.CreateTriggerString, 283 }) 284 node, err := n.WithChildren(triggerExecutor) 285 return node, transform.NewTree, err 286 } else { 287 return plan.NewTriggerExecutor(n, triggerLogic, plan.UpdateTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 288 Name: trigger.TriggerName, 289 CreateStatement: trigger.CreateTriggerString, 290 }), transform.NewTree, nil 291 } 292 case *plan.DeleteFrom: 293 // TODO: This should work correctly when there is only one table that 294 // has a trigger on it, but it won't work if a DELETE FROM JOIN 295 // is deleting from two tables that both have triggers. Seems 296 // like we need something like a MultipleTriggerExecutor node 297 // that could execute multiple triggers on the same row from its 298 // wrapped iterator. There is also an issue with running triggers 299 // because their field indexes assume the row they evalute will 300 // only ever contain the columns from the single table the trigger 301 // is based on, but this isn't true with UPDATE JOIN or DELETE JOIN. 302 if n.HasExplicitTargets() { 303 return nil, transform.SameTree, fmt.Errorf("delete from with explicit target tables " + 304 "does not support triggers; retry with single table deletes") 305 } 306 307 if trigger.TriggerTime == sqlparser.BeforeStr { 308 triggerExecutor := plan.NewTriggerExecutor(n.Child, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 309 Name: trigger.TriggerName, 310 CreateStatement: trigger.CreateTriggerString, 311 }) 312 node, err := n.WithChildren(triggerExecutor) 313 return node, transform.NewTree, err 314 } else { 315 return plan.NewTriggerExecutor(n, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{ 316 Name: trigger.TriggerName, 317 CreateStatement: trigger.CreateTriggerString, 318 }), transform.NewTree, nil 319 } 320 } 321 322 return c.Node, transform.SameTree, nil 323 }) 324 } 325 326 // getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the 327 // plan node given, which must be an insert, update, or delete. 328 func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, trigger *plan.CreateTrigger) (sql.Node, error) { 329 // For trigger body analysis, we don't want any row update accumulators applied to insert / update / delete 330 // statements, we need the raw output from them. 331 var noRowUpdateAccumulators RuleSelector 332 noRowUpdateAccumulators = func(id RuleId) bool { 333 return DefaultRuleSelector(id) && id != applyRowUpdateAccumulatorsId 334 } 335 336 // For the reference to the row in the trigger table, we use the scope mechanism. This is a little strange because 337 // scopes for subqueries work with the child schemas of a scope node, but we don't have such a node here. Instead we 338 // fabricate one with the right properties (its child schema matches the table schema, with the right aliased name) 339 var triggerLogic sql.Node 340 var err error 341 switch trigger.TriggerEvent { 342 case sqlparser.InsertStr: 343 scopeNode := plan.NewProject( 344 []sql.Expression{expression.NewStar()}, 345 plan.NewTableAlias("new", getResolvedTable(n)), 346 ) 347 s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache()) 348 triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators) 349 case sqlparser.UpdateStr: 350 scopeNode := plan.NewProject( 351 []sql.Expression{expression.NewStar()}, 352 plan.NewCrossJoin( 353 plan.NewTableAlias("old", getResolvedTable(n)), 354 plan.NewTableAlias("new", getResolvedTable(n)), 355 ), 356 ) 357 s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache()) 358 triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators) 359 case sqlparser.DeleteStr: 360 scopeNode := plan.NewProject( 361 []sql.Expression{expression.NewStar()}, 362 plan.NewTableAlias("old", getResolvedTable(n)), 363 ) 364 s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache()) 365 triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators) 366 } 367 368 return StripPassthroughNodes(triggerLogic), err 369 } 370 371 // validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any 372 // table being updated in an outer scope of this analysis) 373 func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error { 374 var circularRef error 375 transform.Inspect(trigger.Body, func(node sql.Node) bool { 376 switch node := node.(type) { 377 case *plan.Update, *plan.InsertInto, *plan.DeleteFrom: 378 for _, n := range append([]sql.Node{n}, scope.MemoNodes()...) { 379 invokingTableName := getUnaliasedTableName(n) 380 updatedTable := getUnaliasedTableName(node) 381 // TODO: need to compare DB as well 382 if updatedTable == invokingTableName { 383 circularRef = sql.ErrTriggerTableInUse.New(updatedTable) 384 return false 385 } 386 } 387 } 388 return true 389 }) 390 391 return circularRef 392 } 393 394 func orderTriggersAndReverseAfter(triggers []*plan.CreateTrigger) []*plan.CreateTrigger { 395 beforeTriggers, afterTriggers := plan.OrderTriggers(triggers) 396 397 // Reverse the order of after triggers. This is because we always apply them to the Insert / Update / Delete node 398 // that initiated the trigger, so after triggers, which wrap the Insert, need be applied in reverse order for them to 399 // run in the correct order. 400 for left, right := 0, len(afterTriggers)-1; left < right; left, right = left+1, right-1 { 401 afterTriggers[left], afterTriggers[right] = afterTriggers[right], afterTriggers[left] 402 } 403 404 return append(beforeTriggers, afterTriggers...) 405 } 406 407 func triggerEventsMatch(event plan.TriggerEvent, event2 string) bool { 408 return strings.ToLower((string)(event)) == strings.ToLower(event2) 409 } 410 411 // wrapWritesWithRollback wraps the entire tree iff it contains a trigger, allowing rollback when a trigger errors 412 func wrapWritesWithRollback(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 413 // Check if tree contains a TriggerExecutor 414 containsTrigger := false 415 transform.Inspect(n, func(n sql.Node) bool { 416 // After Triggers wrap nodes 417 if _, ok := n.(*plan.TriggerExecutor); ok { 418 containsTrigger = true 419 return false // done, don't bother to recurse 420 } 421 422 // Before Triggers on Inserts are inside Source 423 if n, ok := n.(*plan.InsertInto); ok { 424 if _, ok := n.Source.(*plan.TriggerExecutor); ok { 425 containsTrigger = true 426 return false 427 } 428 } 429 430 // Before Triggers on Delete and Update should be in children 431 return true 432 }) 433 434 // No TriggerExecutor, so return same tree 435 if !containsTrigger { 436 return n, transform.SameTree, nil 437 } 438 439 // If we don't have a transaction session we can't do rollbacks 440 _, ok := ctx.Session.(sql.TransactionSession) 441 if !ok { 442 return plan.NewNoopTriggerRollback(n), transform.NewTree, nil 443 } 444 445 // Wrap tree with new node 446 return plan.NewTriggerRollback(n), transform.NewTree, nil 447 }