github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/onduplicatekey/on_duplicate_key.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package onduplicatekey 16 17 import ( 18 "bytes" 19 "fmt" 20 "github.com/matrixorigin/matrixone/pkg/common/moerr" 21 "github.com/matrixorigin/matrixone/pkg/container/batch" 22 "github.com/matrixorigin/matrixone/pkg/container/types" 23 "github.com/matrixorigin/matrixone/pkg/container/vector" 24 "github.com/matrixorigin/matrixone/pkg/pb/plan" 25 "github.com/matrixorigin/matrixone/pkg/sql/colexec" 26 plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan" 27 "github.com/matrixorigin/matrixone/pkg/vm" 28 "github.com/matrixorigin/matrixone/pkg/vm/process" 29 ) 30 31 const argName = "on_duplicate_key" 32 33 func (arg *Argument) String(buf *bytes.Buffer) { 34 buf.WriteString(argName) 35 buf.WriteString(": processing on duplicate key before insert") 36 } 37 38 func (arg *Argument) Prepare(p *process.Process) error { 39 ap := arg 40 ap.ctr = &container{} 41 ap.ctr.InitReceiver(p, true) 42 return nil 43 } 44 45 func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) { 46 if err, isCancel := vm.CancelCheck(proc); isCancel { 47 return vm.CancelResult, err 48 } 49 50 anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor()) 51 anal.Start() 52 defer anal.Stop() 53 54 ctr := arg.ctr 55 result := vm.NewCallResult() 56 57 for { 58 switch ctr.state { 59 case Build: 60 for { 61 bat, end, err := ctr.ReceiveFromAllRegs(anal) 62 if err != nil { 63 result.Status = vm.ExecStop 64 return result, nil 65 } 66 67 if end { 68 break 69 } 70 anal.Input(bat, arg.GetIsFirst()) 71 err = resetInsertBatchForOnduplicateKey(proc, bat, arg) 72 if err != nil { 73 bat.Clean(proc.Mp()) 74 return result, err 75 } 76 77 } 78 ctr.state = Eval 79 80 case Eval: 81 if ctr.rbat != nil { 82 anal.Output(ctr.rbat, arg.GetIsLast()) 83 } 84 result.Batch = ctr.rbat 85 ctr.state = End 86 return result, nil 87 88 case End: 89 result.Batch = nil 90 result.Status = vm.ExecStop 91 return result, nil 92 } 93 } 94 } 95 96 func resetInsertBatchForOnduplicateKey(proc *process.Process, originBatch *batch.Batch, insertArg *Argument) error { 97 //get rowid vec index 98 rowIdIdx := int32(-1) 99 for _, idx := range insertArg.OnDuplicateIdx { 100 if originBatch.Vecs[idx].GetType().Oid == types.T_Rowid { 101 rowIdIdx = idx 102 break 103 } 104 } 105 if rowIdIdx == -1 { 106 return moerr.NewConstraintViolation(proc.Ctx, "can not find rowid when insert with on duplicate key") 107 } 108 109 insertColCount := int(insertArg.InsertColCount) //columns without hidden columns 110 if insertArg.ctr.rbat == nil { 111 insertArg.ctr.rbat = batch.NewWithSize(len(insertArg.Attrs)) 112 insertArg.ctr.rbat.Attrs = insertArg.Attrs 113 114 insertArg.ctr.checkConflictBat = batch.NewWithSize(len(insertArg.Attrs)) 115 insertArg.ctr.checkConflictBat.Attrs = append(insertArg.ctr.checkConflictBat.Attrs, insertArg.Attrs...) 116 117 for i, v := range originBatch.Vecs { 118 newVec := proc.GetVector(*v.GetType()) 119 insertArg.ctr.rbat.SetVector(int32(i), newVec) 120 121 ckVec := proc.GetVector(*v.GetType()) 122 insertArg.ctr.checkConflictBat.SetVector(int32(i), ckVec) 123 } 124 } 125 126 insertBatch := insertArg.ctr.rbat 127 checkConflictBatch := insertArg.ctr.checkConflictBat 128 attrs := make([]string, len(insertBatch.Attrs)) 129 copy(attrs, insertBatch.Attrs) 130 131 updateExpr := insertArg.OnDuplicateExpr 132 oldRowIdVec := vector.MustFixedCol[types.Rowid](originBatch.Vecs[rowIdIdx]) 133 134 checkExpressionExecutors, err := colexec.NewExpressionExecutorsFromPlanExpressions(proc, insertArg.UniqueColCheckExpr) 135 if err != nil { 136 return err 137 } 138 defer func() { 139 for _, executor := range checkExpressionExecutors { 140 executor.Free() 141 } 142 }() 143 144 for i := 0; i < originBatch.RowCount(); i++ { 145 newBatch, err := fetchOneRowAsBatch(i, originBatch, proc, attrs) 146 if err != nil { 147 return err 148 } 149 150 // check if uniqueness conflict found in checkConflictBatch 151 oldConflictIdx, conflictMsg, err := checkConflict(proc, newBatch, checkConflictBatch, checkExpressionExecutors, insertArg.UniqueCols, insertColCount) 152 if err != nil { 153 newBatch.Clean(proc.GetMPool()) 154 return err 155 } 156 if oldConflictIdx > -1 { 157 158 if insertArg.IsIgnore { 159 continue 160 } 161 162 // if conflict with origin row. and row_id is not equal row_id of insertBatch's inflict row. then throw error 163 if !newBatch.Vecs[rowIdIdx].GetNulls().Contains(0) { 164 oldRowId := vector.MustFixedCol[types.Rowid](insertBatch.Vecs[rowIdIdx])[oldConflictIdx] 165 newRowId := vector.MustFixedCol[types.Rowid](newBatch.Vecs[rowIdIdx])[0] 166 if !bytes.Equal(oldRowId[:], newRowId[:]) { 167 newBatch.Clean(proc.GetMPool()) 168 return moerr.NewConstraintViolation(proc.Ctx, conflictMsg) 169 } 170 } 171 172 for j := 0; j < insertColCount; j++ { 173 fromVec := insertBatch.Vecs[j] 174 toVec := newBatch.Vecs[j+insertColCount] 175 err := toVec.Copy(fromVec, 0, int64(oldConflictIdx), proc.Mp()) 176 if err != nil { 177 newBatch.Clean(proc.GetMPool()) 178 return err 179 } 180 } 181 tmpBatch, err := updateOldBatch(newBatch, updateExpr, proc, insertColCount, attrs) 182 if err != nil { 183 newBatch.Clean(proc.GetMPool()) 184 return err 185 } 186 // update the oldConflictIdx of insertBatch by newBatch 187 for j := 0; j < insertColCount; j++ { 188 fromVec := tmpBatch.Vecs[j] 189 toVec := insertBatch.Vecs[j] 190 err := toVec.Copy(fromVec, int64(oldConflictIdx), 0, proc.Mp()) 191 if err != nil { 192 tmpBatch.Clean(proc.GetMPool()) 193 newBatch.Clean(proc.GetMPool()) 194 return err 195 } 196 197 toVec2 := checkConflictBatch.Vecs[j] 198 err = toVec2.Copy(fromVec, int64(oldConflictIdx), 0, proc.Mp()) 199 if err != nil { 200 tmpBatch.Clean(proc.GetMPool()) 201 newBatch.Clean(proc.GetMPool()) 202 return err 203 } 204 } 205 proc.PutBatch(tmpBatch) 206 } else { 207 // row id is null: means no uniqueness conflict found in origin rows 208 if len(oldRowIdVec) == 0 || originBatch.Vecs[rowIdIdx].GetNulls().Contains(uint64(i)) { 209 _, err := insertBatch.Append(proc.Ctx, proc.Mp(), newBatch) 210 if err != nil { 211 newBatch.Clean(proc.GetMPool()) 212 return err 213 } 214 _, err = checkConflictBatch.Append(proc.Ctx, proc.Mp(), newBatch) 215 if err != nil { 216 newBatch.Clean(proc.GetMPool()) 217 return err 218 } 219 } else { 220 221 if insertArg.IsIgnore { 222 proc.PutBatch(newBatch) 223 continue 224 } 225 226 tmpBatch, err := updateOldBatch(newBatch, updateExpr, proc, insertColCount, attrs) 227 if err != nil { 228 newBatch.Clean(proc.GetMPool()) 229 return err 230 } 231 conflictIdx, conflictMsg, err := checkConflict(proc, tmpBatch, checkConflictBatch, checkExpressionExecutors, insertArg.UniqueCols, insertColCount) 232 if err != nil { 233 tmpBatch.Clean(proc.GetMPool()) 234 newBatch.Clean(proc.GetMPool()) 235 return err 236 } 237 if conflictIdx > -1 { 238 tmpBatch.Clean(proc.GetMPool()) 239 newBatch.Clean(proc.GetMPool()) 240 return moerr.NewConstraintViolation(proc.Ctx, conflictMsg) 241 } else { 242 // append batch to insertBatch 243 _, err = insertBatch.Append(proc.Ctx, proc.Mp(), tmpBatch) 244 if err != nil { 245 tmpBatch.Clean(proc.GetMPool()) 246 newBatch.Clean(proc.GetMPool()) 247 return err 248 } 249 _, err = checkConflictBatch.Append(proc.Ctx, proc.Mp(), tmpBatch) 250 if err != nil { 251 tmpBatch.Clean(proc.GetMPool()) 252 newBatch.Clean(proc.GetMPool()) 253 return err 254 } 255 } 256 proc.PutBatch(tmpBatch) 257 } 258 } 259 proc.PutBatch(newBatch) 260 } 261 262 return nil 263 } 264 265 func resetColPos(e *plan.Expr, columnCount int) { 266 switch tmpExpr := e.Expr.(type) { 267 case *plan.Expr_Col: 268 tmpExpr.Col.ColPos = tmpExpr.Col.ColPos + int32(columnCount) 269 case *plan.Expr_F: 270 if tmpExpr.F.Func.ObjName != "values" { 271 for _, arg := range tmpExpr.F.Args { 272 resetColPos(arg, columnCount) 273 } 274 } 275 } 276 } 277 278 func fetchOneRowAsBatch(idx int, originBatch *batch.Batch, proc *process.Process, attrs []string) (*batch.Batch, error) { 279 newBatch := batch.NewWithSize(len(attrs)) 280 newBatch.Attrs = attrs 281 var uErr error 282 for i, v := range originBatch.Vecs { 283 newVec := proc.GetVector(*v.GetType()) 284 uErr = newVec.UnionOne(v, int64(idx), proc.Mp()) 285 if uErr != nil { 286 newBatch.Clean(proc.Mp()) 287 return nil, uErr 288 } 289 newBatch.SetVector(int32(i), newVec) 290 } 291 newBatch.SetRowCount(1) 292 return newBatch, nil 293 } 294 295 func updateOldBatch(evalBatch *batch.Batch, updateExpr map[string]*plan.Expr, proc *process.Process, columnCount int, attrs []string) (*batch.Batch, error) { 296 var originVec *vector.Vector 297 newBatch := batch.NewWithSize(len(attrs)) 298 newBatch.Attrs = attrs 299 for i, attr := range newBatch.Attrs { 300 if i < columnCount { 301 // update insert cols 302 if expr, exists := updateExpr[attr]; exists { 303 runExpr := plan2.DeepCopyExpr(expr) 304 resetColPos(runExpr, columnCount) 305 newVec, err := colexec.EvalExpressionOnce(proc, runExpr, []*batch.Batch{evalBatch}) 306 if err != nil { 307 newBatch.Clean(proc.Mp()) 308 return nil, err 309 } 310 newBatch.SetVector(int32(i), newVec) 311 } else { 312 originVec = evalBatch.Vecs[i+columnCount] 313 newVec := proc.GetVector(*originVec.GetType()) 314 err := newVec.UnionOne(originVec, int64(0), proc.Mp()) 315 if err != nil { 316 newBatch.Clean(proc.Mp()) 317 return nil, err 318 } 319 newBatch.SetVector(int32(i), newVec) 320 } 321 } else { 322 // keep old cols 323 originVec = evalBatch.Vecs[i] 324 newVec := proc.GetVector(*originVec.GetType()) 325 err := newVec.UnionOne(originVec, int64(0), proc.Mp()) 326 if err != nil { 327 newBatch.Clean(proc.Mp()) 328 return nil, err 329 } 330 newBatch.SetVector(int32(i), newVec) 331 } 332 } 333 334 newBatch.SetRowCount(1) 335 return newBatch, nil 336 } 337 338 func checkConflict(proc *process.Process, newBatch *batch.Batch, checkConflictBatch *batch.Batch, 339 checkExpressionExecutor []colexec.ExpressionExecutor, uniqueCols []string, colCount int) (int, string, error) { 340 if checkConflictBatch.RowCount() == 0 { 341 return -1, "", nil 342 } 343 for j := 0; j < colCount; j++ { 344 fromVec := newBatch.Vecs[j] 345 toVec := checkConflictBatch.Vecs[j+colCount] 346 for i := 0; i < checkConflictBatch.RowCount(); i++ { 347 err := toVec.Copy(fromVec, int64(i), 0, proc.Mp()) 348 if err != nil { 349 return 0, "", err 350 } 351 } 352 } 353 354 // build the check expr 355 for i, executor := range checkExpressionExecutor { 356 result, err := executor.Eval(proc, []*batch.Batch{checkConflictBatch}) 357 if err != nil { 358 return 0, "", err 359 } 360 361 // run expr row by row. if result is true, break 362 isConflict := vector.MustFixedCol[bool](result) 363 for _, flag := range isConflict { 364 if flag { 365 conflictMsg := fmt.Sprintf("Duplicate entry for key '%s'", uniqueCols[i]) 366 return i, conflictMsg, nil 367 } 368 } 369 } 370 371 return -1, "", nil 372 }