github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/colexec/top/top.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package top
    16  
    17  import (
    18  	"bytes"
    19  	"container/heap"
    20  	"fmt"
    21  	"github.com/matrixorigin/matrixone/pkg/container/types"
    22  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    23  	"github.com/matrixorigin/matrixone/pkg/objectio"
    24  	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/index"
    25  
    26  	"github.com/matrixorigin/matrixone/pkg/sql/colexec"
    27  	"github.com/matrixorigin/matrixone/pkg/vm"
    28  
    29  	"github.com/matrixorigin/matrixone/pkg/compare"
    30  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    31  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    32  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    33  )
    34  
    35  const argName = "top"
    36  
    37  func (arg *Argument) String(buf *bytes.Buffer) {
    38  	buf.WriteString(argName)
    39  	ap := arg
    40  	buf.WriteString(": top([")
    41  	for i, f := range ap.Fs {
    42  		if i > 0 {
    43  			buf.WriteString(", ")
    44  		}
    45  		buf.WriteString(f.String())
    46  	}
    47  	buf.WriteString(fmt.Sprintf("], %v)", ap.Limit))
    48  }
    49  
    50  func (arg *Argument) Prepare(proc *process.Process) (err error) {
    51  	ap := arg
    52  	ap.ctr = new(container)
    53  	if ap.Limit > 1024 {
    54  		ap.ctr.sels = make([]int64, 0, 1024)
    55  	} else {
    56  		ap.ctr.sels = make([]int64, 0, ap.Limit)
    57  	}
    58  	ap.ctr.poses = make([]int32, 0, len(ap.Fs))
    59  
    60  	ctr := ap.ctr
    61  	ctr.executorsForOrderColumn = make([]colexec.ExpressionExecutor, len(ap.Fs))
    62  	for i := range ctr.executorsForOrderColumn {
    63  		ctr.executorsForOrderColumn[i], err = colexec.NewExpressionExecutor(proc, ap.Fs[i].Expr)
    64  		if err != nil {
    65  			return err
    66  		}
    67  	}
    68  	typ := ap.Fs[0].Expr.Typ
    69  	if arg.TopValueTag > 0 {
    70  		ctr.desc = arg.Fs[0].Flag&plan.OrderBySpec_DESC != 0
    71  		ctr.topValueZM = objectio.NewZM(types.T(typ.Id), typ.Scale)
    72  	}
    73  	return nil
    74  }
    75  
    76  func (arg *Argument) Call(proc *process.Process) (vm.CallResult, error) {
    77  	if err, isCancel := vm.CancelCheck(proc); isCancel {
    78  		return vm.CancelResult, err
    79  	}
    80  
    81  	ap := arg
    82  	ctr := ap.ctr
    83  
    84  	anal := proc.GetAnalyze(arg.GetIdx(), arg.GetParallelIdx(), arg.GetParallelMajor())
    85  	anal.Start()
    86  	defer func() {
    87  		anal.Stop()
    88  	}()
    89  
    90  	if ap.Limit == 0 {
    91  		result := vm.NewCallResult()
    92  		result.Status = vm.ExecStop
    93  		return result, nil
    94  	}
    95  
    96  	if ctr.state == vm.Build {
    97  		for {
    98  			result, err := vm.ChildrenCall(arg.GetChildren(0), proc, anal)
    99  			if err != nil {
   100  				return result, err
   101  			}
   102  			bat := result.Batch
   103  			if bat == nil {
   104  				ctr.state = vm.Eval
   105  				break
   106  			}
   107  			if bat.IsEmpty() {
   108  				continue
   109  			}
   110  			err = ctr.build(ap, bat, proc, anal)
   111  			if err != nil {
   112  				return result, err
   113  			}
   114  			if arg.TopValueTag > 0 && arg.updateTopValueZM() {
   115  				proc.SendMessage(process.TopValueMessage{TopValueZM: arg.ctr.topValueZM, Tag: arg.TopValueTag})
   116  			}
   117  		}
   118  	}
   119  
   120  	result := vm.NewCallResult()
   121  	if ctr.state == vm.Eval {
   122  		ctr.state = vm.End
   123  		if ctr.bat != nil {
   124  			err := ctr.eval(ap.Limit, proc, &result)
   125  			if err != nil {
   126  				return result, err
   127  			}
   128  		}
   129  		return result, nil
   130  	}
   131  
   132  	if ctr.state == vm.End {
   133  		return result, nil
   134  	}
   135  
   136  	panic("bug")
   137  }
   138  
   139  func (ctr *container) build(ap *Argument, bat *batch.Batch, proc *process.Process, analyze process.Analyze) error {
   140  	ctr.n = len(bat.Vecs)
   141  	ctr.poses = ctr.poses[:0]
   142  	for i := range ap.Fs {
   143  		vec, err := ctr.executorsForOrderColumn[i].Eval(proc, []*batch.Batch{bat})
   144  		if err != nil {
   145  			return err
   146  		}
   147  		aNewOrderColumn := true
   148  		for j := range bat.Vecs {
   149  			if bat.Vecs[j] == vec {
   150  				aNewOrderColumn = false
   151  				ctr.poses = append(ctr.poses, int32(j))
   152  				break
   153  			}
   154  		}
   155  		if aNewOrderColumn {
   156  			nv, err := vec.Dup(proc.Mp())
   157  			if err != nil {
   158  				return err
   159  			}
   160  			ctr.poses = append(ctr.poses, int32(len(bat.Vecs)))
   161  			bat.Vecs = append(bat.Vecs, nv)
   162  			analyze.Alloc(int64(nv.Size()))
   163  		}
   164  	}
   165  	if ctr.bat == nil {
   166  		mp := make(map[int]int)
   167  		for i, pos := range ctr.poses {
   168  			mp[int(pos)] = i
   169  		}
   170  		ctr.bat = batch.NewWithSize(len(bat.Vecs))
   171  		for i, vec := range bat.Vecs {
   172  			ctr.bat.Vecs[i] = proc.GetVector(*vec.GetType())
   173  		}
   174  		ctr.cmps = make([]compare.Compare, len(bat.Vecs))
   175  		for i := range ctr.cmps {
   176  			var desc, nullsLast bool
   177  			if pos, ok := mp[i]; ok {
   178  				desc = ap.Fs[pos].Flag&plan.OrderBySpec_DESC != 0
   179  				if ap.Fs[pos].Flag&plan.OrderBySpec_NULLS_FIRST != 0 {
   180  					nullsLast = false
   181  				} else if ap.Fs[pos].Flag&plan.OrderBySpec_NULLS_LAST != 0 {
   182  					nullsLast = true
   183  				} else {
   184  					nullsLast = desc
   185  				}
   186  			}
   187  			ctr.cmps[i] = compare.New(*bat.Vecs[i].GetType(), desc, nullsLast)
   188  		}
   189  	}
   190  	err := ctr.processBatch(ap.Limit, bat, proc)
   191  	return err
   192  }
   193  
   194  func (ctr *container) processBatch(limit int64, bat *batch.Batch, proc *process.Process) error {
   195  	var start int64
   196  
   197  	length := int64(bat.RowCount())
   198  	if n := int64(len(ctr.sels)); n < limit {
   199  		start = limit - n
   200  		if start > length {
   201  			start = length
   202  		}
   203  		for i := int64(0); i < start; i++ {
   204  			for j, vec := range ctr.bat.Vecs {
   205  				if err := vec.UnionOne(bat.Vecs[j], i, proc.Mp()); err != nil {
   206  					return err
   207  				}
   208  			}
   209  			ctr.sels = append(ctr.sels, n)
   210  			n++
   211  		}
   212  		ctr.bat.AddRowCount(int(start))
   213  
   214  		if n == limit {
   215  			ctr.sort()
   216  		}
   217  	}
   218  	if start == length {
   219  		return nil
   220  	}
   221  
   222  	// bat is still have items
   223  	for i, cmp := range ctr.cmps {
   224  		cmp.Set(1, bat.Vecs[i])
   225  	}
   226  	for i, j := start, length; i < j; i++ {
   227  		if ctr.compare(1, 0, i, ctr.sels[0]) < 0 {
   228  			for _, cmp := range ctr.cmps {
   229  				if err := cmp.Copy(1, 0, i, ctr.sels[0], proc); err != nil {
   230  					return err
   231  				}
   232  			}
   233  			heap.Fix(ctr, 0)
   234  		}
   235  	}
   236  	return nil
   237  }
   238  
   239  func (ctr *container) eval(limit int64, proc *process.Process, result *vm.CallResult) error {
   240  	if int64(len(ctr.sels)) < limit {
   241  		ctr.sort()
   242  	}
   243  	for i, cmp := range ctr.cmps {
   244  		ctr.bat.Vecs[i] = cmp.Vector()
   245  	}
   246  	sels := make([]int64, len(ctr.sels))
   247  	for i, j := 0, len(ctr.sels); i < j; i++ {
   248  		sels[len(sels)-1-i] = heap.Pop(ctr).(int64)
   249  	}
   250  	if err := ctr.bat.Shuffle(sels, proc.Mp()); err != nil {
   251  		return err
   252  	}
   253  	for i := ctr.n; i < len(ctr.bat.Vecs); i++ {
   254  		ctr.bat.Vecs[i].Free(proc.Mp())
   255  	}
   256  	ctr.bat.Vecs = ctr.bat.Vecs[:ctr.n]
   257  	result.Batch = ctr.bat
   258  	return nil
   259  }
   260  
   261  // do sort work for heap, and result order will be set in container.sels
   262  func (ctr *container) sort() {
   263  	for i, cmp := range ctr.cmps {
   264  		cmp.Set(0, ctr.bat.Vecs[i])
   265  	}
   266  	heap.Init(ctr)
   267  }
   268  
   269  func (arg *Argument) updateTopValueZM() bool {
   270  	v, ok := arg.getTopValue()
   271  	if !ok {
   272  		return false
   273  	}
   274  	zm := arg.ctr.topValueZM
   275  	if !zm.IsInited() {
   276  		index.UpdateZM(zm, v)
   277  		return true
   278  	}
   279  	newZM := objectio.NewZM(zm.GetType(), zm.GetScale())
   280  	index.UpdateZM(newZM, v)
   281  	if arg.ctr.desc && newZM.CompareMax(zm) > 0 {
   282  		arg.ctr.topValueZM = newZM
   283  		return true
   284  	}
   285  	if !arg.ctr.desc && newZM.CompareMin(zm) < 0 {
   286  		arg.ctr.topValueZM = newZM
   287  		return true
   288  	}
   289  	return false
   290  }
   291  
   292  func (arg *Argument) getTopValue() ([]byte, bool) {
   293  	ctr := arg.ctr
   294  	// not enough items in the heap.
   295  	if int64(len(ctr.sels)) < arg.Limit {
   296  		return nil, false
   297  	}
   298  	x := int(ctr.sels[0])
   299  	vec := ctr.cmps[ctr.poses[0]].Vector()
   300  	if vec.GetType().IsVarlen() {
   301  		return vec.GetBytesAt(x), true
   302  	}
   303  	switch vec.GetType().Oid {
   304  	case types.T_int8:
   305  		v := vector.GetFixedAt[int8](vec, x)
   306  		return types.EncodeInt8(&v), true
   307  	case types.T_int16:
   308  		v := vector.GetFixedAt[int16](vec, x)
   309  		return types.EncodeInt16(&v), true
   310  	case types.T_int32:
   311  		v := vector.GetFixedAt[int32](vec, x)
   312  		return types.EncodeInt32(&v), true
   313  	case types.T_int64:
   314  		v := vector.GetFixedAt[int64](vec, x)
   315  		return types.EncodeInt64(&v), true
   316  	case types.T_uint8:
   317  		v := vector.GetFixedAt[uint8](vec, x)
   318  		return types.EncodeUint8(&v), true
   319  	case types.T_uint16:
   320  		v := vector.GetFixedAt[uint16](vec, x)
   321  		return types.EncodeUint16(&v), true
   322  	case types.T_uint32:
   323  		v := vector.GetFixedAt[uint32](vec, x)
   324  		return types.EncodeUint32(&v), true
   325  	case types.T_uint64:
   326  		v := vector.GetFixedAt[uint64](vec, x)
   327  		return types.EncodeUint64(&v), true
   328  	case types.T_float32:
   329  		v := vector.GetFixedAt[float32](vec, x)
   330  		return types.EncodeFloat32(&v), true
   331  	case types.T_float64:
   332  		v := vector.GetFixedAt[float64](vec, x)
   333  		return types.EncodeFloat64(&v), true
   334  	case types.T_date:
   335  		v := vector.GetFixedAt[types.Date](vec, x)
   336  		return types.EncodeDate(&v), true
   337  	case types.T_datetime:
   338  		v := vector.GetFixedAt[types.Datetime](vec, x)
   339  		return types.EncodeDatetime(&v), true
   340  	case types.T_timestamp:
   341  		v := vector.GetFixedAt[types.Timestamp](vec, x)
   342  		return types.EncodeTimestamp(&v), true
   343  	case types.T_time:
   344  		v := vector.GetFixedAt[types.Time](vec, x)
   345  		return types.EncodeTime(&v), true
   346  	case types.T_decimal64:
   347  		v := vector.GetFixedAt[types.Decimal64](vec, x)
   348  		return types.EncodeDecimal64(&v), true
   349  	case types.T_decimal128:
   350  		v := vector.GetFixedAt[types.Decimal128](vec, x)
   351  		return types.EncodeDecimal128(&v), true
   352  	}
   353  	return nil, false
   354  }