github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/onduplicatekey/on_duplicate_key.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package onduplicatekey
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/types"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    26  	plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    29  )
    30  
    31  const argName = "on_duplicate_key"
    32  
    33  func (arg *Argument) String(buf *bytes.Buffer) {
    34  	buf.WriteString(argName)
    35  	buf.WriteString(": processing on duplicate key before insert")
    36  }
    37  
    38  func (arg *Argument) Prepare(p *process.Process) error {
    39  	ap := arg
    40  	ap.ctr = &container{}
    41  	ap.ctr.InitReceiver(p, true)
    42  	return nil
    43  }
    44  
    45  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    46  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    47  		return vm.CancelResult, err
    48  	}
    49  
    50  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    51  	anal.Start()
    52  	defer anal.Stop()
    53  
    54  	ctr := arg.ctr
    55  	result := vm.NewCallResult()
    56  
    57  	for {
    58  		switch ctr.state {
    59  		case Build:
    60  			for {
    61  				bat, end, err := ctr.ReceiveFromAllRegs(anal)
    62  				if err != nil {
    63  					result.Status = vm.ExecStop
    64  					return result, nil
    65  				}
    66  
    67  				if end {
    68  					break
    69  				}
    70  				anal.Input(bat, arg.GetIsFirst())
    71  				err = resetInsertBatchForOnduplicateKey(proc, bat, arg)
    72  				if err != nil {
    73  					bat.Clean(proc.Mp())
    74  					return result, err
    75  				}
    76  
    77  			}
    78  			ctr.state = Eval
    79  
    80  		case Eval:
    81  			if ctr.rbat != nil {
    82  				anal.Output(ctr.rbat, arg.GetIsLast())
    83  			}
    84  			result.Batch = ctr.rbat
    85  			ctr.state = End
    86  			return result, nil
    87  
    88  		case End:
    89  			result.Batch = nil
    90  			result.Status = vm.ExecStop
    91  			return result, nil
    92  		}
    93  	}
    94  }
    95  
    96  func resetInsertBatchForOnduplicateKey(proc *process.Process, originBatch *batch.Batch, insertArg *Argument) error {
    97  	//get rowid vec index
    98  	rowIdIdx := int32(-1)
    99  	for _, idx := range insertArg.OnDuplicateIdx {
   100  		if originBatch.Vecs[idx].GetType().Oid == types.T_Rowid {
   101  			rowIdIdx = idx
   102  			break
   103  		}
   104  	}
   105  	if rowIdIdx == -1 {
   106  		return moerr.NewConstraintViolation(proc.Ctx, "can not find rowid when insert with on duplicate key")
   107  	}
   108  
   109  	insertColCount := int(insertArg.InsertColCount) //columns without hidden columns
   110  	if insertArg.ctr.rbat == nil {
   111  		insertArg.ctr.rbat = batch.NewWithSize(len(insertArg.Attrs))
   112  		insertArg.ctr.rbat.Attrs = insertArg.Attrs
   113  
   114  		insertArg.ctr.checkConflictBat = batch.NewWithSize(len(insertArg.Attrs))
   115  		insertArg.ctr.checkConflictBat.Attrs = append(insertArg.ctr.checkConflictBat.Attrs, insertArg.Attrs...)
   116  
   117  		for i, v := range originBatch.Vecs {
   118  			newVec := proc.GetVector(*v.GetType())
   119  			insertArg.ctr.rbat.SetVector(int32(i), newVec)
   120  
   121  			ckVec := proc.GetVector(*v.GetType())
   122  			insertArg.ctr.checkConflictBat.SetVector(int32(i), ckVec)
   123  		}
   124  	}
   125  
   126  	insertBatch := insertArg.ctr.rbat
   127  	checkConflictBatch := insertArg.ctr.checkConflictBat
   128  	attrs := make([]string, len(insertBatch.Attrs))
   129  	copy(attrs, insertBatch.Attrs)
   130  
   131  	updateExpr := insertArg.OnDuplicateExpr
   132  	oldRowIdVec := vector.MustFixedCol[types.Rowid](originBatch.Vecs[rowIdIdx])
   133  
   134  	checkExpressionExecutors, err := colexec.NewExpressionExecutorsFromPlanExpressions(proc, insertArg.UniqueColCheckExpr)
   135  	if err != nil {
   136  		return err
   137  	}
   138  	defer func() {
   139  		for _, executor := range checkExpressionExecutors {
   140  			executor.Free()
   141  		}
   142  	}()
   143  
   144  	for i := 0; i < originBatch.RowCount(); i++ {
   145  		newBatch, err := fetchOneRowAsBatch(i, originBatch, proc, attrs)
   146  		if err != nil {
   147  			return err
   148  		}
   149  
   150  		// check if uniqueness conflict found in checkConflictBatch
   151  		oldConflictIdx, conflictMsg, err := checkConflict(proc, newBatch, checkConflictBatch, checkExpressionExecutors, insertArg.UniqueCols, insertColCount)
   152  		if err != nil {
   153  			newBatch.Clean(proc.GetMPool())
   154  			return err
   155  		}
   156  		if oldConflictIdx > -1 {
   157  
   158  			if insertArg.IsIgnore {
   159  				continue
   160  			}
   161  
   162  			// if conflict with origin row. and row_id is not equal row_id of insertBatch's inflict row. then throw error
   163  			if !newBatch.Vecs[rowIdIdx].GetNulls().Contains(0) {
   164  				oldRowId := vector.MustFixedCol[types.Rowid](insertBatch.Vecs[rowIdIdx])[oldConflictIdx]
   165  				newRowId := vector.MustFixedCol[types.Rowid](newBatch.Vecs[rowIdIdx])[0]
   166  				if !bytes.Equal(oldRowId[:], newRowId[:]) {
   167  					newBatch.Clean(proc.GetMPool())
   168  					return moerr.NewConstraintViolation(proc.Ctx, conflictMsg)
   169  				}
   170  			}
   171  
   172  			for j := 0; j < insertColCount; j++ {
   173  				fromVec := insertBatch.Vecs[j]
   174  				toVec := newBatch.Vecs[j+insertColCount]
   175  				err := toVec.Copy(fromVec, 0, int64(oldConflictIdx), proc.Mp())
   176  				if err != nil {
   177  					newBatch.Clean(proc.GetMPool())
   178  					return err
   179  				}
   180  			}
   181  			tmpBatch, err := updateOldBatch(newBatch, updateExpr, proc, insertColCount, attrs)
   182  			if err != nil {
   183  				newBatch.Clean(proc.GetMPool())
   184  				return err
   185  			}
   186  			// update the oldConflictIdx of insertBatch by newBatch
   187  			for j := 0; j < insertColCount; j++ {
   188  				fromVec := tmpBatch.Vecs[j]
   189  				toVec := insertBatch.Vecs[j]
   190  				err := toVec.Copy(fromVec, int64(oldConflictIdx), 0, proc.Mp())
   191  				if err != nil {
   192  					tmpBatch.Clean(proc.GetMPool())
   193  					newBatch.Clean(proc.GetMPool())
   194  					return err
   195  				}
   196  
   197  				toVec2 := checkConflictBatch.Vecs[j]
   198  				err = toVec2.Copy(fromVec, int64(oldConflictIdx), 0, proc.Mp())
   199  				if err != nil {
   200  					tmpBatch.Clean(proc.GetMPool())
   201  					newBatch.Clean(proc.GetMPool())
   202  					return err
   203  				}
   204  			}
   205  			proc.PutBatch(tmpBatch)
   206  		} else {
   207  			// row id is null: means no uniqueness conflict found in origin rows
   208  			if len(oldRowIdVec) == 0 || originBatch.Vecs[rowIdIdx].GetNulls().Contains(uint64(i)) {
   209  				_, err := insertBatch.Append(proc.Ctx, proc.Mp(), newBatch)
   210  				if err != nil {
   211  					newBatch.Clean(proc.GetMPool())
   212  					return err
   213  				}
   214  				_, err = checkConflictBatch.Append(proc.Ctx, proc.Mp(), newBatch)
   215  				if err != nil {
   216  					newBatch.Clean(proc.GetMPool())
   217  					return err
   218  				}
   219  			} else {
   220  
   221  				if insertArg.IsIgnore {
   222  					proc.PutBatch(newBatch)
   223  					continue
   224  				}
   225  
   226  				tmpBatch, err := updateOldBatch(newBatch, updateExpr, proc, insertColCount, attrs)
   227  				if err != nil {
   228  					newBatch.Clean(proc.GetMPool())
   229  					return err
   230  				}
   231  				conflictIdx, conflictMsg, err := checkConflict(proc, tmpBatch, checkConflictBatch, checkExpressionExecutors, insertArg.UniqueCols, insertColCount)
   232  				if err != nil {
   233  					tmpBatch.Clean(proc.GetMPool())
   234  					newBatch.Clean(proc.GetMPool())
   235  					return err
   236  				}
   237  				if conflictIdx > -1 {
   238  					tmpBatch.Clean(proc.GetMPool())
   239  					newBatch.Clean(proc.GetMPool())
   240  					return moerr.NewConstraintViolation(proc.Ctx, conflictMsg)
   241  				} else {
   242  					// append batch to insertBatch
   243  					_, err = insertBatch.Append(proc.Ctx, proc.Mp(), tmpBatch)
   244  					if err != nil {
   245  						tmpBatch.Clean(proc.GetMPool())
   246  						newBatch.Clean(proc.GetMPool())
   247  						return err
   248  					}
   249  					_, err = checkConflictBatch.Append(proc.Ctx, proc.Mp(), tmpBatch)
   250  					if err != nil {
   251  						tmpBatch.Clean(proc.GetMPool())
   252  						newBatch.Clean(proc.GetMPool())
   253  						return err
   254  					}
   255  				}
   256  				proc.PutBatch(tmpBatch)
   257  			}
   258  		}
   259  		proc.PutBatch(newBatch)
   260  	}
   261  
   262  	return nil
   263  }
   264  
   265  func resetColPos(e *plan.Expr, columnCount int) {
   266  	switch tmpExpr := e.Expr.(type) {
   267  	case *plan.Expr_Col:
   268  		tmpExpr.Col.ColPos = tmpExpr.Col.ColPos + int32(columnCount)
   269  	case *plan.Expr_F:
   270  		if tmpExpr.F.Func.ObjName != "values" {
   271  			for _, arg := range tmpExpr.F.Args {
   272  				resetColPos(arg, columnCount)
   273  			}
   274  		}
   275  	}
   276  }
   277  
   278  func fetchOneRowAsBatch(idx int, originBatch *batch.Batch, proc *process.Process, attrs []string) (*batch.Batch, error) {
   279  	newBatch := batch.NewWithSize(len(attrs))
   280  	newBatch.Attrs = attrs
   281  	var uErr error
   282  	for i, v := range originBatch.Vecs {
   283  		newVec := proc.GetVector(*v.GetType())
   284  		uErr = newVec.UnionOne(v, int64(idx), proc.Mp())
   285  		if uErr != nil {
   286  			newBatch.Clean(proc.Mp())
   287  			return nil, uErr
   288  		}
   289  		newBatch.SetVector(int32(i), newVec)
   290  	}
   291  	newBatch.SetRowCount(1)
   292  	return newBatch, nil
   293  }
   294  
   295  func updateOldBatch(evalBatch *batch.Batch, updateExpr map[string]*plan.Expr, proc *process.Process, columnCount int, attrs []string) (*batch.Batch, error) {
   296  	var originVec *vector.Vector
   297  	newBatch := batch.NewWithSize(len(attrs))
   298  	newBatch.Attrs = attrs
   299  	for i, attr := range newBatch.Attrs {
   300  		if i < columnCount {
   301  			// update insert cols
   302  			if expr, exists := updateExpr[attr]; exists {
   303  				runExpr := plan2.DeepCopyExpr(expr)
   304  				resetColPos(runExpr, columnCount)
   305  				newVec, err := colexec.EvalExpressionOnce(proc, runExpr, []*batch.Batch{evalBatch})
   306  				if err != nil {
   307  					newBatch.Clean(proc.Mp())
   308  					return nil, err
   309  				}
   310  				newBatch.SetVector(int32(i), newVec)
   311  			} else {
   312  				originVec = evalBatch.Vecs[i+columnCount]
   313  				newVec := proc.GetVector(*originVec.GetType())
   314  				err := newVec.UnionOne(originVec, int64(0), proc.Mp())
   315  				if err != nil {
   316  					newBatch.Clean(proc.Mp())
   317  					return nil, err
   318  				}
   319  				newBatch.SetVector(int32(i), newVec)
   320  			}
   321  		} else {
   322  			// keep old cols
   323  			originVec = evalBatch.Vecs[i]
   324  			newVec := proc.GetVector(*originVec.GetType())
   325  			err := newVec.UnionOne(originVec, int64(0), proc.Mp())
   326  			if err != nil {
   327  				newBatch.Clean(proc.Mp())
   328  				return nil, err
   329  			}
   330  			newBatch.SetVector(int32(i), newVec)
   331  		}
   332  	}
   333  
   334  	newBatch.SetRowCount(1)
   335  	return newBatch, nil
   336  }
   337  
   338  func checkConflict(proc *process.Process, newBatch *batch.Batch, checkConflictBatch *batch.Batch,
   339  	checkExpressionExecutor []colexec.ExpressionExecutor, uniqueCols []string, colCount int) (int, string, error) {
   340  	if checkConflictBatch.RowCount() == 0 {
   341  		return -1, "", nil
   342  	}
   343  	for j := 0; j < colCount; j++ {
   344  		fromVec := newBatch.Vecs[j]
   345  		toVec := checkConflictBatch.Vecs[j+colCount]
   346  		for i := 0; i < checkConflictBatch.RowCount(); i++ {
   347  			err := toVec.Copy(fromVec, int64(i), 0, proc.Mp())
   348  			if err != nil {
   349  				return 0, "", err
   350  			}
   351  		}
   352  	}
   353  
   354  	// build the check expr
   355  	for i, executor := range checkExpressionExecutor {
   356  		result, err := executor.Eval(proc, []*batch.Batch{checkConflictBatch})
   357  		if err != nil {
   358  			return 0, "", err
   359  		}
   360  
   361  		// run expr row by row. if result is true, break
   362  		isConflict := vector.MustFixedCol[bool](result)
   363  		for _, flag := range isConflict {
   364  			if flag {
   365  				conflictMsg := fmt.Sprintf("Duplicate entry for key '%s'", uniqueCols[i])
   366  				return i, conflictMsg, nil
   367  			}
   368  		}
   369  	}
   370  
   371  	return -1, "", nil
   372  }