github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/update.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"gopkg.in/src-d/go-errors.v1"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  )
    25  
    26  var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE")
    27  var ErrUpdateForTableNotSupported = errors.NewKind("The target table %s of the UPDATE is not updatable")
    28  var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but expression returned %T")
    29  
    30  // Update is a node for updating rows on tables.
    31  type Update struct {
    32  	UnaryNode
    33  	checks sql.CheckConstraints
    34  	Ignore bool
    35  }
    36  
    37  var _ sql.Node = (*Update)(nil)
    38  var _ sql.Databaseable = (*Update)(nil)
    39  var _ sql.CollationCoercible = (*Update)(nil)
    40  var _ sql.CheckConstraintNode = (*Update)(nil)
    41  
    42  // NewUpdate creates an Update node.
    43  func NewUpdate(n sql.Node, ignore bool, updateExprs []sql.Expression) *Update {
    44  	return &Update{
    45  		UnaryNode: UnaryNode{NewUpdateSource(
    46  			n,
    47  			ignore,
    48  			updateExprs,
    49  		)},
    50  		Ignore: ignore,
    51  	}
    52  }
    53  
    54  func GetUpdatable(node sql.Node) (sql.UpdatableTable, error) {
    55  	switch node := node.(type) {
    56  	case sql.UpdatableTable:
    57  		return node, nil
    58  	case *IndexedTableAccess:
    59  		return GetUpdatable(node.TableNode)
    60  	case *ResolvedTable:
    61  		return getUpdatableTable(node.Table)
    62  	case *SubqueryAlias:
    63  		return nil, ErrUpdateNotSupported.New()
    64  	case *TriggerExecutor:
    65  		return GetUpdatable(node.Left())
    66  	case sql.TableWrapper:
    67  		return getUpdatableTable(node.Underlying())
    68  	case *UpdateJoin:
    69  		return node.GetUpdatable(), nil
    70  	}
    71  	if len(node.Children()) > 1 {
    72  		return nil, ErrUpdateNotSupported.New()
    73  	}
    74  	for _, child := range node.Children() {
    75  		updater, _ := GetUpdatable(child)
    76  		if updater != nil {
    77  			return updater, nil
    78  		}
    79  	}
    80  	return nil, ErrUpdateNotSupported.New()
    81  }
    82  
    83  func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) {
    84  	switch t := t.(type) {
    85  	case sql.UpdatableTable:
    86  		return t, nil
    87  	case sql.TableWrapper:
    88  		return getUpdatableTable(t.Underlying())
    89  	default:
    90  		return nil, ErrUpdateNotSupported.New()
    91  	}
    92  }
    93  
    94  // GetDatabase returns the first database found in the node tree given
    95  func GetDatabase(node sql.Node) sql.Database {
    96  	switch node := node.(type) {
    97  	case *IndexedTableAccess:
    98  		return GetDatabase(node.TableNode)
    99  	case *ResolvedTable:
   100  		return node.Database()
   101  	case *UnresolvedTable:
   102  		return node.Database()
   103  	}
   104  
   105  	for _, child := range node.Children() {
   106  		return GetDatabase(child)
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func (u *Update) Checks() sql.CheckConstraints {
   113  	return u.checks
   114  }
   115  
   116  func (u *Update) WithChecks(checks sql.CheckConstraints) sql.Node {
   117  	ret := *u
   118  	ret.checks = checks
   119  	return &ret
   120  }
   121  
   122  // DB returns the database being updated. |Database| is already used by another interface we implement.
   123  func (u *Update) DB() sql.Database {
   124  	return GetDatabase(u.Child)
   125  }
   126  
   127  func (u *Update) IsReadOnly() bool {
   128  	return false
   129  }
   130  
   131  func (u *Update) Database() string {
   132  	db := GetDatabase(u.Child)
   133  	if db == nil {
   134  		return ""
   135  	}
   136  	return db.Name()
   137  }
   138  
   139  func (u *Update) Expressions() []sql.Expression {
   140  	return u.checks.ToExpressions()
   141  }
   142  
   143  func (u *Update) Resolved() bool {
   144  	return u.Child.Resolved() && expression.ExpressionsResolved(u.checks.ToExpressions()...)
   145  }
   146  
   147  func (u Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
   148  	if len(newExprs) != len(u.checks) {
   149  		return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), len(u.checks))
   150  	}
   151  
   152  	var err error
   153  	u.checks, err = u.checks.FromExpressions(newExprs)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	return &u, nil
   159  }
   160  
   161  // UpdateInfo is the Info for OKResults returned by Update nodes.
   162  type UpdateInfo struct {
   163  	Matched, Updated, Warnings int
   164  }
   165  
   166  // String implements fmt.Stringer
   167  func (ui UpdateInfo) String() string {
   168  	return fmt.Sprintf("Rows matched: %d  Changed: %d  Warnings: %d", ui.Matched, ui.Updated, ui.Warnings)
   169  }
   170  
   171  // WithChildren implements the Node interface.
   172  func (u *Update) WithChildren(children ...sql.Node) (sql.Node, error) {
   173  	if len(children) != 1 {
   174  		return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
   175  	}
   176  	np := *u
   177  	np.Child = children[0]
   178  	return &np, nil
   179  }
   180  
   181  // CheckPrivileges implements the interface sql.Node.
   182  func (u *Update) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   183  	//TODO: If column values are retrieved then the SELECT privilege is required
   184  	// For example: "UPDATE table SET x = y + 1 WHERE z > 0"
   185  	// We would need SELECT privileges on both the "y" and "z" columns as they're retrieving values
   186  	subject := sql.PrivilegeCheckSubject{
   187  		Database: CheckPrivilegeNameForDatabase(u.DB()),
   188  		Table:    getTableName(u.Child),
   189  	}
   190  	// TODO: this needs a real database, fix it
   191  	return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update))
   192  }
   193  
   194  // CollationCoercibility implements the interface sql.CollationCoercible.
   195  func (*Update) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   196  	return sql.Collation_binary, 7
   197  }
   198  
   199  func (u *Update) String() string {
   200  	pr := sql.NewTreePrinter()
   201  	_ = pr.WriteNode("Update")
   202  	_ = pr.WriteChildren(u.Child.String())
   203  	return pr.String()
   204  }
   205  
   206  func (u *Update) DebugString() string {
   207  	pr := sql.NewTreePrinter()
   208  	_ = pr.WriteNode("Update")
   209  	_ = pr.WriteChildren(sql.DebugString(u.Child))
   210  	return pr.String()
   211  }