github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/load_triggers.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"strings"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/plan"
    22  	"github.com/dolthub/go-mysql-server/sql/planbuilder"
    23  	"github.com/dolthub/go-mysql-server/sql/transform"
    24  )
    25  
    26  // loadTriggers loads any triggers that are required for a plan node to operate properly (except for nodes dealing with
    27  // trigger execution).
    28  func loadTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    29  	span, ctx := ctx.Span("loadTriggers")
    30  	defer span.End()
    31  
    32  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    33  		switch node := n.(type) {
    34  		case *plan.ShowTriggers:
    35  			newShowTriggers := *node
    36  			loadedTriggers, err := loadTriggersFromDb(ctx, a, newShowTriggers.Database())
    37  			if err != nil {
    38  				return nil, transform.SameTree, err
    39  			}
    40  			if len(loadedTriggers) != 0 {
    41  				newShowTriggers.Triggers = loadedTriggers
    42  			} else {
    43  				newShowTriggers.Triggers = make([]*plan.CreateTrigger, 0)
    44  			}
    45  			return &newShowTriggers, transform.NewTree, nil
    46  		case *plan.DropTrigger:
    47  			loadedTriggers, err := loadTriggersFromDb(ctx, a, node.Database())
    48  			if err != nil {
    49  				return nil, transform.SameTree, err
    50  			}
    51  			lowercasedTriggerName := strings.ToLower(node.TriggerName)
    52  			for _, trigger := range loadedTriggers {
    53  				if strings.ToLower(trigger.TriggerName) == lowercasedTriggerName {
    54  					node.TriggerName = trigger.TriggerName
    55  				} else if trigger.TriggerOrder != nil &&
    56  					strings.ToLower(trigger.TriggerOrder.OtherTriggerName) == lowercasedTriggerName {
    57  					return nil, transform.SameTree, sql.ErrTriggerCannotBeDropped.New(node.TriggerName, trigger.TriggerName)
    58  				}
    59  			}
    60  			return node, transform.NewTree, nil
    61  		case *plan.DropTable:
    62  			// if there is no table left after filtering out non-existent tables, no need to load triggers
    63  			if len(node.Tables) == 0 {
    64  				return node, transform.SameTree, nil
    65  			}
    66  
    67  			// the table has to be TableNode as this rule is executed after resolve-table rule
    68  			var dropTableDb sql.Database
    69  			if t, ok := node.Tables[0].(*plan.ResolvedTable); ok {
    70  				dropTableDb = t.SqlDatabase
    71  			}
    72  
    73  			loadedTriggers, err := loadTriggersFromDb(ctx, a, dropTableDb)
    74  			if err != nil {
    75  				return nil, transform.SameTree, err
    76  			}
    77  			lowercasedNames := make(map[string]struct{})
    78  			tblNames, err := node.TableNames()
    79  			if err != nil {
    80  				return nil, transform.SameTree, err
    81  			}
    82  			for _, tableName := range tblNames {
    83  				lowercasedNames[strings.ToLower(tableName)] = struct{}{}
    84  			}
    85  			var triggersForTable []string
    86  			for _, trigger := range loadedTriggers {
    87  				if _, ok := lowercasedNames[strings.ToLower(trigger.Table.(sql.Nameable).Name())]; ok {
    88  					triggersForTable = append(triggersForTable, trigger.TriggerName)
    89  				}
    90  			}
    91  			return node.WithTriggers(triggersForTable), transform.NewTree, nil
    92  		default:
    93  			return node, transform.SameTree, nil
    94  		}
    95  	})
    96  }
    97  
    98  func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database) ([]*plan.CreateTrigger, error) {
    99  	var loadedTriggers []*plan.CreateTrigger
   100  	if triggerDb, ok := db.(sql.TriggerDatabase); ok {
   101  		triggers, err := triggerDb.GetTriggers(ctx)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		for _, trigger := range triggers {
   106  			var parsedTrigger sql.Node
   107  			sqlMode := sql.NewSqlModeFromString(trigger.SqlMode)
   108  			parsedTrigger, err = planbuilder.ParseWithOptions(ctx, a.Catalog, trigger.CreateStatement, sqlMode.ParserOptions())
   109  			if err != nil {
   110  				return nil, err
   111  			}
   112  			triggerPlan, ok := parsedTrigger.(*plan.CreateTrigger)
   113  			if !ok {
   114  				return nil, sql.ErrTriggerCreateStatementInvalid.New(trigger.CreateStatement)
   115  			}
   116  			triggerPlan.CreatedAt = trigger.CreatedAt // use the stored created time
   117  			loadedTriggers = append(loadedTriggers, triggerPlan)
   118  		}
   119  	}
   120  	return loadedTriggers, nil
   121  }