github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/table_function/generate_series.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package table_function
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"math"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    26  	"github.com/matrixorigin/matrixone/pkg/container/types"
    27  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    28  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    29  	"github.com/matrixorigin/matrixone/pkg/vm"
    30  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    31  )
    32  
    33  const addBatchSize int64 = 8191
    34  
    35  func generateSeriesString(buf *bytes.Buffer) {
    36  	buf.WriteString("generate_series")
    37  }
    38  
    39  func generateSeriesPrepare(proc *process.Process, arg *Argument) (err error) {
    40  	arg.ctr = new(container)
    41  	arg.ctr.executorsForArgs, err = colexec.NewExpressionExecutorsFromPlanExpressions(proc, arg.Args)
    42  	arg.generateSeries = new(generateSeriesArg)
    43  	return err
    44  }
    45  
    46  func resetGenerateSeriesState(proc *process.Process, arg *Argument) error {
    47  	if arg.generateSeries.state == initArg {
    48  		var startVec, endVec, stepVec, startVecTmp, endVecTmp *vector.Vector
    49  		var err error
    50  		arg.generateSeries.state = genBatch
    51  
    52  		defer func() {
    53  			if startVecTmp != nil {
    54  				startVecTmp.Free(proc.Mp())
    55  			}
    56  			if endVecTmp != nil {
    57  				endVecTmp.Free(proc.Mp())
    58  			}
    59  		}()
    60  
    61  		if len(arg.ctr.executorsForArgs) == 1 {
    62  			endVec, err = arg.ctr.executorsForArgs[0].Eval(proc, []*batch.Batch{batch.EmptyForConstFoldBatch})
    63  			if err != nil {
    64  				return err
    65  			}
    66  			startVec, err = vector.NewConstFixed(types.T_int64.ToType(), int64(1), 1, proc.Mp())
    67  		} else {
    68  			startVec, err = arg.ctr.executorsForArgs[0].Eval(proc, []*batch.Batch{batch.EmptyForConstFoldBatch})
    69  			if err != nil {
    70  				return err
    71  			}
    72  			endVec, err = arg.ctr.executorsForArgs[1].Eval(proc, []*batch.Batch{batch.EmptyForConstFoldBatch})
    73  		}
    74  		if err != nil {
    75  			return err
    76  		}
    77  		if len(arg.Args) == 3 {
    78  			stepVec, err = arg.ctr.executorsForArgs[2].Eval(proc, []*batch.Batch{batch.EmptyForConstFoldBatch})
    79  			if err != nil {
    80  				return err
    81  			}
    82  		}
    83  		if !startVec.IsConst() || !endVec.IsConst() || (stepVec != nil && !stepVec.IsConst()) {
    84  			return moerr.NewInvalidInput(proc.Ctx, "generate_series only support scalar")
    85  		}
    86  		arg.generateSeries.startVecType = startVec.GetType()
    87  		switch arg.generateSeries.startVecType.Oid {
    88  		case types.T_int32:
    89  			if endVec.GetType().Oid != types.T_int32 || (stepVec != nil && stepVec.GetType().Oid != types.T_int32) {
    90  				return moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.GetType().Oid.String(), endVec.GetType().Oid.String())
    91  			}
    92  			initStartAndEnd[int32](arg, startVec, endVec, stepVec)
    93  		case types.T_int64:
    94  			if endVec.GetType().Oid != types.T_int64 || (stepVec != nil && stepVec.GetType().Oid != types.T_int64) {
    95  				return moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.GetType().Oid.String(), endVec.GetType().Oid.String())
    96  			}
    97  			initStartAndEnd[int64](arg, startVec, endVec, stepVec)
    98  		case types.T_datetime:
    99  			if endVec.GetType().Oid != types.T_datetime || (stepVec != nil && stepVec.GetType().Oid != types.T_varchar) {
   100  				return moerr.NewInvalidInput(proc.Ctx, "generate_series arguments must be of the same type, type1: %s, type2: %s", startVec.GetType().Oid.String(), endVec.GetType().Oid.String())
   101  			}
   102  			startSlice := vector.MustFixedCol[types.Datetime](startVec)
   103  			endSlice := vector.MustFixedCol[types.Datetime](endVec)
   104  			arg.generateSeries.start = startSlice[0]
   105  			arg.generateSeries.end = endSlice[0]
   106  			arg.generateSeries.last = endSlice[0]
   107  			if stepVec == nil {
   108  				return moerr.NewInvalidInput(proc.Ctx, "generate_series datetime must specify step")
   109  			}
   110  			stepSlice := vector.MustStrCol(stepVec)
   111  			arg.generateSeries.step = stepSlice[0]
   112  		case types.T_varchar:
   113  			if stepVec == nil {
   114  				return moerr.NewInvalidInput(proc.Ctx, "generate_series must specify step")
   115  			}
   116  			startSlice := vector.MustStrCol(startVec)
   117  			endSlice := vector.MustStrCol(endVec)
   118  			startStr := startSlice[0]
   119  			endStr := endSlice[0]
   120  			scale := int32(findScale(startStr, endStr))
   121  			startTmp, err := types.ParseDatetime(startStr, scale)
   122  			if err != nil {
   123  				return err
   124  			}
   125  
   126  			endTmp, err := types.ParseDatetime(endStr, scale)
   127  			if err != nil {
   128  				return err
   129  			}
   130  			if startVecTmp, err = vector.NewConstFixed(types.T_datetime.ToType(), startTmp, 1, proc.Mp()); err != nil {
   131  				return err
   132  			}
   133  			if endVecTmp, err = vector.NewConstFixed(types.T_datetime.ToType(), endTmp, 1, proc.Mp()); err != nil {
   134  				return err
   135  			}
   136  
   137  			newStartSlice := vector.MustFixedCol[types.Datetime](startVecTmp)
   138  			newEndSlice := vector.MustFixedCol[types.Datetime](endVecTmp)
   139  			arg.generateSeries.scale = scale
   140  			arg.generateSeries.start = newStartSlice[0]
   141  			arg.generateSeries.end = newEndSlice[0]
   142  			arg.generateSeries.last = newEndSlice[0]
   143  			stepSlice := vector.MustStrCol(stepVec)
   144  			arg.generateSeries.step = stepSlice[0]
   145  		default:
   146  			return moerr.NewNotSupported(proc.Ctx, "generate_series not support type %s", arg.generateSeries.startVecType.Oid.String())
   147  
   148  		}
   149  	}
   150  
   151  	if arg.generateSeries.state == genBatch {
   152  		switch arg.generateSeries.startVecType.Oid {
   153  		case types.T_int32:
   154  			computeNewStartAndEnd[int32](arg)
   155  		case types.T_int64:
   156  			computeNewStartAndEnd[int64](arg)
   157  		case types.T_varchar, types.T_datetime:
   158  			//todo split datetime batch
   159  			arg.generateSeries.state = genFinish
   160  		default:
   161  			arg.generateSeries.state = genFinish
   162  		}
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  func generateSeriesCall(_ int, proc *process.Process, arg *Argument, result *vm.CallResult) (bool, error) {
   169  	var (
   170  		err  error
   171  		rbat *batch.Batch
   172  	)
   173  	defer func() {
   174  		if err != nil && rbat != nil {
   175  			rbat.Clean(proc.Mp())
   176  		}
   177  	}()
   178  
   179  	if arg.generateSeries.state == genFinish {
   180  		return true, nil
   181  	}
   182  
   183  	err = resetGenerateSeriesState(proc, arg)
   184  	if err != nil {
   185  		return false, err
   186  	}
   187  
   188  	rbat = batch.NewWithSize(len(arg.Attrs))
   189  	rbat.Attrs = arg.Attrs
   190  	for i := range arg.Attrs {
   191  		rbat.Vecs[i] = proc.GetVector(arg.retSchema[i])
   192  	}
   193  
   194  	switch arg.generateSeries.startVecType.Oid {
   195  	case types.T_int32:
   196  		start := arg.generateSeries.start.(int32)
   197  		end := arg.generateSeries.end.(int32)
   198  		step := arg.generateSeries.step.(int32)
   199  		err = handleInt(start, end, step, generateInt32, proc, rbat)
   200  		if err != nil {
   201  			return false, err
   202  		}
   203  	case types.T_int64:
   204  		start := arg.generateSeries.start.(int64)
   205  		end := arg.generateSeries.end.(int64)
   206  		step := arg.generateSeries.step.(int64)
   207  		err = handleInt(start, end, step, generateInt64, proc, rbat)
   208  		if err != nil {
   209  			return false, err
   210  		}
   211  	case types.T_datetime:
   212  		start := arg.generateSeries.start.(types.Datetime)
   213  		end := arg.generateSeries.end.(types.Datetime)
   214  		step := arg.generateSeries.step.(string)
   215  
   216  		err = handleDatetime(start, end, step, -1, proc, rbat)
   217  	case types.T_varchar:
   218  		start := arg.generateSeries.start.(types.Datetime)
   219  		end := arg.generateSeries.end.(types.Datetime)
   220  		step := arg.generateSeries.step.(string)
   221  		scale := arg.generateSeries.scale
   222  		rbat.Vecs[0].GetType().Scale = scale
   223  
   224  		err = handleDatetime(start, end, step, scale, proc, rbat)
   225  		if err != nil {
   226  			return false, err
   227  		}
   228  
   229  	default:
   230  		return false, moerr.NewNotSupported(proc.Ctx, "generate_series not support type %s", arg.generateSeries.startVecType.Oid.String())
   231  
   232  	}
   233  	result.Batch = rbat
   234  	return false, nil
   235  }
   236  
   237  func judgeArgs[T generateSeriesNumber](ctx context.Context, start, end, step T) ([]T, error) {
   238  	if step == 0 {
   239  		return nil, moerr.NewInvalidInput(ctx, "step size cannot equal zero")
   240  	}
   241  	if start == end {
   242  		return []T{start}, nil
   243  	}
   244  	s1 := step > 0
   245  	s2 := end > start
   246  	if s1 != s2 {
   247  		return []T{}, nil
   248  	}
   249  	return nil, nil
   250  }
   251  
   252  func initStartAndEnd[T generateSeriesNumber](arg *Argument, startVec, endVec, stepVec *vector.Vector) {
   253  	startSlice := vector.MustFixedCol[T](startVec)
   254  	endSlice := vector.MustFixedCol[T](endVec)
   255  	start := startSlice[0]
   256  	end := startSlice[0]
   257  	last := endSlice[0]
   258  	var step T
   259  	if stepVec != nil {
   260  		stepSlice := vector.MustFixedCol[T](stepVec)
   261  		step = stepSlice[0]
   262  	} else {
   263  		if startSlice[0] < endSlice[0] {
   264  			step = T(1)
   265  		} else {
   266  			step = T(-1)
   267  		}
   268  	}
   269  	end = end - step
   270  
   271  	arg.generateSeries.start = start
   272  	arg.generateSeries.end = end
   273  	arg.generateSeries.last = last
   274  	arg.generateSeries.step = step
   275  }
   276  
   277  func computeNewStartAndEnd[T generateSeriesNumber](arg *Argument) {
   278  	step := arg.generateSeries.step.(T)
   279  	newStart := arg.generateSeries.end.(T) + step
   280  	last := arg.generateSeries.last.(T)
   281  	newEnd := newStart + step*T(addBatchSize)
   282  	if step > 0 {
   283  		if newEnd < newStart {
   284  			newEnd = last
   285  		} else {
   286  			if newEnd > last {
   287  				newEnd = last
   288  			}
   289  		}
   290  	} else {
   291  		if newEnd > newStart {
   292  			newEnd = last
   293  		} else {
   294  			if newEnd < last {
   295  				newEnd = last
   296  			}
   297  		}
   298  	}
   299  	if newEnd == last {
   300  		arg.generateSeries.state = genFinish
   301  	}
   302  	arg.generateSeries.start = newStart
   303  	arg.generateSeries.end = newEnd
   304  }
   305  
   306  func trimStep(step string) string {
   307  	step = strings.TrimSpace(step)
   308  	step = strings.TrimSuffix(step, "s")
   309  	step = strings.TrimSuffix(step, "(s)")
   310  	return step
   311  }
   312  
   313  func genStep(ctx context.Context, step string) (num int64, tp types.IntervalType, err error) {
   314  	step = trimStep(step)
   315  	s := strings.Split(step, " ")
   316  	if len(s) != 2 {
   317  		err = moerr.NewInvalidInput(ctx, "invalid step '%s'", step)
   318  		return
   319  	}
   320  	num, err = strconv.ParseInt(s[0], 10, 64)
   321  	if err != nil {
   322  		err = moerr.NewInvalidInput(ctx, "invalid step '%s'", step)
   323  		return
   324  	}
   325  	tp, err = types.IntervalTypeOf(s[1])
   326  	return
   327  }
   328  
   329  func generateInt32(ctx context.Context, start, end, step int32) ([]int32, error) {
   330  	res, err := judgeArgs(ctx, start, end, step)
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  	if res != nil {
   335  		return res, nil
   336  	}
   337  	if step > 0 {
   338  		for i := start; i <= end; i += step {
   339  			res = append(res, i)
   340  			if i > 0 && math.MaxInt32-i < step {
   341  				break
   342  			}
   343  		}
   344  	} else {
   345  		for i := start; i >= end; i += step {
   346  			res = append(res, i)
   347  			if i < 0 && math.MinInt32-i > step {
   348  				break
   349  			}
   350  		}
   351  	}
   352  	return res, nil
   353  }
   354  
   355  func generateInt64(ctx context.Context, start, end, step int64) ([]int64, error) {
   356  	res, err := judgeArgs(ctx, start, end, step)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  	if res != nil {
   361  		return res, nil
   362  	}
   363  	if step > 0 {
   364  		for i := start; i <= end; i += step {
   365  			res = append(res, i)
   366  			if i > 0 && math.MaxInt64-i < step {
   367  				break
   368  			}
   369  		}
   370  	} else {
   371  		for i := start; i >= end; i += step {
   372  			res = append(res, i)
   373  			if i < 0 && math.MinInt64-i > step {
   374  				break
   375  			}
   376  		}
   377  	}
   378  	return res, nil
   379  }
   380  
   381  func generateDatetime(ctx context.Context, start, end types.Datetime, stepStr string) ([]types.Datetime, error) {
   382  	step, tp, err := genStep(ctx, stepStr)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	var res []types.Datetime
   387  	res, err = judgeArgs(ctx, start, end, types.Datetime(step)) // here, transfer step to types.Datetime may change the inner behavior of datetime, but we just care the sign of step.
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  	if res != nil {
   392  		return res, nil
   393  	}
   394  	if step > 0 {
   395  		for i := start; i <= end; {
   396  			res = append(res, i)
   397  			var ok bool
   398  			i, ok = i.AddInterval(step, tp, types.DateTimeType)
   399  			if !ok {
   400  				return nil, moerr.NewInvalidInput(ctx, "invalid step '%s'", stepStr)
   401  			}
   402  		}
   403  	} else {
   404  		for i := start; i >= end; {
   405  			res = append(res, i)
   406  			var ok bool
   407  			i, ok = i.AddInterval(step, tp, types.DateTimeType)
   408  			if !ok {
   409  				return nil, moerr.NewInvalidInput(ctx, "invalid step '%s'", stepStr)
   410  			}
   411  		}
   412  	}
   413  	return res, nil
   414  }
   415  
   416  func handleInt[T int32 | int64](start, end, step T, genFn func(context.Context, T, T, T) ([]T, error), proc *process.Process, rbat *batch.Batch) error {
   417  	res, err := genFn(proc.Ctx, start, end, step)
   418  	if err != nil {
   419  		return err
   420  	}
   421  	for i := range res {
   422  		err = vector.AppendFixed(rbat.Vecs[0], res[i], false, proc.Mp())
   423  		if err != nil {
   424  			return err
   425  		}
   426  	}
   427  	rbat.SetRowCount(len(res))
   428  	return nil
   429  }
   430  
   431  func handleDatetime(start, end types.Datetime, step string, scale int32, proc *process.Process, rbat *batch.Batch) error {
   432  	res, err := generateDatetime(proc.Ctx, start, end, step)
   433  	if err != nil {
   434  		return err
   435  	}
   436  	for i := range res {
   437  		if scale >= 0 {
   438  			err = vector.AppendBytes(rbat.Vecs[0], []byte(res[i].String2(scale)), false, proc.Mp())
   439  		} else {
   440  			err = vector.AppendFixed(rbat.Vecs[0], res[i], false, proc.Mp())
   441  		}
   442  
   443  		if err != nil {
   444  			return err
   445  		}
   446  	}
   447  	rbat.SetRowCount(len(res))
   448  	return nil
   449  }
   450  
   451  func findScale(s1, s2 string) int {
   452  	p1 := 0
   453  	if strings.Contains(s1, ".") {
   454  		p1 = len(s1) - strings.LastIndex(s1, ".")
   455  	}
   456  	p2 := 0
   457  	if strings.Contains(s2, ".") {
   458  		p2 = len(s2) - strings.LastIndex(s2, ".")
   459  	}
   460  	if p2 > p1 {
   461  		p1 = p2
   462  	}
   463  	if p1 > 6 {
   464  		p1 = 6
   465  	}
   466  	return p1
   467  }