github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/trigger.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  )
    20  
    21  type TriggerEvent string
    22  
    23  const (
    24  	InsertTrigger TriggerEvent = "insert"
    25  	UpdateTrigger              = "update"
    26  	DeleteTrigger              = "delete"
    27  )
    28  
    29  type TriggerTime string
    30  
    31  const (
    32  	BeforeTrigger TriggerTime = "before"
    33  	AfterTrigger              = "after"
    34  )
    35  
    36  // TriggerExecutor is node that wraps, or is wrapped by, an INSERT, UPDATE, or DELETE node to execute defined trigger
    37  // logic either before or after that operation. When a table has multiple triggers defined, TriggerExecutor nodes can
    38  // wrap each other as well.
    39  type TriggerExecutor struct {
    40  	BinaryNode        // Left = wrapped node, Right = trigger execution logic
    41  	TriggerEvent      TriggerEvent
    42  	TriggerTime       TriggerTime
    43  	TriggerDefinition sql.TriggerDefinition
    44  }
    45  
    46  var _ sql.Node = (*TriggerExecutor)(nil)
    47  var _ sql.CollationCoercible = (*TriggerExecutor)(nil)
    48  
    49  func NewTriggerExecutor(child, triggerLogic sql.Node, triggerEvent TriggerEvent, triggerTime TriggerTime, triggerDefinition sql.TriggerDefinition) *TriggerExecutor {
    50  	return &TriggerExecutor{
    51  		BinaryNode: BinaryNode{
    52  			left:  child,
    53  			right: triggerLogic,
    54  		},
    55  		TriggerEvent:      triggerEvent,
    56  		TriggerTime:       triggerTime,
    57  		TriggerDefinition: triggerDefinition,
    58  	}
    59  }
    60  
    61  func (t *TriggerExecutor) String() string {
    62  	pr := sql.NewTreePrinter()
    63  	_ = pr.WriteNode("Trigger(%s)", t.TriggerDefinition.CreateStatement)
    64  	_ = pr.WriteChildren(t.left.String())
    65  	return pr.String()
    66  }
    67  
    68  func (t *TriggerExecutor) IsReadOnly() bool {
    69  	return t.left.IsReadOnly() && t.right.IsReadOnly()
    70  }
    71  
    72  func (t *TriggerExecutor) DebugString() string {
    73  	pr := sql.NewTreePrinter()
    74  	_ = pr.WriteNode("Trigger(%s)", t.TriggerDefinition.CreateStatement)
    75  	_ = pr.WriteChildren(sql.DebugString(t.left), sql.DebugString(t.right))
    76  	return pr.String()
    77  }
    78  
    79  func (t *TriggerExecutor) Schema() sql.Schema {
    80  	return t.left.Schema()
    81  }
    82  
    83  func (t *TriggerExecutor) WithChildren(children ...sql.Node) (sql.Node, error) {
    84  	if len(children) != 2 {
    85  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 2)
    86  	}
    87  
    88  	return NewTriggerExecutor(children[0], children[1], t.TriggerEvent, t.TriggerTime, t.TriggerDefinition), nil
    89  }
    90  
    91  // CheckPrivileges implements the interface sql.Node.
    92  func (t *TriggerExecutor) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    93  	// TODO: Figure out exactly how triggers work, not exactly clear whether trigger creator AND user needs the privileges
    94  	subject := sql.PrivilegeCheckSubject{
    95  		Database: GetDatabaseName(t.right),
    96  		Table:    getTableName(t.right),
    97  	}
    98  	return t.left.CheckPrivileges(ctx, opChecker) &&
    99  		opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Trigger))
   100  }
   101  
   102  // CollationCoercibility implements the interface sql.CollationCoercible.
   103  func (t *TriggerExecutor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   104  	return sql.GetCoercibility(ctx, t.left)
   105  }
   106  
   107  // TriggerRollback is a node that wraps the entire tree iff it contains a trigger, creates a savepoint, and performs a
   108  // rollback if something went wrong during execution
   109  type TriggerRollback struct {
   110  	UnaryNode
   111  }
   112  
   113  var _ sql.Node = (*TriggerRollback)(nil)
   114  var _ sql.CollationCoercible = (*TriggerRollback)(nil)
   115  
   116  func NewTriggerRollback(child sql.Node) *TriggerRollback {
   117  	return &TriggerRollback{
   118  		UnaryNode: UnaryNode{Child: child},
   119  	}
   120  }
   121  
   122  func (t *TriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
   123  	if len(children) != 1 {
   124  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   125  	}
   126  
   127  	return NewTriggerRollback(children[0]), nil
   128  }
   129  
   130  // CheckPrivileges implements the interface sql.Node.
   131  func (t *TriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   132  	return t.Child.CheckPrivileges(ctx, opChecker)
   133  }
   134  
   135  // CollationCoercibility implements the interface sql.CollationCoercible.
   136  func (t *TriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   137  	return sql.GetCoercibility(ctx, t.Child)
   138  }
   139  
   140  func (t *TriggerRollback) IsReadOnly() bool {
   141  	return t.Child.IsReadOnly()
   142  }
   143  
   144  func (t *TriggerRollback) String() string {
   145  	pr := sql.NewTreePrinter()
   146  	_ = pr.WriteNode("TriggerRollback()")
   147  	_ = pr.WriteChildren(t.Child.String())
   148  	return pr.String()
   149  }
   150  
   151  func (t *TriggerRollback) DebugString() string {
   152  	pr := sql.NewTreePrinter()
   153  	_ = pr.WriteNode("TriggerRollback")
   154  	_ = pr.WriteChildren(sql.DebugString(t.Child))
   155  	return pr.String()
   156  }
   157  
   158  type NoopTriggerRollback struct {
   159  	UnaryNode
   160  }
   161  
   162  var _ sql.Node = (*NoopTriggerRollback)(nil)
   163  var _ sql.CollationCoercible = (*NoopTriggerRollback)(nil)
   164  
   165  func NewNoopTriggerRollback(child sql.Node) *NoopTriggerRollback {
   166  	return &NoopTriggerRollback{
   167  		UnaryNode: UnaryNode{Child: child},
   168  	}
   169  }
   170  
   171  func (t *NoopTriggerRollback) WithChildren(children ...sql.Node) (sql.Node, error) {
   172  	if len(children) != 1 {
   173  		return nil, sql.ErrInvalidChildrenNumber.New(t, len(children), 1)
   174  	}
   175  
   176  	return NewNoopTriggerRollback(children[0]), nil
   177  }
   178  
   179  // CheckPrivileges implements the interface sql.Node.
   180  func (t *NoopTriggerRollback) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   181  	return t.Child.CheckPrivileges(ctx, opChecker)
   182  }
   183  
   184  // CollationCoercibility implements the interface sql.CollationCoercible.
   185  func (t *NoopTriggerRollback) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   186  	return sql.GetCoercibility(ctx, t.Child)
   187  }
   188  
   189  func (t *NoopTriggerRollback) IsReadOnly() bool {
   190  	return true
   191  }
   192  
   193  func (t *NoopTriggerRollback) String() string {
   194  	pr := sql.NewTreePrinter()
   195  	_ = pr.WriteNode("TriggerRollback()")
   196  	_ = pr.WriteChildren(t.Child.String())
   197  	return pr.String()
   198  }
   199  
   200  func (t *NoopTriggerRollback) DebugString() string {
   201  	pr := sql.NewTreePrinter()
   202  	_ = pr.WriteNode("TriggerRollback")
   203  	_ = pr.WriteChildren(sql.DebugString(t.Child))
   204  	return pr.String()
   205  }