github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/update_join.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "github.com/dolthub/go-mysql-server/sql" 19 ) 20 21 type UpdateJoin struct { 22 Updaters map[string]sql.RowUpdater 23 UnaryNode 24 } 25 26 // NewUpdateJoin returns an *UpdateJoin node. 27 func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin { 28 return &UpdateJoin{ 29 Updaters: editorMap, 30 UnaryNode: UnaryNode{Child: child}, 31 } 32 } 33 34 var _ sql.Node = (*UpdateJoin)(nil) 35 var _ sql.CollationCoercible = (*UpdateJoin)(nil) 36 37 // String implements the sql.Node interface. 38 func (u *UpdateJoin) String() string { 39 pr := sql.NewTreePrinter() 40 _ = pr.WriteNode("Update Join") 41 _ = pr.WriteChildren(u.Child.String()) 42 return pr.String() 43 } 44 45 // DebugString implements the sql.Node interface. 46 func (u *UpdateJoin) DebugString() string { 47 pr := sql.NewTreePrinter() 48 _ = pr.WriteNode("Update Join") 49 _ = pr.WriteChildren(sql.DebugString(u.Child)) 50 return pr.String() 51 } 52 53 // GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable. 54 func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable { 55 return &updatableJoinTable{ 56 updaters: u.Updaters, 57 joinNode: u.Child.(*UpdateSource).Child, 58 } 59 } 60 61 // WithChildren implements the sql.Node interface. 62 func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) { 63 if len(children) != 1 { 64 return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) 65 } 66 67 return NewUpdateJoin(u.Updaters, children[0]), nil 68 } 69 70 func (u *UpdateJoin) IsReadOnly() bool { 71 return false 72 } 73 74 // CheckPrivileges implements the interface sql.Node. 75 func (u *UpdateJoin) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 76 return u.Child.CheckPrivileges(ctx, opChecker) 77 } 78 79 // CollationCoercibility implements the interface sql.CollationCoercible. 80 func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 81 return sql.GetCoercibility(ctx, u.Child) 82 } 83 84 // updatableJoinTable manages the update of multiple tables. 85 type updatableJoinTable struct { 86 updaters map[string]sql.RowUpdater 87 joinNode sql.Node 88 } 89 90 var _ sql.UpdatableTable = (*updatableJoinTable)(nil) 91 92 // Partitions implements the sql.UpdatableTable interface. 93 func (u *updatableJoinTable) Partitions(context *sql.Context) (sql.PartitionIter, error) { 94 panic("this method should not be called") 95 } 96 97 // PartitionsRows implements the sql.UpdatableTable interface. 98 func (u *updatableJoinTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) { 99 panic("this method should not be called") 100 } 101 102 // Name implements the sql.UpdatableTable interface. 103 func (u *updatableJoinTable) Name() string { 104 panic("this method should not be called") 105 } 106 107 // String implements the sql.UpdatableTable interface. 108 func (u *updatableJoinTable) String() string { 109 panic("this method should not be called") 110 } 111 112 // Schema implements the sql.UpdatableTable interface. 113 func (u *updatableJoinTable) Schema() sql.Schema { 114 return u.joinNode.Schema() 115 } 116 117 // Collation implements the sql.Table interface. 118 func (u *updatableJoinTable) Collation() sql.CollationID { 119 return sql.Collation_Default 120 } 121 122 // Updater implements the sql.UpdatableTable interface. 123 func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater { 124 return &updatableJoinUpdater{ 125 updaterMap: u.updaters, 126 schemaMap: RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()), 127 joinSchema: u.joinNode.Schema(), 128 } 129 } 130 131 // RecreateTableSchemaFromJoinSchema takes a join schema and recreates each individual tables schema. 132 func RecreateTableSchemaFromJoinSchema(joinSchema sql.Schema) map[string]sql.Schema { 133 ret := make(map[string]sql.Schema, 0) 134 135 for _, c := range joinSchema { 136 potential, exists := ret[c.Source] 137 if exists { 138 ret[c.Source] = append(potential, c) 139 } else { 140 ret[c.Source] = sql.Schema{c} 141 } 142 } 143 144 return ret 145 } 146 147 // updatableJoinUpdater manages the process of taking a join row and allocating the respective updates to each updatable 148 // table. 149 type updatableJoinUpdater struct { 150 updaterMap map[string]sql.RowUpdater 151 schemaMap map[string]sql.Schema 152 joinSchema sql.Schema 153 } 154 155 var _ sql.RowUpdater = (*updatableJoinUpdater)(nil) 156 157 // StatementBegin implements the sql.TableEditor interface. 158 func (u *updatableJoinUpdater) StatementBegin(ctx *sql.Context) { 159 for _, v := range u.updaterMap { 160 v.StatementBegin(ctx) 161 } 162 } 163 164 // DiscardChanges implements the sql.TableEditor interface. 165 func (u *updatableJoinUpdater) DiscardChanges(ctx *sql.Context, errorEncountered error) error { 166 for _, v := range u.updaterMap { 167 err := v.DiscardChanges(ctx, errorEncountered) 168 if err != nil { 169 return err 170 } 171 } 172 173 return nil 174 } 175 176 // StatementComplete implements the sql.TableEditor interface. 177 func (u *updatableJoinUpdater) StatementComplete(ctx *sql.Context) error { 178 for _, v := range u.updaterMap { 179 err := v.StatementComplete(ctx) 180 181 if err != nil { 182 return err 183 } 184 } 185 186 return nil 187 } 188 189 // Update implements the sql.RowUpdater interface. 190 func (u *updatableJoinUpdater) Update(ctx *sql.Context, old sql.Row, new sql.Row) error { 191 tableToOldRowMap := SplitRowIntoTableRowMap(old, u.joinSchema) 192 tableToNewRowMap := SplitRowIntoTableRowMap(new, u.joinSchema) 193 194 for tableName, updater := range u.updaterMap { 195 oldRow := tableToOldRowMap[tableName] 196 newRow := tableToNewRowMap[tableName] 197 schema := u.schemaMap[tableName] 198 199 eq, err := oldRow.Equals(newRow, schema) 200 if err != nil { 201 return err 202 } 203 204 if !eq { 205 err = updater.Update(ctx, oldRow, newRow) 206 } 207 208 if err != nil { 209 return err 210 } 211 } 212 213 return nil 214 } 215 216 // SplitRowIntoTableRowMap takes a join table row and breaks into a map of tables and their respective row. 217 func SplitRowIntoTableRowMap(row sql.Row, joinSchema sql.Schema) map[string]sql.Row { 218 ret := make(map[string]sql.Row) 219 220 if len(joinSchema) == 0 { 221 return ret 222 } 223 224 currentTable := joinSchema[0].Source 225 currentRow := sql.Row{row[0]} 226 227 for i := 1; i < len(joinSchema); i++ { 228 c := joinSchema[i] 229 230 if c.Source != currentTable { 231 ret[currentTable] = currentRow 232 currentTable = c.Source 233 currentRow = sql.Row{row[i]} 234 } else { 235 currentTable = c.Source 236 currentRow = append(currentRow, row[i]) 237 } 238 } 239 240 ret[currentTable] = currentRow 241 242 return ret 243 } 244 245 // Close implements the sql.RowUpdater interface. 246 func (u *updatableJoinUpdater) Close(ctx *sql.Context) error { 247 for _, updater := range u.updaterMap { 248 err := updater.Close(ctx) 249 if err != nil { 250 return err 251 } 252 } 253 254 return nil 255 }