github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/function/operator/case_when.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package operator
    16  
    17  import (
    18  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    19  	"github.com/matrixorigin/matrixone/pkg/container/types"
    20  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    21  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    22  	"golang.org/x/exp/constraints"
    23  )
    24  
    25  // case-when operator only support format like that
    26  //
    27  //	`
    28  //		case
    29  //		when A = a1 then ...
    30  //		when A = a2 then ...
    31  //		when A = a3 then ...
    32  //		(else ...)
    33  //	`
    34  //
    35  // format `case A when a1 then ... when a2 then ...` should be converted to required format.
    36  var (
    37  	CaseWhenUint8 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    38  		return cwGeneral[uint8](vs, proc, types.Type{Oid: types.T_uint8})
    39  	}
    40  
    41  	CaseWhenUint16 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    42  		return cwGeneral[uint16](vs, proc, types.Type{Oid: types.T_uint16})
    43  	}
    44  
    45  	CaseWhenUint32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    46  		return cwGeneral[uint32](vs, proc, types.Type{Oid: types.T_uint32})
    47  	}
    48  
    49  	CaseWhenUint64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    50  		return cwGeneral[uint64](vs, proc, types.Type{Oid: types.T_uint64})
    51  	}
    52  
    53  	CaseWhenInt8 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    54  		return cwGeneral[int8](vs, proc, types.Type{Oid: types.T_int8})
    55  	}
    56  
    57  	CaseWhenInt16 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    58  		return cwGeneral[int16](vs, proc, types.Type{Oid: types.T_int16})
    59  	}
    60  
    61  	CaseWhenInt32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    62  		return cwGeneral[int32](vs, proc, types.Type{Oid: types.T_int32})
    63  	}
    64  
    65  	CaseWhenInt64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    66  		return cwGeneral[int64](vs, proc, types.Type{Oid: types.T_int64})
    67  	}
    68  
    69  	CaseWhenFloat32 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    70  		return cwGeneral[float32](vs, proc, types.Type{Oid: types.T_float32})
    71  	}
    72  
    73  	CaseWhenFloat64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    74  		return cwGeneral[float64](vs, proc, types.Type{Oid: types.T_float64})
    75  	}
    76  
    77  	CaseWhenBool = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    78  		return cwGeneral[bool](vs, proc, types.Type{Oid: types.T_bool})
    79  	}
    80  
    81  	CaseWhenDate = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    82  		return cwGeneral[types.Date](vs, proc, types.Type{Oid: types.T_date})
    83  	}
    84  
    85  	CaseWhenTime = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    86  		return cwGeneral[types.Time](vs, proc, types.Type{Oid: types.T_time})
    87  	}
    88  
    89  	CaseWhenDateTime = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    90  		return cwGeneral[types.Datetime](vs, proc, types.Type{Oid: types.T_datetime})
    91  	}
    92  
    93  	CaseWhenVarchar = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    94  		return cwString(vs, proc, types.Type{Oid: types.T_varchar, Width: types.MaxVarcharLen})
    95  	}
    96  
    97  	CaseWhenChar = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
    98  		return cwString(vs, proc, types.Type{Oid: types.T_char})
    99  	}
   100  
   101  	CaseWhenDecimal64 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   102  		return cwGeneral[types.Decimal64](vs, proc, types.Type{Oid: types.T_decimal64})
   103  	}
   104  
   105  	CaseWhenDecimal128 = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   106  		return cwGeneral[types.Decimal128](vs, proc, types.Type{Oid: types.T_decimal128})
   107  	}
   108  
   109  	CaseWhenTimestamp = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   110  		return cwGeneral[types.Timestamp](vs, proc, types.Type{Oid: types.T_timestamp})
   111  	}
   112  
   113  	CaseWhenUuid = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   114  		return cwGeneral[types.Uuid](vs, proc, types.Type{Oid: types.T_uuid})
   115  	}
   116  
   117  	CaseWhenBlob = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   118  		return cwString(vs, proc, types.Type{Oid: types.T_blob})
   119  	}
   120  
   121  	CaseWhenText = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   122  		return cwString(vs, proc, types.Type{Oid: types.T_text})
   123  	}
   124  
   125  	CaseWhenJson = func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   126  		return cwString(vs, proc, types.Type{Oid: types.T_json})
   127  	}
   128  )
   129  
   130  // CwTypeCheckFn is type check function for case-when operator
   131  func CwTypeCheckFn(inputTypes []types.T, _ []types.T, ret types.T) bool {
   132  	l := len(inputTypes)
   133  	if l >= 2 {
   134  		for i := 0; i < l-1; i += 2 {
   135  			if inputTypes[i] != types.T_bool {
   136  				return false
   137  			}
   138  		}
   139  
   140  		if l%2 == 1 {
   141  			if inputTypes[l-1] != ret && inputTypes[l-1] != types.T_any {
   142  				return false
   143  			}
   144  		}
   145  
   146  		for i := 1; i < l; i += 2 {
   147  			if inputTypes[i] != ret && inputTypes[i] != types.T_any {
   148  				return false
   149  			}
   150  		}
   151  		return true
   152  	}
   153  	return false
   154  }
   155  
   156  type OrderedValue interface {
   157  	constraints.Integer | constraints.Float | types.Date | types.Datetime | types.Decimal64 | types.Timestamp
   158  }
   159  
   160  type NormalType interface {
   161  	constraints.Integer | constraints.Float | bool | types.Date | types.Datetime |
   162  		types.Decimal64 | types.Decimal128 | types.Timestamp | types.Uuid
   163  }
   164  
   165  // cwGeneral is a general evaluate function for case-when operator
   166  // whose return type is uint / int / float / bool / date / datetime
   167  func cwGeneral[T NormalType](vs []*vector.Vector, proc *process.Process, t types.Type) (*vector.Vector, error) {
   168  	l := vector.Length(vs[0])
   169  
   170  	rs, err := proc.AllocVector(t, int64(l*t.Oid.TypeLen()))
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	rs.Col = vector.DecodeFixedCol[T](rs, t.Oid.TypeLen())
   175  	rs.Col = rs.Col.([]T)[:l]
   176  	rscols := rs.Col.([]T)
   177  
   178  	flag := make([]bool, l) // if flag[i] is false, it couldn't adapt to any case
   179  
   180  	for i := 0; i < len(vs)-1; i += 2 {
   181  		whenv := vs[i]
   182  		thenv := vs[i+1]
   183  		whencols := vector.MustTCols[bool](whenv)
   184  		thencols := vector.MustTCols[T](thenv)
   185  		switch {
   186  		case whenv.IsScalar() && thenv.IsScalar():
   187  			if !whenv.IsScalarNull() && whencols[0] {
   188  				if thenv.IsScalarNull() {
   189  					return proc.AllocScalarNullVector(t), nil
   190  				} else {
   191  					r := proc.AllocScalarVector(t)
   192  					r.Typ.Precision = thenv.Typ.Precision
   193  					r.Typ.Width = thenv.Typ.Width
   194  					r.Typ.Scale = thenv.Typ.Scale
   195  					r.Col = make([]T, 1)
   196  					r.Col.([]T)[0] = thencols[0]
   197  					return r, nil
   198  				}
   199  			}
   200  		case whenv.IsScalar() && !thenv.IsScalar():
   201  			rs.Typ.Precision = thenv.Typ.Precision
   202  			rs.Typ.Width = thenv.Typ.Width
   203  			rs.Typ.Scale = thenv.Typ.Scale
   204  			if !whenv.IsScalarNull() && whencols[0] {
   205  				copy(rscols, thencols)
   206  				rs.Nsp.Or(thenv.Nsp)
   207  				return rs, nil
   208  			}
   209  		case !whenv.IsScalar() && thenv.IsScalar():
   210  			rs.Typ.Precision = thenv.Typ.Precision
   211  			rs.Typ.Width = thenv.Typ.Width
   212  			rs.Typ.Scale = thenv.Typ.Scale
   213  			if thenv.IsScalarNull() {
   214  				var j uint64
   215  				temp := make([]uint64, 0, l)
   216  				for j = 0; j < uint64(l); j++ {
   217  					if flag[j] {
   218  						continue
   219  					}
   220  					if whencols[j] {
   221  						temp = append(temp, j)
   222  						flag[j] = true
   223  					}
   224  				}
   225  				nulls.Add(rs.Nsp, temp...)
   226  			} else {
   227  				for j := 0; j < l; j++ {
   228  					if flag[j] {
   229  						continue
   230  					}
   231  					if whencols[j] {
   232  						rscols[j] = thencols[0]
   233  						flag[j] = true
   234  					}
   235  				}
   236  			}
   237  		case !whenv.IsScalar() && !thenv.IsScalar():
   238  			rs.Typ.Precision = thenv.Typ.Precision
   239  			rs.Typ.Width = thenv.Typ.Width
   240  			rs.Typ.Scale = thenv.Typ.Scale
   241  			if nulls.Any(thenv.Nsp) {
   242  				var j uint64
   243  				temp := make([]uint64, 0, l)
   244  				for j = 0; j < uint64(l); j++ {
   245  					if whencols[j] {
   246  						if flag[j] {
   247  							continue
   248  						}
   249  						if nulls.Contains(thenv.Nsp, j) {
   250  							temp = append(temp, j)
   251  						} else {
   252  							rscols[j] = thencols[j]
   253  						}
   254  						flag[j] = true
   255  					}
   256  				}
   257  				nulls.Add(rs.Nsp, temp...)
   258  			} else {
   259  				for j := 0; j < l; j++ {
   260  					if whencols[j] {
   261  						if flag[j] {
   262  							continue
   263  						}
   264  						rscols[j] = thencols[j]
   265  						flag[j] = true
   266  					}
   267  				}
   268  			}
   269  		}
   270  	}
   271  
   272  	// deal the ELSE part
   273  	if len(vs)%2 == 0 || vs[len(vs)-1].IsScalarNull() {
   274  		var i uint64
   275  		temp := make([]uint64, 0, l)
   276  		for i = 0; i < uint64(l); i++ {
   277  			if !flag[i] {
   278  				temp = append(temp, i)
   279  			}
   280  		}
   281  		nulls.Add(rs.Nsp, temp...)
   282  	} else {
   283  		ev := vs[len(vs)-1]
   284  		ecols := ev.Col.([]T)
   285  		if ev.IsScalar() {
   286  			for i := 0; i < l; i++ {
   287  				if !flag[i] {
   288  					rscols[i] = ecols[0]
   289  				}
   290  			}
   291  		} else {
   292  			if nulls.Any(ev.Nsp) {
   293  				var i uint64
   294  				temp := make([]uint64, 0, l)
   295  				for i = 0; i < uint64(l); i++ {
   296  					if !flag[i] {
   297  						if nulls.Contains(ev.Nsp, i) {
   298  							temp = append(temp, i)
   299  						} else {
   300  							rscols[i] = ecols[i]
   301  						}
   302  					}
   303  				}
   304  				nulls.Add(rs.Nsp, temp...)
   305  			} else {
   306  				for i := 0; i < l; i++ {
   307  					if !flag[i] {
   308  						rscols[i] = ecols[i]
   309  					}
   310  				}
   311  			}
   312  		}
   313  	}
   314  
   315  	return rs, nil
   316  }
   317  
   318  // cwString is an evaluate function for case-when operator
   319  // whose return type is char / varchar
   320  func cwString(vs []*vector.Vector, proc *process.Process, typ types.Type) (*vector.Vector, error) {
   321  	nres := vector.Length(vs[0])
   322  	results := make([]string, nres)
   323  	nsp := nulls.NewWithSize(nres)
   324  	flag := make([]bool, nres)
   325  
   326  	for i := 0; i < len(vs)-1; i += 2 {
   327  		whenv := vs[i]
   328  		thenv := vs[i+1]
   329  		whencols := vector.MustTCols[bool](whenv)
   330  		thencols := vector.MustStrCols(thenv)
   331  		switch {
   332  		case whenv.IsScalar() && thenv.IsScalar():
   333  			if !whenv.IsScalarNull() && whencols[0] {
   334  				if thenv.IsScalarNull() {
   335  					for idx := range results {
   336  						if !flag[idx] {
   337  							nsp.Np.Add(uint64(idx))
   338  							flag[idx] = true
   339  						}
   340  					}
   341  				} else {
   342  					for idx := range results {
   343  						if !flag[idx] {
   344  							results[idx] = thencols[0]
   345  							flag[idx] = true
   346  						}
   347  					}
   348  				}
   349  			}
   350  		case whenv.IsScalar() && !thenv.IsScalar():
   351  			if !whenv.IsScalarNull() && whencols[0] {
   352  				for idx := range results {
   353  					if !flag[idx] {
   354  						if nulls.Contains(thenv.Nsp, uint64(idx)) {
   355  							nsp.Np.Add(uint64(idx))
   356  						} else {
   357  							results[idx] = thencols[idx]
   358  						}
   359  						flag[idx] = true
   360  					}
   361  				}
   362  			}
   363  		case !whenv.IsScalar() && thenv.IsScalar():
   364  			if thenv.IsScalarNull() {
   365  				for idx := range results {
   366  					if !flag[idx] {
   367  						if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] {
   368  							nsp.Np.Add(uint64(idx))
   369  							flag[idx] = true
   370  						}
   371  					}
   372  				}
   373  			} else {
   374  				for idx := range results {
   375  					if !flag[idx] {
   376  						if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] {
   377  							results[idx] = thencols[0]
   378  							flag[idx] = true
   379  						}
   380  					}
   381  				}
   382  			}
   383  		case !whenv.IsScalar() && !thenv.IsScalar():
   384  			for idx := range results {
   385  				if !flag[idx] {
   386  					if !nulls.Contains(whenv.Nsp, uint64(idx)) && whencols[idx] {
   387  						if nulls.Contains(thenv.Nsp, uint64(idx)) {
   388  							nsp.Np.Add(uint64(idx))
   389  						} else {
   390  							results[idx] = thencols[idx]
   391  						}
   392  						flag[idx] = true
   393  					}
   394  				}
   395  			}
   396  		}
   397  	}
   398  
   399  	// deal the ELSE part
   400  	if len(vs)%2 == 0 || vs[len(vs)-1].IsScalarNull() {
   401  		for idx := range results {
   402  			if !flag[idx] {
   403  				nulls.Add(nsp, uint64(idx))
   404  				flag[idx] = true
   405  			}
   406  		}
   407  	} else {
   408  		ev := vs[len(vs)-1]
   409  		ecols := vector.MustStrCols(ev)
   410  		if ev.IsScalar() {
   411  			for idx := range results {
   412  				if !flag[idx] {
   413  					results[idx] = ecols[0]
   414  					flag[idx] = true
   415  				}
   416  			}
   417  		} else {
   418  			for idx := range results {
   419  				if !flag[idx] {
   420  					if nulls.Contains(ev.Nsp, uint64(idx)) {
   421  						nulls.Add(nsp, uint64(idx))
   422  					} else {
   423  						results[idx] = ecols[idx]
   424  					}
   425  					flag[idx] = true
   426  				}
   427  			}
   428  		}
   429  	}
   430  
   431  	return vector.NewWithStrings(typ, results, nsp, proc.Mp()), nil
   432  }