github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/update_join.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  )
    20  
    21  type UpdateJoin struct {
    22  	Updaters map[string]sql.RowUpdater
    23  	UnaryNode
    24  }
    25  
    26  // NewUpdateJoin returns an *UpdateJoin node.
    27  func NewUpdateJoin(editorMap map[string]sql.RowUpdater, child sql.Node) *UpdateJoin {
    28  	return &UpdateJoin{
    29  		Updaters:  editorMap,
    30  		UnaryNode: UnaryNode{Child: child},
    31  	}
    32  }
    33  
    34  var _ sql.Node = (*UpdateJoin)(nil)
    35  var _ sql.CollationCoercible = (*UpdateJoin)(nil)
    36  
    37  // String implements the sql.Node interface.
    38  func (u *UpdateJoin) String() string {
    39  	pr := sql.NewTreePrinter()
    40  	_ = pr.WriteNode("Update Join")
    41  	_ = pr.WriteChildren(u.Child.String())
    42  	return pr.String()
    43  }
    44  
    45  // DebugString implements the sql.Node interface.
    46  func (u *UpdateJoin) DebugString() string {
    47  	pr := sql.NewTreePrinter()
    48  	_ = pr.WriteNode("Update Join")
    49  	_ = pr.WriteChildren(sql.DebugString(u.Child))
    50  	return pr.String()
    51  }
    52  
    53  // GetUpdatable returns an updateJoinTable which implements sql.UpdatableTable.
    54  func (u *UpdateJoin) GetUpdatable() sql.UpdatableTable {
    55  	return &updatableJoinTable{
    56  		updaters: u.Updaters,
    57  		joinNode: u.Child.(*UpdateSource).Child,
    58  	}
    59  }
    60  
    61  // WithChildren implements the sql.Node interface.
    62  func (u *UpdateJoin) WithChildren(children ...sql.Node) (sql.Node, error) {
    63  	if len(children) != 1 {
    64  		return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1)
    65  	}
    66  
    67  	return NewUpdateJoin(u.Updaters, children[0]), nil
    68  }
    69  
    70  func (u *UpdateJoin) IsReadOnly() bool {
    71  	return false
    72  }
    73  
    74  // CheckPrivileges implements the interface sql.Node.
    75  func (u *UpdateJoin) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    76  	return u.Child.CheckPrivileges(ctx, opChecker)
    77  }
    78  
    79  // CollationCoercibility implements the interface sql.CollationCoercible.
    80  func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    81  	return sql.GetCoercibility(ctx, u.Child)
    82  }
    83  
    84  // updatableJoinTable manages the update of multiple tables.
    85  type updatableJoinTable struct {
    86  	updaters map[string]sql.RowUpdater
    87  	joinNode sql.Node
    88  }
    89  
    90  var _ sql.UpdatableTable = (*updatableJoinTable)(nil)
    91  
    92  // Partitions implements the sql.UpdatableTable interface.
    93  func (u *updatableJoinTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
    94  	panic("this method should not be called")
    95  }
    96  
    97  // PartitionsRows implements the sql.UpdatableTable interface.
    98  func (u *updatableJoinTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
    99  	panic("this method should not be called")
   100  }
   101  
   102  // Name implements the sql.UpdatableTable interface.
   103  func (u *updatableJoinTable) Name() string {
   104  	panic("this method should not be called")
   105  }
   106  
   107  // String implements the sql.UpdatableTable interface.
   108  func (u *updatableJoinTable) String() string {
   109  	panic("this method should not be called")
   110  }
   111  
   112  // Schema implements the sql.UpdatableTable interface.
   113  func (u *updatableJoinTable) Schema() sql.Schema {
   114  	return u.joinNode.Schema()
   115  }
   116  
   117  // Collation implements the sql.Table interface.
   118  func (u *updatableJoinTable) Collation() sql.CollationID {
   119  	return sql.Collation_Default
   120  }
   121  
   122  // Updater implements the sql.UpdatableTable interface.
   123  func (u *updatableJoinTable) Updater(ctx *sql.Context) sql.RowUpdater {
   124  	return &updatableJoinUpdater{
   125  		updaterMap: u.updaters,
   126  		schemaMap:  RecreateTableSchemaFromJoinSchema(u.joinNode.Schema()),
   127  		joinSchema: u.joinNode.Schema(),
   128  	}
   129  }
   130  
   131  // RecreateTableSchemaFromJoinSchema takes a join schema and recreates each individual tables schema.
   132  func RecreateTableSchemaFromJoinSchema(joinSchema sql.Schema) map[string]sql.Schema {
   133  	ret := make(map[string]sql.Schema, 0)
   134  
   135  	for _, c := range joinSchema {
   136  		potential, exists := ret[c.Source]
   137  		if exists {
   138  			ret[c.Source] = append(potential, c)
   139  		} else {
   140  			ret[c.Source] = sql.Schema{c}
   141  		}
   142  	}
   143  
   144  	return ret
   145  }
   146  
   147  // updatableJoinUpdater manages the process of taking a join row and allocating the respective updates to each updatable
   148  // table.
   149  type updatableJoinUpdater struct {
   150  	updaterMap map[string]sql.RowUpdater
   151  	schemaMap  map[string]sql.Schema
   152  	joinSchema sql.Schema
   153  }
   154  
   155  var _ sql.RowUpdater = (*updatableJoinUpdater)(nil)
   156  
   157  // StatementBegin implements the sql.TableEditor interface.
   158  func (u *updatableJoinUpdater) StatementBegin(ctx *sql.Context) {
   159  	for _, v := range u.updaterMap {
   160  		v.StatementBegin(ctx)
   161  	}
   162  }
   163  
   164  // DiscardChanges implements the sql.TableEditor interface.
   165  func (u *updatableJoinUpdater) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
   166  	for _, v := range u.updaterMap {
   167  		err := v.DiscardChanges(ctx, errorEncountered)
   168  		if err != nil {
   169  			return err
   170  		}
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  // StatementComplete implements the sql.TableEditor interface.
   177  func (u *updatableJoinUpdater) StatementComplete(ctx *sql.Context) error {
   178  	for _, v := range u.updaterMap {
   179  		err := v.StatementComplete(ctx)
   180  
   181  		if err != nil {
   182  			return err
   183  		}
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  // Update implements the sql.RowUpdater interface.
   190  func (u *updatableJoinUpdater) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
   191  	tableToOldRowMap := SplitRowIntoTableRowMap(old, u.joinSchema)
   192  	tableToNewRowMap := SplitRowIntoTableRowMap(new, u.joinSchema)
   193  
   194  	for tableName, updater := range u.updaterMap {
   195  		oldRow := tableToOldRowMap[tableName]
   196  		newRow := tableToNewRowMap[tableName]
   197  		schema := u.schemaMap[tableName]
   198  
   199  		eq, err := oldRow.Equals(newRow, schema)
   200  		if err != nil {
   201  			return err
   202  		}
   203  
   204  		if !eq {
   205  			err = updater.Update(ctx, oldRow, newRow)
   206  		}
   207  
   208  		if err != nil {
   209  			return err
   210  		}
   211  	}
   212  
   213  	return nil
   214  }
   215  
   216  // SplitRowIntoTableRowMap takes a join table row and breaks into a map of tables and their respective row.
   217  func SplitRowIntoTableRowMap(row sql.Row, joinSchema sql.Schema) map[string]sql.Row {
   218  	ret := make(map[string]sql.Row)
   219  
   220  	if len(joinSchema) == 0 {
   221  		return ret
   222  	}
   223  
   224  	currentTable := joinSchema[0].Source
   225  	currentRow := sql.Row{row[0]}
   226  
   227  	for i := 1; i < len(joinSchema); i++ {
   228  		c := joinSchema[i]
   229  
   230  		if c.Source != currentTable {
   231  			ret[currentTable] = currentRow
   232  			currentTable = c.Source
   233  			currentRow = sql.Row{row[i]}
   234  		} else {
   235  			currentTable = c.Source
   236  			currentRow = append(currentRow, row[i])
   237  		}
   238  	}
   239  
   240  	ret[currentTable] = currentRow
   241  
   242  	return ret
   243  }
   244  
   245  // Close implements the sql.RowUpdater interface.
   246  func (u *updatableJoinUpdater) Close(ctx *sql.Context) error {
   247  	for _, updater := range u.updaterMap {
   248  		err := updater.Close(ctx)
   249  		if err != nil {
   250  			return err
   251  		}
   252  	}
   253  
   254  	return nil
   255  }