github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/semi/join.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package semi 16 17 import ( 18 "bytes" 19 20 "github.com/matrixorigin/matrixone/pkg/common/hashmap" 21 "github.com/matrixorigin/matrixone/pkg/container/batch" 22 "github.com/matrixorigin/matrixone/pkg/container/vector" 23 "github.com/matrixorigin/matrixone/pkg/sql/colexec" 24 "github.com/matrixorigin/matrixone/pkg/vm" 25 "github.com/matrixorigin/matrixone/pkg/vm/process" 26 ) 27 28 const argName = "semi" 29 30 func (arg *Argument) String(buf *bytes.Buffer) { 31 buf.WriteString(argName) 32 buf.WriteString(": semi join ") 33 } 34 35 func (arg *Argument) Prepare(proc *process.Process) (err error) { 36 ap := arg 37 ap.ctr = new(container) 38 ap.ctr.InitReceiver(proc, false) 39 ap.ctr.inBuckets = make([]uint8, hashmap.UnitLimit) 40 ap.ctr.vecs = make([]*vector.Vector, len(ap.Conditions[0])) 41 42 ap.ctr.evecs = make([]evalVector, len(ap.Conditions[0])) 43 for i := range ap.ctr.evecs { 44 ap.ctr.evecs[i].executor, err = colexec.NewExpressionExecutor(proc, ap.Conditions[0][i]) 45 if err != nil { 46 return err 47 } 48 } 49 50 if ap.Cond != nil { 51 ap.ctr.expr, err = colexec.NewExpressionExecutor(proc, ap.Cond) 52 } 53 return err 54 } 55 56 func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) { 57 if err, isCancel := vm.CancelCheck(proc); isCancel { 58 return vm.CancelResult, err 59 } 60 61 anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor()) 62 anal.Start() 63 defer anal.Stop() 64 ap := arg 65 ctr := ap.ctr 66 result := vm.NewCallResult() 67 for { 68 switch ctr.state { 69 case Build: 70 if err := ctr.build(anal); err != nil { 71 return result, err 72 } 73 if ctr.mp == nil && !arg.IsShuffle { 74 // for inner ,right and semi join, if hashmap is empty, we can finish this pipeline 75 // shuffle join can't stop early for this moment 76 ctr.state = End 77 } else { 78 ctr.state = Probe 79 } 80 if ctr.mp != nil && ctr.mp.PushedRuntimeFilterIn() && ap.Cond == nil { 81 ctr.skipProbe = true 82 } 83 84 case Probe: 85 bat, _, err := ctr.ReceiveFromSingleReg(0, anal) 86 if err != nil { 87 return result, err 88 } 89 90 if bat == nil { 91 ctr.state = End 92 continue 93 } 94 if bat.IsEmpty() { 95 proc.PutBatch(bat) 96 continue 97 } 98 if ctr.skipProbe { 99 vecused := make([]bool, len(bat.Vecs)) 100 newvecs := make([]*vector.Vector, len(ap.Result)) 101 for i, pos := range ap.Result { 102 vecused[pos] = true 103 newvecs[i] = bat.Vecs[pos] 104 } 105 for i := range bat.Vecs { 106 if !vecused[i] { 107 bat.Vecs[i].Free(proc.Mp()) 108 } 109 } 110 bat.Vecs = newvecs 111 result.Batch = bat 112 anal.Output(bat, arg.GetIsLast()) 113 return result, nil 114 } 115 if ctr.mp == nil { 116 proc.PutBatch(bat) 117 continue 118 } 119 if err := ctr.probe(bat, ap, proc, anal, arg.GetIsFirst(), arg.GetIsLast(), &result); err != nil { 120 bat.Clean(proc.Mp()) 121 return result, err 122 } 123 proc.PutBatch(bat) 124 return result, nil 125 126 default: 127 result.Batch = nil 128 result.Status = vm.ExecStop 129 return result, nil 130 } 131 } 132 } 133 134 func (ctr *container) receiveHashMap(anal process.Analyze) error { 135 bat, _, err := ctr.ReceiveFromSingleReg(1, anal) 136 if err != nil { 137 return err 138 } 139 if bat != nil && bat.AuxData != nil { 140 ctr.mp = bat.DupJmAuxData() 141 ctr.maxAllocSize = max(ctr.maxAllocSize, ctr.mp.Size()) 142 } 143 return nil 144 } 145 146 func (ctr *container) receiveBatch(anal process.Analyze) error { 147 for { 148 bat, _, err := ctr.ReceiveFromSingleReg(1, anal) 149 if err != nil { 150 return err 151 } 152 if bat != nil { 153 ctr.batchRowCount += bat.RowCount() 154 ctr.batches = append(ctr.batches, bat) 155 } else { 156 break 157 } 158 } 159 for i := 0; i < len(ctr.batches)-1; i++ { 160 if ctr.batches[i].RowCount() != colexec.DefaultBatchSize { 161 panic("wrong batch received for hash build!") 162 } 163 } 164 return nil 165 } 166 167 func (ctr *container) build(anal process.Analyze) error { 168 err := ctr.receiveHashMap(anal) 169 if err != nil { 170 return err 171 } 172 return ctr.receiveBatch(anal) 173 } 174 175 func (ctr *container) probe(bat *batch.Batch, ap *Argument, proc *process.Process, anal process.Analyze, isFirst bool, isLast bool, result *vm.CallResult) error { 176 anal.Input(bat, isFirst) 177 if ctr.rbat != nil { 178 proc.PutBatch(ctr.rbat) 179 ctr.rbat = nil 180 } 181 ctr.rbat = batch.NewWithSize(len(ap.Result)) 182 for i, pos := range ap.Result { 183 ctr.rbat.Vecs[i] = proc.GetVector(*bat.Vecs[pos].GetType()) 184 // for semi join, if left batch is sorted , then output batch is sorted 185 ctr.rbat.Vecs[i].SetSorted(bat.Vecs[pos].GetSorted()) 186 } 187 if err := ctr.evalJoinCondition(bat, proc); err != nil { 188 return err 189 } 190 if ctr.joinBat1 == nil { 191 ctr.joinBat1, ctr.cfs1 = colexec.NewJoinBatch(bat, proc.Mp()) 192 } 193 if ctr.joinBat2 == nil && ctr.batchRowCount > 0 { 194 ctr.joinBat2, ctr.cfs2 = colexec.NewJoinBatch(ctr.batches[0], proc.Mp()) 195 } 196 count := bat.RowCount() 197 mSels := ctr.mp.Sels() 198 itr := ctr.mp.NewIterator() 199 200 rowCountIncrease := 0 201 eligible := make([]int32, 0, hashmap.UnitLimit) 202 for i := 0; i < count; i += hashmap.UnitLimit { 203 n := count - i 204 if n > hashmap.UnitLimit { 205 n = hashmap.UnitLimit 206 } 207 copy(ctr.inBuckets, hashmap.OneUInt8s) 208 vals, zvals := itr.Find(i, n, ctr.vecs, ctr.inBuckets) 209 for k := 0; k < n; k++ { 210 if ctr.inBuckets[k] == 0 || zvals[k] == 0 || vals[k] == 0 { 211 continue 212 } 213 if ap.Cond != nil { 214 matched := false // mark if any tuple satisfies the condition 215 if ap.HashOnPK { 216 idx1, idx2 := int64(vals[k]-1)/colexec.DefaultBatchSize, int64(vals[k]-1)%colexec.DefaultBatchSize 217 if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k), 218 1, ctr.cfs1); err != nil { 219 return err 220 } 221 if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], idx2, 222 1, ctr.cfs2); err != nil { 223 return err 224 } 225 vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2}) 226 if err != nil { 227 return err 228 } 229 if vec.IsConstNull() || vec.GetNulls().Contains(0) { 230 continue 231 } 232 bs := vector.MustFixedCol[bool](vec) 233 if bs[0] { 234 matched = true 235 } 236 } else { 237 sels := mSels[vals[k]-1] 238 for _, sel := range sels { 239 idx1, idx2 := sel/colexec.DefaultBatchSize, sel%colexec.DefaultBatchSize 240 if err := colexec.SetJoinBatchValues(ctr.joinBat1, bat, int64(i+k), 241 1, ctr.cfs1); err != nil { 242 return err 243 } 244 if err := colexec.SetJoinBatchValues(ctr.joinBat2, ctr.batches[idx1], int64(idx2), 245 1, ctr.cfs2); err != nil { 246 return err 247 } 248 vec, err := ctr.expr.Eval(proc, []*batch.Batch{ctr.joinBat1, ctr.joinBat2}) 249 if err != nil { 250 return err 251 } 252 if vec.IsConstNull() || vec.GetNulls().Contains(0) { 253 continue 254 } 255 bs := vector.MustFixedCol[bool](vec) 256 if bs[0] { 257 matched = true 258 break 259 } 260 } 261 262 } 263 if !matched { 264 continue 265 } 266 } 267 eligible = append(eligible, int32(i+k)) 268 rowCountIncrease++ 269 } 270 for j, pos := range ap.Result { 271 if err := ctr.rbat.Vecs[j].Union(bat.Vecs[pos], eligible, proc.Mp()); err != nil { 272 return err 273 } 274 } 275 eligible = eligible[:0] 276 } 277 278 ctr.rbat.AddRowCount(rowCountIncrease) 279 anal.Output(ctr.rbat, isLast) 280 result.Batch = ctr.rbat 281 return nil 282 } 283 284 func (ctr *container) evalJoinCondition(bat *batch.Batch, proc *process.Process) error { 285 for i := range ctr.evecs { 286 vec, err := ctr.evecs[i].executor.Eval(proc, []*batch.Batch{bat}) 287 if err != nil { 288 return err 289 } 290 ctr.vecs[i] = vec 291 ctr.evecs[i].vec = vec 292 } 293 return nil 294 }