github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/colexec/case.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package colexec
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/col/coldata"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/colexecbase/colexecerror"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/colmem"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/execinfra"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    23  )
    24  
    25  type caseOp struct {
    26  	allocator *colmem.Allocator
    27  	buffer    *bufferOp
    28  
    29  	caseOps []colexecbase.Operator
    30  	elseOp  colexecbase.Operator
    31  
    32  	thenIdxs  []int
    33  	outputIdx int
    34  	typ       *types.T
    35  
    36  	// origSel is a buffer used to keep track of the original selection vector of
    37  	// the input batch. We need to do this because we're going to destructively
    38  	// modify the selection vector in order to do the work of the case statement.
    39  	origSel []int
    40  	// prevSel is a buffer used to keep track of the selection vector before
    41  	// running a case arm (i.e. "previous to the current case arm"). We need to
    42  	// keep track of it because case arm will modify the selection vector of the
    43  	// batch, and then we need to figure out which tuples have not been matched
    44  	// by the current case arm (those present in the "previous" sel and not
    45  	// present in the "current" sel).
    46  	prevSel []int
    47  }
    48  
    49  var _ InternalMemoryOperator = &caseOp{}
    50  
    51  func (c *caseOp) ChildCount(verbose bool) int {
    52  	return 1 + len(c.caseOps) + 1
    53  }
    54  
    55  func (c *caseOp) Child(nth int, verbose bool) execinfra.OpNode {
    56  	if nth == 0 {
    57  		return c.buffer
    58  	} else if nth < len(c.caseOps)+1 {
    59  		return c.caseOps[nth-1]
    60  	} else if nth == 1+len(c.caseOps) {
    61  		return c.elseOp
    62  	}
    63  	colexecerror.InternalError(fmt.Sprintf("invalid idx %d", nth))
    64  	// This code is unreachable, but the compiler cannot infer that.
    65  	return nil
    66  }
    67  
    68  func (c *caseOp) InternalMemoryUsage() int {
    69  	// We internally use two selection vectors, origSel and prevSel.
    70  	return 2 * colmem.SizeOfBatchSizeSelVector
    71  }
    72  
    73  // NewCaseOp returns an operator that runs a case statement.
    74  // buffer is a bufferOp that will return the input batch repeatedly.
    75  // caseOps is a list of operator chains, one per branch in the case statement.
    76  //   Each caseOp is connected to the input buffer op, and filters the input based
    77  //   on the case arm's WHEN condition, and then projects the remaining selected
    78  //   tuples based on the case arm's THEN condition.
    79  // elseOp is the ELSE condition.
    80  // whenCol is the index into the input batch to read from.
    81  // thenCol is the index into the output batch to write to.
    82  // typ is the type of the CASE expression.
    83  func NewCaseOp(
    84  	allocator *colmem.Allocator,
    85  	buffer colexecbase.Operator,
    86  	caseOps []colexecbase.Operator,
    87  	elseOp colexecbase.Operator,
    88  	thenIdxs []int,
    89  	outputIdx int,
    90  	typ *types.T,
    91  ) colexecbase.Operator {
    92  	return &caseOp{
    93  		allocator: allocator,
    94  		buffer:    buffer.(*bufferOp),
    95  		caseOps:   caseOps,
    96  		elseOp:    elseOp,
    97  		thenIdxs:  thenIdxs,
    98  		outputIdx: outputIdx,
    99  		typ:       typ,
   100  		origSel:   make([]int, coldata.BatchSize()),
   101  		prevSel:   make([]int, coldata.BatchSize()),
   102  	}
   103  }
   104  
   105  func (c *caseOp) Init() {
   106  	for i := range c.caseOps {
   107  		c.caseOps[i].Init()
   108  	}
   109  	c.elseOp.Init()
   110  }
   111  
   112  func (c *caseOp) Next(ctx context.Context) coldata.Batch {
   113  	c.buffer.advance(ctx)
   114  	origLen := c.buffer.batch.Length()
   115  	if origLen == 0 {
   116  		return coldata.ZeroBatch
   117  	}
   118  	var origHasSel bool
   119  	if sel := c.buffer.batch.Selection(); sel != nil {
   120  		origHasSel = true
   121  		copy(c.origSel, sel)
   122  	}
   123  
   124  	prevLen := origLen
   125  	prevHasSel := false
   126  	if sel := c.buffer.batch.Selection(); sel != nil {
   127  		prevHasSel = true
   128  		c.prevSel = c.prevSel[:origLen]
   129  		copy(c.prevSel[:origLen], sel[:origLen])
   130  	}
   131  	outputCol := c.buffer.batch.ColVec(c.outputIdx)
   132  	if outputCol.MaybeHasNulls() {
   133  		// We need to make sure that there are no left over null values in the
   134  		// output vector.
   135  		// Note: technically, this is not necessary because we're using
   136  		// Vec.Copy method when populating the output vector which itself
   137  		// handles the null values, but we want to be on the safe side, so we
   138  		// have this (at the moment) redundant resetting behavior.
   139  		outputCol.Nulls().UnsetNulls()
   140  	}
   141  	c.allocator.PerformOperation([]coldata.Vec{outputCol}, func() {
   142  		for i := range c.caseOps {
   143  			// Run the next case operator chain. It will project its THEN expression
   144  			// for all tuples that matched its WHEN expression and that were not
   145  			// already matched.
   146  			batch := c.caseOps[i].Next(ctx)
   147  			// The batch's projection column now additionally contains results for all
   148  			// of the tuples that passed the ith WHEN clause. The batch's selection
   149  			// vector is set to the same selection of tuples.
   150  			// Now, we must subtract this selection vector from the previous
   151  			// selection vector, so that the next operator gets to operate on the
   152  			// remaining set of tuples in the input that haven't matched an arm of the
   153  			// case statement.
   154  			// As an example, imagine the first WHEN op matched tuple 3. The following
   155  			// diagram shows the selection vector before running WHEN, after running
   156  			// WHEN, and then the desired selection vector after subtraction:
   157  			// - origSel
   158  			// | - selection vector after running WHEN
   159  			// | | - desired selection vector after subtraction
   160  			// | | |
   161  			// 1   1
   162  			// 2   2
   163  			// 3 3
   164  			// 4   4
   165  			toSubtract := batch.Selection()
   166  			toSubtract = toSubtract[:batch.Length()]
   167  			// toSubtract is now a selection vector containing all matched tuples of the
   168  			// current case arm.
   169  			var subtractIdx int
   170  			var curIdx int
   171  			if batch.Length() > 0 {
   172  				inputCol := batch.ColVec(c.thenIdxs[i])
   173  				// Copy the results into the output vector, using the toSubtract selection
   174  				// vector to copy only the elements that we actually wrote according to the
   175  				// current case arm.
   176  				outputCol.Copy(
   177  					coldata.CopySliceArgs{
   178  						SliceArgs: coldata.SliceArgs{
   179  							Src:         inputCol,
   180  							Sel:         toSubtract,
   181  							SrcStartIdx: 0,
   182  							SrcEndIdx:   len(toSubtract),
   183  						},
   184  						SelOnDest: true,
   185  					})
   186  				if prevHasSel {
   187  					// We have a previous selection vector, which represents the tuples
   188  					// that haven't yet been matched. Remove the ones that just matched
   189  					// from the previous selection vector.
   190  					for i := range c.prevSel {
   191  						if subtractIdx < len(toSubtract) && toSubtract[subtractIdx] == c.prevSel[i] {
   192  							// The ith element of the previous selection vector matched the
   193  							// current one in toSubtract. Skip writing this element, removing
   194  							// it from the previous selection vector.
   195  							subtractIdx++
   196  							continue
   197  						}
   198  						c.prevSel[curIdx] = c.prevSel[i]
   199  						curIdx++
   200  					}
   201  				} else {
   202  					// No selection vector means there have been no matches yet, and we were
   203  					// considering the entire batch of tuples for this case arm. Make a new
   204  					// selection vector with all of the tuples but the ones that just matched.
   205  					c.prevSel = c.prevSel[:cap(c.prevSel)]
   206  					for i := 0; i < origLen; i++ {
   207  						if subtractIdx < len(toSubtract) && toSubtract[subtractIdx] == i {
   208  							subtractIdx++
   209  							continue
   210  						}
   211  						c.prevSel[curIdx] = i
   212  						curIdx++
   213  					}
   214  				}
   215  				// Set the buffered batch into the desired state.
   216  				c.buffer.batch.SetLength(curIdx)
   217  				prevLen = curIdx
   218  				c.buffer.batch.SetSelection(true)
   219  				prevHasSel = true
   220  				copy(c.buffer.batch.Selection()[:curIdx], c.prevSel)
   221  				c.prevSel = c.prevSel[:curIdx]
   222  			} else {
   223  				// There were no matches with the current WHEN arm, so we simply need
   224  				// to restore the buffered batch into the previous state.
   225  				c.buffer.batch.SetLength(prevLen)
   226  				c.buffer.batch.SetSelection(prevHasSel)
   227  				if prevHasSel {
   228  					copy(c.buffer.batch.Selection()[:prevLen], c.prevSel)
   229  					c.prevSel = c.prevSel[:prevLen]
   230  				}
   231  			}
   232  			// Now our selection vector is set to exclude all the things that have
   233  			// matched so far. Reset the buffer and run the next case arm.
   234  			c.buffer.rewind()
   235  		}
   236  		// Finally, run the else operator, which will project into all tuples that
   237  		// are remaining in the selection vector (didn't match any case arms). Once
   238  		// that's done, restore the original selection vector and return the batch.
   239  		batch := c.elseOp.Next(ctx)
   240  		if batch.Length() > 0 {
   241  			inputCol := batch.ColVec(c.thenIdxs[len(c.thenIdxs)-1])
   242  			outputCol.Copy(
   243  				coldata.CopySliceArgs{
   244  					SliceArgs: coldata.SliceArgs{
   245  						Src:         inputCol,
   246  						Sel:         batch.Selection(),
   247  						SrcStartIdx: 0,
   248  						SrcEndIdx:   batch.Length(),
   249  					},
   250  					SelOnDest: true,
   251  				})
   252  		}
   253  	})
   254  	// Restore the original state of the buffered batch.
   255  	c.buffer.batch.SetLength(origLen)
   256  	c.buffer.batch.SetSelection(origHasSel)
   257  	if origHasSel {
   258  		copy(c.buffer.batch.Selection()[:origLen], c.origSel[:origLen])
   259  	}
   260  	return c.buffer.batch
   261  }