github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/function/operator/case_when.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package operator 16 17 import ( 18 "github.com/matrixorigin/matrixone/pkg/container/nulls" 19 "github.com/matrixorigin/matrixone/pkg/container/types" 20 "github.com/matrixorigin/matrixone/pkg/container/vector" 21 "github.com/matrixorigin/matrixone/pkg/vm/process" 22 "golang.org/x/exp/constraints" 23 ) 24 25 // case-when operator only support format like that 26 // 27 // ` 28 // case 29 // when A = a1 then ... 30 // when A = a2 then ... 31 // when A = a3 then ... 32 // (else ...) 33 // ` 34 // 35 // format `case A when a1 then ... when a2 then ...` should be converted to required format. 36 var ( 37 CaseWhenUint8 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 38 return cwGeneral[uint8](vs, proc, types.Type{Oid: types.T_uint8}) 39 } 40 41 CaseWhenUint16 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 42 return cwGeneral[uint16](vs, proc, types.Type{Oid: types.T_uint16}) 43 } 44 45 CaseWhenUint32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 46 return cwGeneral[uint32](vs, proc, types.Type{Oid: types.T_uint32}) 47 } 48 49 CaseWhenUint64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 50 return cwGeneral[uint64](vs, proc, types.Type{Oid: types.T_uint64}) 51 } 52 53 CaseWhenInt8 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 54 return cwGeneral[int8](vs, proc, types.Type{Oid: types.T_int8}) 55 } 56 57 CaseWhenInt16 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 58 return cwGeneral[int16](vs, proc, types.Type{Oid: types.T_int16}) 59 } 60 61 CaseWhenInt32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 62 return cwGeneral[int32](vs, proc, types.Type{Oid: types.T_int32}) 63 } 64 65 CaseWhenInt64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 66 return cwGeneral[int64](vs, proc, types.Type{Oid: types.T_int64}) 67 } 68 69 CaseWhenFloat32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 70 return cwGeneral[float32](vs, proc, types.Type{Oid: types.T_float32}) 71 } 72 73 CaseWhenFloat64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 74 return cwGeneral[float64](vs, proc, types.Type{Oid: types.T_float64}) 75 } 76 77 CaseWhenBool = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 78 return cwGeneral[bool](vs, proc, types.Type{Oid: types.T_bool}) 79 } 80 81 CaseWhenDate = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 82 return cwGeneral[types.Date](vs, proc, types.Type{Oid: types.T_date}) 83 } 84 85 CaseWhenTime = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 86 return cwGeneral[types.Time](vs, proc, types.Type{Oid: types.T_time}) 87 } 88 89 CaseWhenDateTime = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 90 return cwGeneral[types.Datetime](vs, proc, types.Type{Oid: types.T_datetime}) 91 } 92 93 CaseWhenVarchar = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 94 return cwString(vs, proc, types.Type{Oid: types.T_varchar, Width: types.MaxVarcharLen}) 95 } 96 97 CaseWhenChar = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 98 return cwString(vs, proc, types.Type{Oid: types.T_char}) 99 } 100 101 CaseWhenDecimal64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 102 return cwGeneral[types.Decimal64](vs, proc, types.Type{Oid: types.T_decimal64}) 103 } 104 105 CaseWhenDecimal128 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 106 return cwGeneral[types.Decimal128](vs, proc, types.Type{Oid: types.T_decimal128}) 107 } 108 109 CaseWhenTimestamp = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 110 return cwGeneral[types.Timestamp](vs, proc, types.Type{Oid: types.T_timestamp}) 111 } 112 113 CaseWhenUuid = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 114 return cwGeneral[types.Uuid](vs, proc, types.Type{Oid: types.T_uuid}) 115 } 116 117 CaseWhenBlob = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 118 return cwString(vs, proc, types.Type{Oid: types.T_blob}) 119 } 120 121 CaseWhenText = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 122 return cwString(vs, proc, types.Type{Oid: types.T_text}) 123 } 124 125 CaseWhenJson = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 126 return cwString(vs, proc, types.Type{Oid: types.T_json}) 127 } 128 ) 129 130 // CwTypeCheckFn is type check function for case-when operator 131 func CwTypeCheckFn(inputTypes []types.T, _ []types.T, ret types.T) bool { 132 l := len(inputTypes) 133 if l >= 2 { 134 for i := 0; i < l-1; i += 2 { 135 if inputTypes[i] != types.T_bool { 136 return false 137 } 138 } 139 140 if l%2 == 1 { 141 if inputTypes[l-1] != ret && inputTypes[l-1] != types.T_any { 142 return false 143 } 144 } 145 146 for i := 1; i < l; i += 2 { 147 if inputTypes[i] != ret && inputTypes[i] != types.T_any { 148 return false 149 } 150 } 151 return true 152 } 153 return false 154 } 155 156 type OrderedValue interface { 157 constraints.Integer | constraints.Float | types.Date | types.Datetime | types.Decimal64 | types.Timestamp 158 } 159 160 type NormalType interface { 161 constraints.Integer | constraints.Float | bool | types.Date | types.Datetime | 162 types.Decimal64 | types.Decimal128 | types.Timestamp | types.Uuid 163 } 164 165 // cwGeneral is a general evaluate function for case-when operator 166 // whose return type is uint / int / float / bool / date / datetime 167 func cwGeneral[T NormalType](vs []*vector.Vector, proc *process.Process, t types.Type) (*vector.Vector, error) { 168 l := vector.Length(vs[0]) 169 170 rs, err := proc.AllocVector(t, int64(l*t.Oid.TypeLen())) 171 if err != nil { 172 return nil, err 173 } 174 rs.Col = vector.DecodeFixedCol[T](rs, t.Oid.TypeLen()) 175 rs.Col = rs.Col.([]T)[:l] 176 rscols := rs.Col.([]T) 177 178 flag := make([]bool, l) // if flag[i] is false, it couldn't adapt to any case 179 180 for i := 0; i < len(vs)-1; i += 2 { 181 whenv := vs[i] 182 thenv := vs[i+1] 183 whencols := vector.MustTCols[bool](whenv) 184 thencols := vector.MustTCols[T](thenv) 185 switch { 186 case whenv.IsScalar() && thenv.IsScalar(): 187 if !whenv.IsScalarNull() && whencols[0] { 188 if thenv.IsScalarNull() { 189 return proc.AllocScalarNullVector(t), nil 190 } else { 191 r := proc.AllocScalarVector(t) 192 r.Typ.Precision = thenv.Typ.Precision 193 r.Typ.Width = thenv.Typ.Width 194 r.Typ.Scale = thenv.Typ.Scale 195 r.Col = make([]T, 1) 196 r.Col.([]T)[0] = thencols[0] 197 return r, nil 198 } 199 } 200 case whenv.IsScalar() && !thenv.IsScalar(): 201 rs.Typ.Precision = thenv.Typ.Precision 202 rs.Typ.Width = thenv.Typ.Width 203 rs.Typ.Scale = thenv.Typ.Scale 204 if !whenv.IsScalarNull() && whencols[0] { 205 copy(rscols, thencols) 206 rs.Nsp.Or(thenv.Nsp) 207 return rs, nil 208 } 209 case !whenv.IsScalar() && thenv.IsScalar(): 210 rs.Typ.Precision = thenv.Typ.Precision 211 rs.Typ.Width = thenv.Typ.Width 212 rs.Typ.Scale = thenv.Typ.Scale 213 if thenv.IsScalarNull() { 214 var j uint64 215 temp := make([]uint64, 0, l) 216 for j = 0; j < uint64(l); j++ { 217 if flag[j] { 218 continue 219 } 220 if whencols[j] { 221 temp = append(temp, j) 222 flag[j] = true 223 } 224 } 225 nulls.Add(rs.Nsp, temp...) 226 } else { 227 for j := 0; j < l; j++ { 228 if flag[j] { 229 continue 230 } 231 if whencols[j] { 232 rscols[j] = thencols[0] 233 flag[j] = true 234 } 235 } 236 } 237 case !whenv.IsScalar() && !thenv.IsScalar(): 238 rs.Typ.Precision = thenv.Typ.Precision 239 rs.Typ.Width = thenv.Typ.Width 240 rs.Typ.Scale = thenv.Typ.Scale 241 if nulls.Any(thenv.Nsp) { 242 var j uint64 243 temp := make([]uint64, 0, l) 244 for j = 0; j < uint64(l); j++ { 245 if whencols[j] { 246 if flag[j] { 247 continue 248 } 249 if nulls.Contains(thenv.Nsp, j) { 250 temp = append(temp, j) 251 } else { 252 rscols[j] = thencols[j] 253 } 254 flag[j] = true 255 } 256 } 257 nulls.Add(rs.Nsp, temp...) 258 } else { 259 for j := 0; j < l; j++ { 260 if whencols[j] { 261 if flag[j] { 262 continue 263 } 264 rscols[j] = thencols[j] 265 flag[j] = true 266 } 267 } 268 } 269 } 270 } 271 272 // deal the ELSE part 273 if len(vs)%2 == 0 || vs[len(vs)-1].IsScalarNull() { 274 var i uint64 275 temp := make([]uint64, 0, l) 276 for i = 0; i < uint64(l); i++ { 277 if !flag[i] { 278 temp = append(temp, i) 279 } 280 } 281 nulls.Add(rs.Nsp, temp...) 282 } else { 283 ev := vs[len(vs)-1] 284 ecols := ev.Col.([]T) 285 if ev.IsScalar() { 286 for i := 0; i < l; i++ { 287 if !flag[i] { 288 rscols[i] = ecols[0] 289 } 290 } 291 } else { 292 if nulls.Any(ev.Nsp) { 293 var i uint64 294 temp := make([]uint64, 0, l) 295 for i = 0; i < uint64(l); i++ { 296 if !flag[i] { 297 if nulls.Contains(ev.Nsp, i) { 298 temp = append(temp, i) 299 } else { 300 rscols[i] = ecols[i] 301 } 302 } 303 } 304 nulls.Add(rs.Nsp, temp...) 305 } else { 306 for i := 0; i < l; i++ { 307 if !flag[i] { 308 rscols[i] = ecols[i] 309 } 310 } 311 } 312 } 313 } 314 315 return rs, nil 316 } 317 318 // cwString is an evaluate function for case-when operator 319 // whose return type is char / varchar 320 func cwString(vs []*vector.Vector, proc *process.Process, typ types.Type) (*vector.Vector, error) { 321 nres := vector.Length(vs[0]) 322 results := make([]string, nres) 323 nsp := nulls.NewWithSize(nres) 324 flag := make([]bool, nres) 325 326 for i := 0; i < len(vs)-1; i += 2 { 327 whenv := vs[i] 328 thenv := vs[i+1] 329 whencols := vector.MustTCols[bool](whenv) 330 thencols := vector.MustStrCols(thenv) 331 switch { 332 case whenv.IsScalar() && thenv.IsScalar(): 333 if !whenv.IsScalarNull() && whencols[0] { 334 if thenv.IsScalarNull() { 335 for idx := range results { 336 if !flag[idx] { 337 nsp.Np.Add(uint64(idx)) 338 flag[idx] = true 339 } 340 } 341 } else { 342 for idx := range results { 343 if !flag[idx] { 344 results[idx] = thencols[0] 345 flag[idx] = true 346 } 347 } 348 } 349 } 350 case whenv.IsScalar() && !thenv.IsScalar(): 351 if !whenv.IsScalarNull() && whencols[0] { 352 for idx := range results { 353 if !flag[idx] { 354 if nulls.Contains(thenv.Nsp, uint64(idx)) { 355 nsp.Np.Add(uint64(idx)) 356 } else { 357 results[idx] = thencols[idx] 358 } 359 flag[idx] = true 360 } 361 } 362 } 363 case !whenv.IsScalar() && thenv.IsScalar(): 364 if thenv.IsScalarNull() { 365 for idx := range results { 366 if !flag[idx] { 367 if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] { 368 nsp.Np.Add(uint64(idx)) 369 flag[idx] = true 370 } 371 } 372 } 373 } else { 374 for idx := range results { 375 if !flag[idx] { 376 if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] { 377 results[idx] = thencols[0] 378 flag[idx] = true 379 } 380 } 381 } 382 } 383 case !whenv.IsScalar() && !thenv.IsScalar(): 384 for idx := range results { 385 if !flag[idx] { 386 if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] { 387 if nulls.Contains(thenv.Nsp, uint64(idx)) { 388 nsp.Np.Add(uint64(idx)) 389 } else { 390 results[idx] = thencols[idx] 391 } 392 flag[idx] = true 393 } 394 } 395 } 396 } 397 } 398 399 // deal the ELSE part 400 if len(vs)%2 == 0 || vs[len(vs)-1].IsScalarNull() { 401 for idx := range results { 402 if !flag[idx] { 403 nulls.Add(nsp, uint64(idx)) 404 flag[idx] = true 405 } 406 } 407 } else { 408 ev := vs[len(vs)-1] 409 ecols := vector.MustStrCols(ev) 410 if ev.IsScalar() { 411 for idx := range results { 412 if !flag[idx] { 413 results[idx] = ecols[0] 414 flag[idx] = true 415 } 416 } 417 } else { 418 for idx := range results { 419 if !flag[idx] { 420 if nulls.Contains(ev.Nsp, uint64(idx)) { 421 nulls.Add(nsp, uint64(idx)) 422 } else { 423 results[idx] = ecols[idx] 424 } 425 flag[idx] = true 426 } 427 } 428 } 429 } 430 431 return vector.NewWithStrings(typ, results, nsp, proc.Mp()), nil 432 }