github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/update.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 20 "gopkg.in/src-d/go-errors.v1" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/expression" 24 ) 25 26 var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE") 27 var ErrUpdateForTableNotSupported = errors.NewKind("The target table %s of the UPDATE is not updatable") 28 var ErrUpdateUnexpectedSetResult = errors.NewKind("attempted to set field but expression returned %T") 29 30 // Update is a node for updating rows on tables. 31 type Update struct { 32 UnaryNode 33 checks sql.CheckConstraints 34 Ignore bool 35 } 36 37 var _ sql.Node = (*Update)(nil) 38 var _ sql.Databaseable = (*Update)(nil) 39 var _ sql.CollationCoercible = (*Update)(nil) 40 var _ sql.CheckConstraintNode = (*Update)(nil) 41 42 // NewUpdate creates an Update node. 43 func NewUpdate(n sql.Node, ignore bool, updateExprs []sql.Expression) *Update { 44 return &Update{ 45 UnaryNode: UnaryNode{NewUpdateSource( 46 n, 47 ignore, 48 updateExprs, 49 )}, 50 Ignore: ignore, 51 } 52 } 53 54 func GetUpdatable(node sql.Node) (sql.UpdatableTable, error) { 55 switch node := node.(type) { 56 case sql.UpdatableTable: 57 return node, nil 58 case *IndexedTableAccess: 59 return GetUpdatable(node.TableNode) 60 case *ResolvedTable: 61 return getUpdatableTable(node.Table) 62 case *SubqueryAlias: 63 return nil, ErrUpdateNotSupported.New() 64 case *TriggerExecutor: 65 return GetUpdatable(node.Left()) 66 case sql.TableWrapper: 67 return getUpdatableTable(node.Underlying()) 68 case *UpdateJoin: 69 return node.GetUpdatable(), nil 70 } 71 if len(node.Children()) > 1 { 72 return nil, ErrUpdateNotSupported.New() 73 } 74 for _, child := range node.Children() { 75 updater, _ := GetUpdatable(child) 76 if updater != nil { 77 return updater, nil 78 } 79 } 80 return nil, ErrUpdateNotSupported.New() 81 } 82 83 func getUpdatableTable(t sql.Table) (sql.UpdatableTable, error) { 84 switch t := t.(type) { 85 case sql.UpdatableTable: 86 return t, nil 87 case sql.TableWrapper: 88 return getUpdatableTable(t.Underlying()) 89 default: 90 return nil, ErrUpdateNotSupported.New() 91 } 92 } 93 94 // GetDatabase returns the first database found in the node tree given 95 func GetDatabase(node sql.Node) sql.Database { 96 switch node := node.(type) { 97 case *IndexedTableAccess: 98 return GetDatabase(node.TableNode) 99 case *ResolvedTable: 100 return node.Database() 101 case *UnresolvedTable: 102 return node.Database() 103 } 104 105 for _, child := range node.Children() { 106 return GetDatabase(child) 107 } 108 109 return nil 110 } 111 112 func (u *Update) Checks() sql.CheckConstraints { 113 return u.checks 114 } 115 116 func (u *Update) WithChecks(checks sql.CheckConstraints) sql.Node { 117 ret := *u 118 ret.checks = checks 119 return &ret 120 } 121 122 // DB returns the database being updated. |Database| is already used by another interface we implement. 123 func (u *Update) DB() sql.Database { 124 return GetDatabase(u.Child) 125 } 126 127 func (u *Update) IsReadOnly() bool { 128 return false 129 } 130 131 func (u *Update) Database() string { 132 db := GetDatabase(u.Child) 133 if db == nil { 134 return "" 135 } 136 return db.Name() 137 } 138 139 func (u *Update) Expressions() []sql.Expression { 140 return u.checks.ToExpressions() 141 } 142 143 func (u *Update) Resolved() bool { 144 return u.Child.Resolved() && expression.ExpressionsResolved(u.checks.ToExpressions()...) 145 } 146 147 func (u Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { 148 if len(newExprs) != len(u.checks) { 149 return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), len(u.checks)) 150 } 151 152 var err error 153 u.checks, err = u.checks.FromExpressions(newExprs) 154 if err != nil { 155 return nil, err 156 } 157 158 return &u, nil 159 } 160 161 // UpdateInfo is the Info for OKResults returned by Update nodes. 162 type UpdateInfo struct { 163 Matched, Updated, Warnings int 164 } 165 166 // String implements fmt.Stringer 167 func (ui UpdateInfo) String() string { 168 return fmt.Sprintf("Rows matched: %d Changed: %d Warnings: %d", ui.Matched, ui.Updated, ui.Warnings) 169 } 170 171 // WithChildren implements the Node interface. 172 func (u *Update) WithChildren(children ...sql.Node) (sql.Node, error) { 173 if len(children) != 1 { 174 return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) 175 } 176 np := *u 177 np.Child = children[0] 178 return &np, nil 179 } 180 181 // CheckPrivileges implements the interface sql.Node. 182 func (u *Update) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 183 //TODO: If column values are retrieved then the SELECT privilege is required 184 // For example: "UPDATE table SET x = y + 1 WHERE z > 0" 185 // We would need SELECT privileges on both the "y" and "z" columns as they're retrieving values 186 subject := sql.PrivilegeCheckSubject{ 187 Database: CheckPrivilegeNameForDatabase(u.DB()), 188 Table: getTableName(u.Child), 189 } 190 // TODO: this needs a real database, fix it 191 return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update)) 192 } 193 194 // CollationCoercibility implements the interface sql.CollationCoercible. 195 func (*Update) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 196 return sql.Collation_binary, 7 197 } 198 199 func (u *Update) String() string { 200 pr := sql.NewTreePrinter() 201 _ = pr.WriteNode("Update") 202 _ = pr.WriteChildren(u.Child.String()) 203 return pr.String() 204 } 205 206 func (u *Update) DebugString() string { 207 pr := sql.NewTreePrinter() 208 _ = pr.WriteNode("Update") 209 _ = pr.WriteChildren(sql.DebugString(u.Child)) 210 return pr.String() 211 }