github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/update.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"context"
    15  	"sync"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/sql/rowcontainer"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    20  	"github.com/cockroachdb/cockroach/pkg/util/tracing"
    21  	"github.com/cockroachdb/errors"
    22  )
    23  
    24  var updateNodePool = sync.Pool{
    25  	New: func() interface{} {
    26  		return &updateNode{}
    27  	},
    28  }
    29  
    30  type updateNode struct {
    31  	source planNode
    32  
    33  	// columns is set if this UPDATE is returning any rows, to be
    34  	// consumed by a renderNode upstream. This occurs when there is a
    35  	// RETURNING clause with some scalar expressions.
    36  	columns sqlbase.ResultColumns
    37  
    38  	run updateRun
    39  }
    40  
    41  // updateRun contains the run-time state of updateNode during local execution.
    42  type updateRun struct {
    43  	tu         tableUpdater
    44  	rowsNeeded bool
    45  
    46  	checkOrds checkSet
    47  
    48  	// rowCount is the number of rows in the current batch.
    49  	rowCount int
    50  
    51  	// done informs a new call to BatchedNext() that the previous call to
    52  	// BatchedNext() has completed the work already.
    53  	done bool
    54  
    55  	// rows contains the accumulated result rows if rowsNeeded is set.
    56  	rows *rowcontainer.RowContainer
    57  
    58  	// traceKV caches the current KV tracing flag.
    59  	traceKV bool
    60  
    61  	// computedCols are the columns that need to be (re-)computed as
    62  	// the result of updating some of the columns in updateCols.
    63  	computedCols []sqlbase.ColumnDescriptor
    64  	// computeExprs are the expressions to evaluate to re-compute the
    65  	// columns in computedCols.
    66  	computeExprs []tree.TypedExpr
    67  	// iVarContainerForComputedCols is used as a temporary buffer that
    68  	// holds the updated values for every column in the source, to
    69  	// serve as input for indexed vars contained in the computeExprs.
    70  	iVarContainerForComputedCols sqlbase.RowIndexedVarContainer
    71  
    72  	// sourceSlots is the helper that maps RHS expressions to LHS targets.
    73  	// This is necessary because there may be fewer RHS expressions than
    74  	// LHS targets. For example, SET (a, b) = (SELECT 1,2) has:
    75  	// - 2 targets (a, b)
    76  	// - 1 source slot, the subquery (SELECT 1, 2).
    77  	// Each call to extractValues() on a sourceSlot will return 1 or more
    78  	// datums suitable for assignments. In the example above, the
    79  	// method would return 2 values.
    80  	sourceSlots []sourceSlot
    81  
    82  	// updateValues will hold the new values for every column
    83  	// mentioned in the LHS of the SET expressions, in the
    84  	// order specified by those SET expressions (thus potentially
    85  	// a different order than the source).
    86  	updateValues tree.Datums
    87  
    88  	// During the update, the expressions provided by the source plan
    89  	// contain the columns that are being assigned in the order
    90  	// specified by the table descriptor.
    91  	//
    92  	// For example, with UPDATE kv SET v=3, k=2, the source plan will
    93  	// provide the values in the order k, v (assuming this is the order
    94  	// the columns are defined in kv's descriptor).
    95  	//
    96  	// Then during the update, the columns are updated in the order of
    97  	// the setExprs (or, equivalently, the order of the sourceSlots),
    98  	// for the example above that would be v, k. The results
    99  	// are stored in updateValues above.
   100  	//
   101  	// Then at the end of the update, the values need to be presented
   102  	// back to the TableRowUpdater in the order of the table descriptor
   103  	// again.
   104  	//
   105  	// updateVals is the buffer for this 2nd stage.
   106  	// updateColsIdx maps the order of the 2nd stage into the order of the 3rd stage.
   107  	// This provides the inverse mapping of sourceSlots.
   108  	//
   109  	updateColsIdx map[sqlbase.ColumnID]int
   110  
   111  	// rowIdxToRetIdx is the mapping from the columns in ru.FetchCols to the
   112  	// columns in the resultRowBuffer. A value of -1 is used to indicate
   113  	// that the column at that index is not part of the resultRowBuffer
   114  	// of the mutation. Otherwise, the value at the i-th index refers to the
   115  	// index of the resultRowBuffer where the i-th column is to be returned.
   116  	rowIdxToRetIdx []int
   117  
   118  	// numPassthrough is the number of columns in addition to the set of
   119  	// columns of the target table being returned, that we must pass through
   120  	// from the input node.
   121  	numPassthrough int
   122  }
   123  
   124  // maxUpdateBatchSize is the max number of entries in the KV batch for
   125  // the update operation (including secondary index updates, FK
   126  // cascading updates, etc), before the current KV batch is executed
   127  // and a new batch is started.
   128  const maxUpdateBatchSize = 10000
   129  
   130  func (u *updateNode) startExec(params runParams) error {
   131  	// cache traceKV during execution, to avoid re-evaluating it for every row.
   132  	u.run.traceKV = params.p.ExtendedEvalContext().Tracing.KVTracingEnabled()
   133  
   134  	if u.run.rowsNeeded {
   135  		u.run.rows = rowcontainer.NewRowContainer(
   136  			params.EvalContext().Mon.MakeBoundAccount(),
   137  			sqlbase.ColTypeInfoFromResCols(u.columns), 0)
   138  	}
   139  	return u.run.tu.init(params.ctx, params.p.txn, params.EvalContext())
   140  }
   141  
   142  // Next is required because batchedPlanNode inherits from planNode, but
   143  // batchedPlanNode doesn't really provide it. See the explanatory comments
   144  // in plan_batch.go.
   145  func (u *updateNode) Next(params runParams) (bool, error) { panic("not valid") }
   146  
   147  // Values is required because batchedPlanNode inherits from planNode, but
   148  // batchedPlanNode doesn't really provide it. See the explanatory comments
   149  // in plan_batch.go.
   150  func (u *updateNode) Values() tree.Datums { panic("not valid") }
   151  
   152  // BatchedNext implements the batchedPlanNode interface.
   153  func (u *updateNode) BatchedNext(params runParams) (bool, error) {
   154  	if u.run.done {
   155  		return false, nil
   156  	}
   157  
   158  	tracing.AnnotateTrace()
   159  
   160  	// Advance one batch. First, clear the current batch.
   161  	u.run.rowCount = 0
   162  	if u.run.rows != nil {
   163  		u.run.rows.Clear(params.ctx)
   164  	}
   165  	// Now consume/accumulate the rows for this batch.
   166  	lastBatch := false
   167  	for {
   168  		if err := params.p.cancelChecker.Check(); err != nil {
   169  			return false, err
   170  		}
   171  
   172  		// Advance one individual row.
   173  		if next, err := u.source.Next(params); !next {
   174  			lastBatch = true
   175  			if err != nil {
   176  				return false, err
   177  			}
   178  			break
   179  		}
   180  
   181  		// Process the update for the current source row, potentially
   182  		// accumulating the result row for later.
   183  		if err := u.processSourceRow(params, u.source.Values()); err != nil {
   184  			return false, err
   185  		}
   186  
   187  		u.run.rowCount++
   188  
   189  		// Are we done yet with the current batch?
   190  		if u.run.tu.curBatchSize() >= maxUpdateBatchSize {
   191  			break
   192  		}
   193  	}
   194  
   195  	if u.run.rowCount > 0 {
   196  		if err := u.run.tu.atBatchEnd(params.ctx, u.run.traceKV); err != nil {
   197  			return false, err
   198  		}
   199  
   200  		if !lastBatch {
   201  			// We only run/commit the batch if there were some rows processed
   202  			// in this batch.
   203  			if err := u.run.tu.flushAndStartNewBatch(params.ctx); err != nil {
   204  				return false, err
   205  			}
   206  		}
   207  	}
   208  
   209  	if lastBatch {
   210  		if _, err := u.run.tu.finalize(params.ctx, u.run.traceKV); err != nil {
   211  			return false, err
   212  		}
   213  		// Remember we're done for the next call to BatchedNext().
   214  		u.run.done = true
   215  	}
   216  
   217  	// Possibly initiate a run of CREATE STATISTICS.
   218  	params.ExecCfg().StatsRefresher.NotifyMutation(
   219  		u.run.tu.tableDesc().ID,
   220  		u.run.rowCount,
   221  	)
   222  
   223  	return u.run.rowCount > 0, nil
   224  }
   225  
   226  // processSourceRow processes one row from the source for update and, if
   227  // result rows are needed, saves it in the result row container.
   228  func (u *updateNode) processSourceRow(params runParams, sourceVals tree.Datums) error {
   229  	// sourceVals contains values for the columns from the table, in the order of the
   230  	// table descriptor. (One per column in u.tw.ru.FetchCols)
   231  	//
   232  	// And then after that, all the extra expressions potentially added via
   233  	// a renderNode for the RHS of the assignments.
   234  
   235  	// oldValues is the prefix of sourceVals that corresponds to real
   236  	// stored columns in the table, that is, excluding the RHS assignment
   237  	// expressions.
   238  	oldValues := sourceVals[:len(u.run.tu.ru.FetchCols)]
   239  
   240  	// valueIdx is used in the loop below to map sourceSlots to
   241  	// entries in updateValues.
   242  	valueIdx := 0
   243  
   244  	// Propagate the values computed for the RHS expressions into
   245  	// updateValues at the right positions. The positions in
   246  	// updateValues correspond to the columns named in the LHS
   247  	// operands for SET.
   248  	for _, slot := range u.run.sourceSlots {
   249  		for _, value := range slot.extractValues(sourceVals) {
   250  			u.run.updateValues[valueIdx] = value
   251  			valueIdx++
   252  		}
   253  	}
   254  
   255  	// At this point, we have populated updateValues with the result of
   256  	// computing the RHS for every assignment.
   257  	//
   258  
   259  	if len(u.run.computeExprs) > 0 {
   260  		// We now need to (re-)compute the computed column values, using
   261  		// the updated values above as input.
   262  		//
   263  		// This needs to happen in the context of a row containing all the
   264  		// table's columns as if they had been updated already. This is not
   265  		// yet reflected neither by oldValues (which contain non-updated values)
   266  		// nor updateValues (which contain only those columns mentioned in the SET LHS).
   267  		//
   268  		// So we need to construct a buffer that groups them together.
   269  		// iVarContainerForComputedCols does this.
   270  		copy(u.run.iVarContainerForComputedCols.CurSourceRow, oldValues)
   271  		for i := range u.run.tu.ru.UpdateCols {
   272  			id := u.run.tu.ru.UpdateCols[i].ID
   273  			u.run.iVarContainerForComputedCols.CurSourceRow[u.run.tu.ru.FetchColIDtoRowIndex[id]] = u.run.updateValues[i]
   274  		}
   275  
   276  		// Now (re-)compute the computed columns.
   277  		// Note that it's safe to do this in any order, because we currently
   278  		// prevent computed columns from depending on other computed columns.
   279  		params.EvalContext().PushIVarContainer(&u.run.iVarContainerForComputedCols)
   280  		for i := range u.run.computedCols {
   281  			d, err := u.run.computeExprs[i].Eval(params.EvalContext())
   282  			if err != nil {
   283  				params.EvalContext().IVarContainer = nil
   284  				return errors.Wrapf(err, "computed column %s", tree.ErrString((*tree.Name)(&u.run.computedCols[i].Name)))
   285  			}
   286  			u.run.updateValues[u.run.updateColsIdx[u.run.computedCols[i].ID]] = d
   287  		}
   288  		params.EvalContext().PopIVarContainer()
   289  	}
   290  
   291  	// Verify the schema constraints. For consistency with INSERT/UPSERT
   292  	// and compatibility with PostgreSQL, we must do this before
   293  	// processing the CHECK constraints.
   294  	if err := enforceLocalColumnConstraints(u.run.updateValues, u.run.tu.ru.UpdateCols); err != nil {
   295  		return err
   296  	}
   297  
   298  	// Run the CHECK constraints, if any. CheckHelper will either evaluate the
   299  	// constraints itself, or else inspect boolean columns from the input that
   300  	// contain the results of evaluation.
   301  	if !u.run.checkOrds.Empty() {
   302  		checkVals := sourceVals[len(u.run.tu.ru.FetchCols)+len(u.run.tu.ru.UpdateCols)+u.run.numPassthrough:]
   303  		if err := checkMutationInput(u.run.tu.tableDesc(), u.run.checkOrds, checkVals); err != nil {
   304  			return err
   305  		}
   306  	}
   307  
   308  	// Queue the insert in the KV batch.
   309  	newValues, err := u.run.tu.rowForUpdate(params.ctx, oldValues, u.run.updateValues, u.run.traceKV)
   310  	if err != nil {
   311  		return err
   312  	}
   313  
   314  	// If result rows need to be accumulated, do it.
   315  	if u.run.rows != nil {
   316  		// The new values can include all columns, the construction of the
   317  		// values has used execinfra.ScanVisibilityPublicAndNotPublic so the
   318  		// values may contain additional columns for every newly added column
   319  		// not yet visible. We do not want them to be available for RETURNING.
   320  		//
   321  		// MakeUpdater guarantees that the first columns of the new values
   322  		// are those specified u.columns.
   323  		resultValues := make([]tree.Datum, len(u.columns))
   324  		largestRetIdx := -1
   325  		for i := range u.run.rowIdxToRetIdx {
   326  			retIdx := u.run.rowIdxToRetIdx[i]
   327  			if retIdx >= 0 {
   328  				if retIdx >= largestRetIdx {
   329  					largestRetIdx = retIdx
   330  				}
   331  				resultValues[retIdx] = newValues[i]
   332  			}
   333  		}
   334  
   335  		// At this point we've extracted all the RETURNING values that are part
   336  		// of the target table. We must now extract the columns in the RETURNING
   337  		// clause that refer to other tables (from the FROM clause of the update).
   338  		if u.run.numPassthrough > 0 {
   339  			passthroughBegin := len(u.run.tu.ru.FetchCols) + len(u.run.tu.ru.UpdateCols)
   340  			passthroughEnd := passthroughBegin + u.run.numPassthrough
   341  			passthroughValues := sourceVals[passthroughBegin:passthroughEnd]
   342  
   343  			for i := 0; i < u.run.numPassthrough; i++ {
   344  				largestRetIdx++
   345  				resultValues[largestRetIdx] = passthroughValues[i]
   346  			}
   347  		}
   348  
   349  		if _, err := u.run.rows.AddRow(params.ctx, resultValues); err != nil {
   350  			return err
   351  		}
   352  	}
   353  
   354  	return nil
   355  }
   356  
   357  // BatchedCount implements the batchedPlanNode interface.
   358  func (u *updateNode) BatchedCount() int { return u.run.rowCount }
   359  
   360  // BatchedCount implements the batchedPlanNode interface.
   361  func (u *updateNode) BatchedValues(rowIdx int) tree.Datums { return u.run.rows.At(rowIdx) }
   362  
   363  func (u *updateNode) Close(ctx context.Context) {
   364  	u.source.Close(ctx)
   365  	if u.run.rows != nil {
   366  		u.run.rows.Close(ctx)
   367  		u.run.rows = nil
   368  	}
   369  	u.run.tu.close(ctx)
   370  	*u = updateNode{}
   371  	updateNodePool.Put(u)
   372  }
   373  
   374  func (u *updateNode) enableAutoCommit() {
   375  	u.run.tu.enableAutoCommit()
   376  }
   377  
   378  // sourceSlot abstracts the idea that our update sources can either be tuples
   379  // or scalars. Tuples are for cases such as SET (a, b) = (1, 2) or SET (a, b) =
   380  // (SELECT 1, 2), and scalars are for situations like SET a = b. A sourceSlot
   381  // represents how to extract and type-check the results of the right-hand side
   382  // of a single SET statement. We could treat everything as tuples, including
   383  // scalars as tuples of size 1, and eliminate this indirection, but that makes
   384  // the query plan more complex.
   385  type sourceSlot interface {
   386  	// extractValues returns a slice of the values this slot is responsible for,
   387  	// as extracted from the row of results.
   388  	extractValues(resultRow tree.Datums) tree.Datums
   389  	// checkColumnTypes compares the types of the results that this slot refers to to the types of
   390  	// the columns those values will be assigned to. It returns an error if those types don't match up.
   391  	checkColumnTypes(row []tree.TypedExpr) error
   392  }
   393  
   394  type scalarSlot struct {
   395  	column      sqlbase.ColumnDescriptor
   396  	sourceIndex int
   397  }
   398  
   399  func (ss scalarSlot) extractValues(row tree.Datums) tree.Datums {
   400  	return row[ss.sourceIndex : ss.sourceIndex+1]
   401  }
   402  
   403  func (ss scalarSlot) checkColumnTypes(row []tree.TypedExpr) error {
   404  	renderedResult := row[ss.sourceIndex]
   405  	typ := renderedResult.ResolvedType()
   406  	return sqlbase.CheckDatumTypeFitsColumnType(&ss.column, typ)
   407  }
   408  
   409  // enforceLocalColumnConstraints asserts the column constraints that
   410  // do not require data validation from other sources than the row data
   411  // itself. This includes:
   412  // - rejecting null values in non-nullable columns;
   413  // - checking width constraints from the column type;
   414  // - truncating results to the requested precision (not width).
   415  // Note: the second point is what distinguishes this operation
   416  // from a regular SQL cast -- here widths are checked, not
   417  // used to truncate the value silently.
   418  //
   419  // The row buffer is modified in-place with the result of the
   420  // checks.
   421  func enforceLocalColumnConstraints(row tree.Datums, cols []sqlbase.ColumnDescriptor) error {
   422  	for i := range cols {
   423  		col := &cols[i]
   424  		if !col.Nullable && row[i] == tree.DNull {
   425  			return sqlbase.NewNonNullViolationError(col.Name)
   426  		}
   427  		outVal, err := sqlbase.AdjustValueToColumnType(col.Type, row[i], &col.Name)
   428  		if err != nil {
   429  			return err
   430  		}
   431  		row[i] = outVal
   432  	}
   433  	return nil
   434  }