github.com/matrixorigin/matrixone@v1.2.0/pkg/vm/engine/tae/mergesort/aobj_merger.go (about)

     1  // Copyright 2024 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mergesort
    16  
    17  import (
    18  	"context"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    21  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    22  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    23  	"github.com/matrixorigin/matrixone/pkg/container/types"
    24  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    25  	"github.com/matrixorigin/matrixone/pkg/sort"
    26  	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/containers"
    27  )
    28  
    29  type AObjMerger interface {
    30  	Merge(context.Context) ([]*batch.Batch, func(), []uint32, error)
    31  }
    32  
    33  type aObjMerger[T any] struct {
    34  	heap *heapSlice[T]
    35  
    36  	cols        [][]T
    37  	nulls       []*nulls.Nulls
    38  	sortKeyType *types.Type
    39  
    40  	resultBlkCnt int
    41  
    42  	bats      []*containers.Batch
    43  	rowIdx    []int64
    44  	accRowCnt []int64
    45  
    46  	mapping   []uint32
    47  	rowPerBlk uint32
    48  	vpool     DisposableVecPool
    49  }
    50  
    51  func MergeAObj(
    52  	ctx context.Context,
    53  	vpool DisposableVecPool,
    54  	batches []*containers.Batch,
    55  	sortKeyPos int,
    56  	rowPerBlk uint32,
    57  	resultBlkCnt int) ([]*batch.Batch, func(), []uint32, error) {
    58  	var merger AObjMerger
    59  	typ := batches[0].Vecs[sortKeyPos].GetType()
    60  	if typ.IsVarlen() {
    61  		merger = newAObjMerger(vpool, batches, sort.GenericLess[string], sortKeyPos, vector.MustStrCol, rowPerBlk, resultBlkCnt)
    62  	} else {
    63  		switch typ.Oid {
    64  		case types.T_bool:
    65  			merger = newAObjMerger(vpool, batches, sort.BoolLess, sortKeyPos, vector.MustFixedCol[bool], rowPerBlk, resultBlkCnt)
    66  		case types.T_bit:
    67  			merger = newAObjMerger(vpool, batches, sort.GenericLess[uint64], sortKeyPos, vector.MustFixedCol[uint64], rowPerBlk, resultBlkCnt)
    68  		case types.T_int8:
    69  			merger = newAObjMerger(vpool, batches, sort.GenericLess[int8], sortKeyPos, vector.MustFixedCol[int8], rowPerBlk, resultBlkCnt)
    70  		case types.T_int16:
    71  			merger = newAObjMerger(vpool, batches, sort.GenericLess[int16], sortKeyPos, vector.MustFixedCol[int16], rowPerBlk, resultBlkCnt)
    72  		case types.T_int32:
    73  			merger = newAObjMerger(vpool, batches, sort.GenericLess[int32], sortKeyPos, vector.MustFixedCol[int32], rowPerBlk, resultBlkCnt)
    74  		case types.T_int64:
    75  			merger = newAObjMerger(vpool, batches, sort.GenericLess[int64], sortKeyPos, vector.MustFixedCol[int64], rowPerBlk, resultBlkCnt)
    76  		case types.T_float32:
    77  			merger = newAObjMerger(vpool, batches, sort.GenericLess[float32], sortKeyPos, vector.MustFixedCol[float32], rowPerBlk, resultBlkCnt)
    78  		case types.T_float64:
    79  			merger = newAObjMerger(vpool, batches, sort.GenericLess[float64], sortKeyPos, vector.MustFixedCol[float64], rowPerBlk, resultBlkCnt)
    80  		case types.T_uint8:
    81  			merger = newAObjMerger(vpool, batches, sort.GenericLess[uint8], sortKeyPos, vector.MustFixedCol[uint8], rowPerBlk, resultBlkCnt)
    82  		case types.T_uint16:
    83  			merger = newAObjMerger(vpool, batches, sort.GenericLess[uint16], sortKeyPos, vector.MustFixedCol[uint16], rowPerBlk, resultBlkCnt)
    84  		case types.T_uint32:
    85  			merger = newAObjMerger(vpool, batches, sort.GenericLess[uint32], sortKeyPos, vector.MustFixedCol[uint32], rowPerBlk, resultBlkCnt)
    86  		case types.T_uint64:
    87  			merger = newAObjMerger(vpool, batches, sort.GenericLess[uint64], sortKeyPos, vector.MustFixedCol[uint64], rowPerBlk, resultBlkCnt)
    88  		case types.T_date:
    89  			merger = newAObjMerger(vpool, batches, sort.GenericLess[types.Date], sortKeyPos, vector.MustFixedCol[types.Date], rowPerBlk, resultBlkCnt)
    90  		case types.T_timestamp:
    91  			merger = newAObjMerger(vpool, batches, sort.GenericLess[types.Timestamp], sortKeyPos, vector.MustFixedCol[types.Timestamp], rowPerBlk, resultBlkCnt)
    92  		case types.T_datetime:
    93  			merger = newAObjMerger(vpool, batches, sort.GenericLess[types.Datetime], sortKeyPos, vector.MustFixedCol[types.Datetime], rowPerBlk, resultBlkCnt)
    94  		case types.T_time:
    95  			merger = newAObjMerger(vpool, batches, sort.GenericLess[types.Time], sortKeyPos, vector.MustFixedCol[types.Time], rowPerBlk, resultBlkCnt)
    96  		case types.T_enum:
    97  			merger = newAObjMerger(vpool, batches, sort.GenericLess[types.Enum], sortKeyPos, vector.MustFixedCol[types.Enum], rowPerBlk, resultBlkCnt)
    98  		case types.T_decimal64:
    99  			merger = newAObjMerger(vpool, batches, sort.Decimal64Less, sortKeyPos, vector.MustFixedCol[types.Decimal64], rowPerBlk, resultBlkCnt)
   100  		case types.T_decimal128:
   101  			merger = newAObjMerger(vpool, batches, sort.Decimal128Less, sortKeyPos, vector.MustFixedCol[types.Decimal128], rowPerBlk, resultBlkCnt)
   102  		case types.T_uuid:
   103  			merger = newAObjMerger(vpool, batches, sort.UuidLess, sortKeyPos, vector.MustFixedCol[types.Uuid], rowPerBlk, resultBlkCnt)
   104  		case types.T_TS:
   105  			merger = newAObjMerger(vpool, batches, sort.TsLess, sortKeyPos, vector.MustFixedCol[types.TS], rowPerBlk, resultBlkCnt)
   106  		case types.T_Rowid:
   107  			merger = newAObjMerger(vpool, batches, sort.RowidLess, sortKeyPos, vector.MustFixedCol[types.Rowid], rowPerBlk, resultBlkCnt)
   108  		case types.T_Blockid:
   109  			merger = newAObjMerger(vpool, batches, sort.BlockidLess, sortKeyPos, vector.MustFixedCol[types.Blockid], rowPerBlk, resultBlkCnt)
   110  		default:
   111  			return nil, nil, nil, moerr.NewErrUnsupportedDataType(ctx, typ)
   112  		}
   113  	}
   114  	return merger.Merge(ctx)
   115  }
   116  
   117  func newAObjMerger[T any](
   118  	vpool DisposableVecPool,
   119  	batches []*containers.Batch,
   120  	lessFunc sort.LessFunc[T],
   121  	sortKeyPos int,
   122  	mustColFunc func(*vector.Vector) []T,
   123  	rowPerBlk uint32,
   124  	resultBlkCnt int) AObjMerger {
   125  	size := len(batches)
   126  	m := &aObjMerger[T]{
   127  		vpool:        vpool,
   128  		heap:         newHeapSlice[T](size, lessFunc),
   129  		cols:         make([][]T, size),
   130  		nulls:        make([]*nulls.Nulls, size),
   131  		rowIdx:       make([]int64, size),
   132  		accRowCnt:    make([]int64, size),
   133  		bats:         batches,
   134  		resultBlkCnt: resultBlkCnt,
   135  		rowPerBlk:    rowPerBlk,
   136  	}
   137  
   138  	totalRowCnt := 0
   139  	for i, blk := range batches {
   140  		sortKeyCol := blk.Vecs[sortKeyPos].GetDownstreamVector()
   141  		m.sortKeyType = sortKeyCol.GetType()
   142  		m.cols[i] = mustColFunc(sortKeyCol)
   143  		m.nulls[i] = sortKeyCol.GetNulls()
   144  		m.rowIdx[i] = 0
   145  		m.accRowCnt[i] = int64(totalRowCnt)
   146  		totalRowCnt += len(m.cols[i])
   147  	}
   148  	m.mapping = make([]uint32, totalRowCnt)
   149  
   150  	return m
   151  }
   152  
   153  func (am *aObjMerger[T]) Merge(ctx context.Context) ([]*batch.Batch, func(), []uint32, error) {
   154  	for i := 0; i < len(am.bats); i++ {
   155  		heapPush(am.heap, heapElem[T]{
   156  			data:   am.cols[i][0],
   157  			isNull: am.nulls[i].Contains(0),
   158  			src:    uint32(i),
   159  		})
   160  	}
   161  
   162  	cnBat := containers.ToCNBatch(am.bats[0])
   163  	batches := make([]*batch.Batch, am.resultBlkCnt)
   164  	releaseFs := make([]func(), am.resultBlkCnt)
   165  
   166  	blkCnt := 0
   167  	bufferRowCnt := 0
   168  	k := uint32(0)
   169  	for am.heap.Len() != 0 {
   170  		select {
   171  		case <-ctx.Done():
   172  			return nil, nil, nil, ctx.Err()
   173  		default:
   174  		}
   175  		blkIdx := am.nextPos()
   176  		rowIdx := am.rowIdx[blkIdx]
   177  		if batches[blkCnt] == nil {
   178  			batches[blkCnt], releaseFs[blkCnt] = getSimilarBatch(cnBat, int(am.rowPerBlk), am.vpool)
   179  		}
   180  		for i := range batches[blkCnt].Vecs {
   181  			err := batches[blkCnt].Vecs[i].UnionOne(am.bats[blkIdx].Vecs[i].GetDownstreamVector(), rowIdx, am.vpool.GetMPool())
   182  			if err != nil {
   183  				return nil, nil, nil, err
   184  			}
   185  		}
   186  
   187  		am.mapping[am.accRowCnt[blkIdx]+rowIdx] = k
   188  		k++
   189  		bufferRowCnt++
   190  		// write new block
   191  		if bufferRowCnt == int(am.rowPerBlk) {
   192  			bufferRowCnt = 0
   193  			blkCnt++
   194  		}
   195  
   196  		am.pushNewElem(blkIdx)
   197  	}
   198  	return batches, func() {
   199  		for _, f := range releaseFs {
   200  			f()
   201  		}
   202  	}, am.mapping, nil
   203  }
   204  
   205  func (am *aObjMerger[T]) nextPos() uint32 {
   206  	return heapPop[T](am.heap).src
   207  }
   208  
   209  func (am *aObjMerger[T]) pushNewElem(blkIdx uint32) bool {
   210  	am.rowIdx[blkIdx]++
   211  	if am.rowIdx[blkIdx] >= int64(len(am.cols[blkIdx])) {
   212  		return false
   213  	}
   214  	nextRow := am.rowIdx[blkIdx]
   215  	heapPush(am.heap, heapElem[T]{
   216  		data:   am.cols[blkIdx][nextRow],
   217  		isNull: am.nulls[blkIdx].Contains(uint64(nextRow)),
   218  		src:    blkIdx,
   219  	})
   220  	return true
   221  }