github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/table_function/generate_series.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package table_function
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"math"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    26  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    27  	"github.com/matrixorigin/matrixone/pkg/container/types"
    28  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    29  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    30  	"github.com/matrixorigin/matrixone/pkg/sql/plan"
    31  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    32  )
    33  
    34  func generateSeriesString(arg any, buf *bytes.Buffer) {
    35  	buf.WriteString("generate_series")
    36  }
    37  
    38  func generateSeriesPrepare(_ *process.Process, arg *Argument) error {
    39  	return nil
    40  }
    41  
    42  func generateSeriesCall(_ int, proc *process.Process, arg *Argument) (bool, error) {
    43  	var (
    44  		err                                               error
    45  		startVec, endVec, stepVec, startVecTmp, endVecTmp *vector.Vector
    46  		rbat                                              *batch.Batch
    47  	)
    48  	defer func() {
    49  		if err != nil && rbat != nil {
    50  			rbat.Clean(proc.Mp())
    51  		}
    52  		if startVec != nil {
    53  			startVec.Free(proc.Mp())
    54  		}
    55  		if endVec != nil {
    56  			endVec.Free(proc.Mp())
    57  		}
    58  		if stepVec != nil {
    59  			stepVec.Free(proc.Mp())
    60  		}
    61  		if startVecTmp != nil {
    62  			startVecTmp.Free(proc.Mp())
    63  		}
    64  		if endVecTmp != nil {
    65  			endVecTmp.Free(proc.Mp())
    66  		}
    67  	}()
    68  	bat := proc.InputBatch()
    69  	if bat == nil {
    70  		return true, nil
    71  	}
    72  	startVec, err = colexec.EvalExpr(bat, proc, arg.Args[0])
    73  	if err != nil {
    74  		return false, err
    75  	}
    76  	endVec, err = colexec.EvalExpr(bat, proc, arg.Args[1])
    77  	if err != nil {
    78  		return false, err
    79  	}
    80  	rbat = batch.New(false, arg.Attrs)
    81  	rbat.Cnt = 1
    82  	for i := range arg.Attrs {
    83  		rbat.Vecs[i] = vector.New(dupType(plan.MakePlan2Type(&startVec.Typ)))
    84  	}
    85  	if len(arg.Args) == 3 {
    86  		stepVec, err = colexec.EvalExpr(bat, proc, arg.Args[2])
    87  		if err != nil {
    88  			return false, err
    89  		}
    90  	}
    91  	if !startVec.IsScalar() || !endVec.IsScalar() || (stepVec != nil && !stepVec.IsScalar()) {
    92  		return false, moerr.NewInvalidInput(proc.Ctx, "generate_series only support scalar")
    93  	}
    94  	switch startVec.GetType().Oid {
    95  	case types.T_int32:
    96  		if endVec.Typ.Oid != types.T_int32 || (stepVec != nil && stepVec.Typ.Oid != types.T_int32) {
    97  			return false, moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.Typ.Oid.String(), endVec.Typ.Oid.String())
    98  		}
    99  		err = handleInt(startVec, endVec, stepVec, generateInt32, false, proc, rbat)
   100  		if err != nil {
   101  			return false, err
   102  		}
   103  	case types.T_int64:
   104  		if endVec.Typ.Oid != types.T_int64 || (stepVec != nil && stepVec.Typ.Oid != types.T_int64) {
   105  			return false, moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.Typ.Oid.String(), endVec.Typ.Oid.String())
   106  		}
   107  		err = handleInt(startVec, endVec, stepVec, generateInt64, false, proc, rbat)
   108  		if err != nil {
   109  			return false, err
   110  		}
   111  	case types.T_datetime:
   112  		if endVec.Typ.Oid != types.T_datetime || (stepVec != nil && stepVec.Typ.Oid != types.T_varchar) {
   113  			return false, moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.Typ.Oid.String(), endVec.Typ.Oid.String())
   114  		}
   115  		err = handleDatetime(startVec, endVec, stepVec, false, proc, rbat)
   116  	case types.T_varchar:
   117  		if stepVec == nil {
   118  			return false, moerr.NewInvalidInput(proc.Ctx, "generate_series must specify step")
   119  		}
   120  		startSlice := vector.MustStrCols(startVec)
   121  		endSlice := vector.MustStrCols(endVec)
   122  		stepSlice := vector.MustStrCols(stepVec)
   123  		startStr := startSlice[0]
   124  		endStr := endSlice[0]
   125  		stepStr := stepSlice[0]
   126  		precision := int32(findPrecision(startStr, endStr))
   127  		rbat.Vecs[0].Typ.Precision = precision
   128  		start, err := types.ParseDatetime(startStr, precision)
   129  		if err != nil {
   130  			err = tryInt(startStr, endStr, stepStr, proc, rbat)
   131  			if err != nil {
   132  				return false, err
   133  			}
   134  			break
   135  		}
   136  
   137  		end, err := types.ParseDatetime(endStr, precision)
   138  		if err != nil {
   139  			return false, err
   140  		}
   141  		startVecTmp, err = makeVector(types.T_datetime.ToType(), start, proc.Mp())
   142  		if err != nil {
   143  			return false, err
   144  		}
   145  		endVecTmp, err = makeVector(types.T_datetime.ToType(), end, proc.Mp())
   146  		if err != nil {
   147  			return false, err
   148  		}
   149  		err = handleDatetime(startVecTmp, endVecTmp, stepVec, true, proc, rbat)
   150  		if err != nil {
   151  			return false, err
   152  		}
   153  
   154  	default:
   155  		return false, moerr.NewNotSupported(proc.Ctx, "generate_series not support type %s", startVec.Typ.Oid.String())
   156  
   157  	}
   158  	proc.SetInputBatch(rbat)
   159  	return false, nil
   160  }
   161  
   162  func judgeArgs[T generateSeriesNumber](ctx context.Context, start, end, step T) ([]T, error) {
   163  	if step == 0 {
   164  		return nil, moerr.NewInvalidInput(ctx, "step size cannot equal zero")
   165  	}
   166  	if start == end {
   167  		return []T{start}, nil
   168  	}
   169  	s1 := step > 0
   170  	s2 := end > start
   171  	if s1 != s2 {
   172  		return []T{}, nil
   173  	}
   174  	return nil, nil
   175  }
   176  
   177  func trimStep(step string) string {
   178  	step = strings.TrimSpace(step)
   179  	step = strings.TrimSuffix(step, "s")
   180  	step = strings.TrimSuffix(step, "(s)")
   181  	return step
   182  }
   183  
   184  func genStep(ctx context.Context, step string) (num int64, tp types.IntervalType, err error) {
   185  	step = trimStep(step)
   186  	s := strings.Split(step, " ")
   187  	if len(s) != 2 {
   188  		err = moerr.NewInvalidInput(ctx, "invalid step '%s'", step)
   189  		return
   190  	}
   191  	num, err = strconv.ParseInt(s[0], 10, 64)
   192  	if err != nil {
   193  		err = moerr.NewInvalidInput(ctx, "invalid step '%s'", step)
   194  		return
   195  	}
   196  	tp, err = types.IntervalTypeOf(s[1])
   197  	return
   198  }
   199  
   200  func generateInt32(ctx context.Context, start, end, step int32) ([]int32, error) {
   201  	res, err := judgeArgs(ctx, start, end, step)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  	if res != nil {
   206  		return res, nil
   207  	}
   208  	if step > 0 {
   209  		for i := start; i <= end; i += step {
   210  			res = append(res, i)
   211  			if i > 0 && math.MaxInt32-i < step {
   212  				break
   213  			}
   214  		}
   215  	} else {
   216  		for i := start; i >= end; i += step {
   217  			res = append(res, i)
   218  			if i < 0 && math.MinInt32-i > step {
   219  				break
   220  			}
   221  		}
   222  	}
   223  	return res, nil
   224  }
   225  
   226  func generateInt64(ctx context.Context, start, end, step int64) ([]int64, error) {
   227  	res, err := judgeArgs(ctx, start, end, step)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	if res != nil {
   232  		return res, nil
   233  	}
   234  	if step > 0 {
   235  		for i := start; i <= end; i += step {
   236  			res = append(res, i)
   237  			if i > 0 && math.MaxInt64-i < step {
   238  				break
   239  			}
   240  		}
   241  	} else {
   242  		for i := start; i >= end; i += step {
   243  			res = append(res, i)
   244  			if i < 0 && math.MinInt64-i > step {
   245  				break
   246  			}
   247  		}
   248  	}
   249  	return res, nil
   250  }
   251  
   252  func generateDatetime(ctx context.Context, start, end types.Datetime, stepStr string, precision int32) ([]types.Datetime, error) {
   253  	step, tp, err := genStep(ctx, stepStr)
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	var res []types.Datetime
   258  	res, err = judgeArgs(ctx, start, end, types.Datetime(step)) // here, transfer step to types.Datetime may change the inner behavior of datetime, but we just care the sign of step.
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	if res != nil {
   263  		return res, nil
   264  	}
   265  	if step > 0 {
   266  		for i := start; i <= end; {
   267  			res = append(res, i)
   268  			var ok bool
   269  			i, ok = i.AddInterval(step, tp, types.DateTimeType)
   270  			if !ok {
   271  				return nil, moerr.NewInvalidInput(ctx, "invalid step '%s'", stepStr)
   272  			}
   273  		}
   274  	} else {
   275  		for i := start; i >= end; {
   276  			res = append(res, i)
   277  			var ok bool
   278  			i, ok = i.AddInterval(step, tp, types.DateTimeType)
   279  			if !ok {
   280  				return nil, moerr.NewInvalidInput(ctx, "invalid step '%s'", stepStr)
   281  			}
   282  		}
   283  	}
   284  	return res, nil
   285  }
   286  
   287  func handleInt[T int32 | int64](startVec, endVec, stepVec *vector.Vector, genFn func(context.Context, T, T, T) ([]T, error), toString bool, proc *process.Process, rbat *batch.Batch) error {
   288  	var (
   289  		start, end, step T
   290  	)
   291  	startSlice := vector.MustTCols[T](startVec)
   292  	endSlice := vector.MustTCols[T](endVec)
   293  	start = startSlice[0]
   294  	end = endSlice[0]
   295  	if stepVec != nil {
   296  		stepSlice := vector.MustTCols[T](stepVec)
   297  		step = stepSlice[0]
   298  	} else {
   299  		if start < end {
   300  			step = 1
   301  		} else {
   302  			step = -1
   303  		}
   304  	}
   305  	res, err := genFn(proc.Ctx, start, end, step)
   306  	if err != nil {
   307  		return err
   308  	}
   309  	for i := range res {
   310  		if toString {
   311  			err = rbat.Vecs[0].Append([]byte(strconv.FormatInt(int64(res[i]), 10)), false, proc.Mp())
   312  		} else {
   313  			err = rbat.Vecs[0].Append(res[i], false, proc.Mp())
   314  		}
   315  		if err != nil {
   316  			return err
   317  		}
   318  	}
   319  	rbat.InitZsOne(len(res))
   320  	return nil
   321  }
   322  
   323  func handleDatetime(startVec, endVec, stepVec *vector.Vector, toString bool, proc *process.Process, rbat *batch.Batch) error {
   324  	var (
   325  		start, end types.Datetime
   326  		step       string
   327  	)
   328  	startSlice := vector.MustTCols[types.Datetime](startVec)
   329  	endSlice := vector.MustTCols[types.Datetime](endVec)
   330  	start = startSlice[0]
   331  	end = endSlice[0]
   332  	if stepVec == nil {
   333  		return moerr.NewInvalidInput(proc.Ctx, "generate_series datetime must specify step")
   334  	}
   335  	stepSlice := vector.MustStrCols(stepVec)
   336  	step = stepSlice[0]
   337  	res, err := generateDatetime(proc.Ctx, start, end, step, startVec.Typ.Precision)
   338  	if err != nil {
   339  		return err
   340  	}
   341  	for i := range res {
   342  		if toString {
   343  			err = rbat.Vecs[0].Append([]byte(res[i].String2(rbat.Vecs[0].Typ.Precision)), false, proc.Mp())
   344  		} else {
   345  			err = rbat.Vecs[0].Append(res[i], false, proc.Mp())
   346  		}
   347  		if err != nil {
   348  			return err
   349  		}
   350  	}
   351  	rbat.InitZsOne(len(res))
   352  	return nil
   353  }
   354  
   355  func findPrecision(s1, s2 string) int {
   356  	p1 := 0
   357  	if strings.Contains(s1, ".") {
   358  		p1 = len(s1) - strings.LastIndex(s1, ".")
   359  	}
   360  	p2 := 0
   361  	if strings.Contains(s2, ".") {
   362  		p2 = len(s2) - strings.LastIndex(s2, ".")
   363  	}
   364  	if p2 > p1 {
   365  		p1 = p2
   366  	}
   367  	if p1 > 6 {
   368  		p1 = 6
   369  	}
   370  	return p1
   371  }
   372  func makeVector(p types.Type, v interface{}, mp *mpool.MPool) (*vector.Vector, error) {
   373  	vec := vector.NewConst(p, 1)
   374  	err := vec.Append(v, false, mp)
   375  	if err != nil {
   376  		vec.Free(mp)
   377  		return nil, err
   378  	}
   379  	return vec, nil
   380  }
   381  
   382  func tryInt(startStr, endStr, stepStr string, proc *process.Process, rbat *batch.Batch) error {
   383  	var (
   384  		startVec, endVec, stepVec *vector.Vector
   385  		err                       error
   386  	)
   387  	defer func() {
   388  		if startVec != nil {
   389  			startVec.Free(proc.Mp())
   390  		}
   391  		if endVec != nil {
   392  			endVec.Free(proc.Mp())
   393  		}
   394  		if stepVec != nil {
   395  			stepVec.Free(proc.Mp())
   396  		}
   397  	}()
   398  	startInt, err := strconv.ParseInt(startStr, 10, 64)
   399  	if err != nil {
   400  		return err
   401  	}
   402  	endInt, err := strconv.ParseInt(endStr, 10, 64)
   403  	if err != nil {
   404  		return err
   405  	}
   406  	stepInt, err := strconv.ParseInt(stepStr, 10, 64)
   407  	if err != nil {
   408  		return err
   409  	}
   410  
   411  	startVec, err = makeVector(types.T_int64.ToType(), startInt, proc.Mp())
   412  	if err != nil {
   413  		return err
   414  	}
   415  	endVec, err = makeVector(types.T_int64.ToType(), endInt, proc.Mp())
   416  	if err != nil {
   417  		return err
   418  	}
   419  	stepVec, err = makeVector(types.T_int64.ToType(), stepInt, proc.Mp())
   420  	if err != nil {
   421  		return err
   422  	}
   423  	err = handleInt(startVec, endVec, stepVec, generateInt64, true, proc, rbat)
   424  	if err != nil {
   425  		return err
   426  	}
   427  	return nil
   428  }