github.com/matrixorigin/matrixone@v1.2.0/pkg/vm/engine/tae/mergesort/merger.go (about)

     1  // Copyright 2024 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mergesort
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    24  	"github.com/matrixorigin/matrixone/pkg/container/types"
    25  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    26  	"github.com/matrixorigin/matrixone/pkg/objectio"
    27  	"github.com/matrixorigin/matrixone/pkg/pb/api"
    28  	"github.com/matrixorigin/matrixone/pkg/sort"
    29  	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/blockio"
    30  	"github.com/matrixorigin/matrixone/pkg/vm/engine/tae/options"
    31  )
    32  
    33  type Merger interface {
    34  	merge(context.Context) error
    35  }
    36  
    37  type releasableBatch struct {
    38  	bat      *batch.Batch
    39  	releaseF func()
    40  }
    41  
    42  type merger[T any] struct {
    43  	heap *heapSlice[T]
    44  
    45  	cols    [][]T
    46  	deletes []*nulls.Nulls
    47  	nulls   []*nulls.Nulls
    48  
    49  	buffer *batch.Batch
    50  
    51  	bats   []releasableBatch
    52  	rowIdx []int64
    53  
    54  	objCnt           int
    55  	objBlkCnts       []int
    56  	accObjBlkCnts    []int
    57  	loadedObjBlkCnts []int
    58  
    59  	host MergeTaskHost
    60  
    61  	writer *blockio.BlockWriter
    62  
    63  	sortKeyIdx int
    64  
    65  	mustColFunc func(*vector.Vector) []T
    66  
    67  	totalRowCnt   uint32
    68  	totalSize     uint32
    69  	rowPerBlk     uint32
    70  	blkPerObj     uint16
    71  	rowSize       uint32
    72  	targetObjSize uint32
    73  }
    74  
    75  func newMerger[T any](host MergeTaskHost, lessFunc sort.LessFunc[T], sortKeyPos int, mustColFunc func(*vector.Vector) []T) Merger {
    76  	size := host.GetObjectCnt()
    77  	m := &merger[T]{
    78  		host:       host,
    79  		objCnt:     size,
    80  		bats:       make([]releasableBatch, size),
    81  		rowIdx:     make([]int64, size),
    82  		cols:       make([][]T, size),
    83  		deletes:    make([]*nulls.Nulls, size),
    84  		nulls:      make([]*nulls.Nulls, size),
    85  		heap:       newHeapSlice[T](size, lessFunc),
    86  		sortKeyIdx: sortKeyPos,
    87  
    88  		accObjBlkCnts:    host.GetAccBlkCnts(),
    89  		objBlkCnts:       host.GetBlkCnts(),
    90  		rowPerBlk:        host.GetBlockMaxRows(),
    91  		blkPerObj:        host.GetObjectMaxBlocks(),
    92  		targetObjSize:    host.GetTargetObjSize(),
    93  		totalSize:        host.GetTotalSize(),
    94  		totalRowCnt:      host.GetTotalRowCnt(),
    95  		loadedObjBlkCnts: make([]int, size),
    96  		mustColFunc:      mustColFunc,
    97  	}
    98  	m.rowSize = m.totalSize / m.totalRowCnt
    99  	totalBlkCnt := 0
   100  	for _, cnt := range m.objBlkCnts {
   101  		totalBlkCnt += cnt
   102  	}
   103  	if host.DoTransfer() {
   104  		initTransferMapping(host.GetCommitEntry(), totalBlkCnt)
   105  	}
   106  
   107  	return m
   108  }
   109  
   110  func (m *merger[T]) merge(ctx context.Context) error {
   111  	for i := 0; i < m.objCnt; i++ {
   112  		if ok, err := m.loadBlk(ctx, uint32(i)); !ok {
   113  			return errors.Join(moerr.NewInternalError(ctx, "failed to load first blk"), err)
   114  		}
   115  
   116  		heapPush(m.heap, heapElem[T]{
   117  			data:   m.cols[i][m.rowIdx[i]],
   118  			isNull: m.nulls[i].Contains(uint64(m.rowIdx[i])),
   119  			src:    uint32(i),
   120  		})
   121  	}
   122  
   123  	var releaseF func()
   124  	m.buffer, releaseF = getSimilarBatch(m.bats[0].bat, int(m.rowPerBlk), m.host)
   125  	defer releaseF()
   126  
   127  	objCnt := 0
   128  	objBlkCnt := 0
   129  	bufferRowCnt := 0
   130  	objRowCnt := uint32(0)
   131  	mergedRowCnt := uint32(0)
   132  	commitEntry := m.host.GetCommitEntry()
   133  	for m.heap.Len() != 0 {
   134  		select {
   135  		case <-ctx.Done():
   136  			return ctx.Err()
   137  		default:
   138  		}
   139  		objIdx := m.nextPos()
   140  		if m.deletes[objIdx].Contains(uint64(m.rowIdx[objIdx])) {
   141  			// row is deleted
   142  			if err := m.pushNewElem(ctx, objIdx); err != nil {
   143  				return err
   144  			}
   145  			continue
   146  		}
   147  		rowIdx := m.rowIdx[objIdx]
   148  		for i := range m.buffer.Vecs {
   149  			err := m.buffer.Vecs[i].UnionOne(m.bats[objIdx].bat.Vecs[i], rowIdx, m.host.GetMPool())
   150  			if err != nil {
   151  				return err
   152  			}
   153  		}
   154  
   155  		if m.host.DoTransfer() {
   156  			commitEntry.Booking.Mappings[m.accObjBlkCnts[objIdx]+m.loadedObjBlkCnts[objIdx]-1].M[int32(rowIdx)] = api.TransDestPos{
   157  				ObjIdx: int32(objCnt),
   158  				BlkIdx: int32(uint32(objBlkCnt)),
   159  				RowIdx: int32(bufferRowCnt),
   160  			}
   161  		}
   162  
   163  		bufferRowCnt++
   164  		objRowCnt++
   165  		mergedRowCnt++
   166  		// write new block
   167  		if bufferRowCnt == int(m.rowPerBlk) {
   168  			bufferRowCnt = 0
   169  			objBlkCnt++
   170  
   171  			if m.writer == nil {
   172  				m.writer = m.host.PrepareNewWriter()
   173  			}
   174  
   175  			if _, err := m.writer.WriteBatch(m.buffer); err != nil {
   176  				return err
   177  			}
   178  			// force clean
   179  			m.buffer.CleanOnlyData()
   180  
   181  			// write new object
   182  			if m.needNewObject(objBlkCnt, objRowCnt, mergedRowCnt) {
   183  				// write object and reset writer
   184  				if err := m.syncObject(ctx); err != nil {
   185  					return err
   186  				}
   187  				// reset writer after sync
   188  				objBlkCnt = 0
   189  				objRowCnt = 0
   190  				objCnt++
   191  			}
   192  		}
   193  
   194  		if err := m.pushNewElem(ctx, objIdx); err != nil {
   195  			return err
   196  		}
   197  	}
   198  
   199  	// write remain data
   200  	if bufferRowCnt > 0 {
   201  		objBlkCnt++
   202  
   203  		if m.writer == nil {
   204  			m.writer = m.host.PrepareNewWriter()
   205  		}
   206  		if _, err := m.writer.WriteBatch(m.buffer); err != nil {
   207  			return err
   208  		}
   209  		m.buffer.CleanOnlyData()
   210  	}
   211  	if objBlkCnt > 0 {
   212  		if err := m.syncObject(ctx); err != nil {
   213  			return err
   214  		}
   215  	}
   216  	return nil
   217  }
   218  
   219  func (m *merger[T]) needNewObject(objBlkCnt int, objRowCnt, mergedRowCnt uint32) bool {
   220  	if m.targetObjSize == 0 {
   221  		if m.blkPerObj == 0 {
   222  			return objBlkCnt == int(options.DefaultBlocksPerObject)
   223  		}
   224  		return objBlkCnt == int(m.blkPerObj)
   225  	}
   226  
   227  	if objRowCnt*m.rowSize > m.targetObjSize {
   228  		return (m.totalRowCnt-mergedRowCnt)*m.rowSize > m.targetObjSize
   229  	}
   230  	return false
   231  }
   232  
   233  func (m *merger[T]) nextPos() uint32 {
   234  	return heapPop[T](m.heap).src
   235  }
   236  
   237  func (m *merger[T]) loadBlk(ctx context.Context, objIdx uint32) (bool, error) {
   238  	nextBatch, del, releaseF, err := m.host.LoadNextBatch(ctx, objIdx)
   239  	if m.bats[objIdx].bat != nil {
   240  		m.bats[objIdx].releaseF()
   241  	}
   242  	if err != nil {
   243  		if errors.Is(err, ErrNoMoreBlocks) {
   244  			return false, nil
   245  		}
   246  		return false, err
   247  	}
   248  
   249  	m.bats[objIdx] = releasableBatch{bat: nextBatch, releaseF: releaseF}
   250  	m.loadedObjBlkCnts[objIdx]++
   251  
   252  	vec := nextBatch.GetVector(int32(m.sortKeyIdx))
   253  	m.cols[objIdx] = m.mustColFunc(vec)
   254  	m.nulls[objIdx] = vec.GetNulls()
   255  	m.deletes[objIdx] = del
   256  	m.rowIdx[objIdx] = 0
   257  	return true, nil
   258  }
   259  
   260  func (m *merger[T]) pushNewElem(ctx context.Context, objIdx uint32) error {
   261  	m.rowIdx[objIdx]++
   262  	if m.rowIdx[objIdx] >= int64(len(m.cols[objIdx])) {
   263  		if ok, err := m.loadBlk(ctx, objIdx); !ok {
   264  			return err
   265  		}
   266  	}
   267  	nextRow := m.rowIdx[objIdx]
   268  	heapPush(m.heap, heapElem[T]{
   269  		data:   m.cols[objIdx][nextRow],
   270  		isNull: m.nulls[objIdx].Contains(uint64(nextRow)),
   271  		src:    objIdx,
   272  	})
   273  	return nil
   274  }
   275  
   276  func (m *merger[T]) syncObject(ctx context.Context) error {
   277  	if _, _, err := m.writer.Sync(ctx); err != nil {
   278  		return err
   279  	}
   280  	cobjstats := m.writer.GetObjectStats()[:objectio.SchemaTombstone]
   281  	commitEntry := m.host.GetCommitEntry()
   282  	for _, cobj := range cobjstats {
   283  		commitEntry.CreatedObjs = append(commitEntry.CreatedObjs, cobj.Clone().Marshal())
   284  	}
   285  	m.writer = nil
   286  	return nil
   287  }
   288  
   289  func mergeObjs(ctx context.Context, mergeHost MergeTaskHost, sortKeyPos int) error {
   290  	var merger Merger
   291  	typ := mergeHost.GetSortKeyType()
   292  	if typ.IsVarlen() {
   293  		merger = newMerger(mergeHost, sort.GenericLess[string], sortKeyPos, vector.MustStrCol)
   294  	} else {
   295  		switch typ.Oid {
   296  		case types.T_bool:
   297  			merger = newMerger(mergeHost, sort.BoolLess, sortKeyPos, vector.MustFixedCol[bool])
   298  		case types.T_bit:
   299  			merger = newMerger(mergeHost, sort.GenericLess[uint64], sortKeyPos, vector.MustFixedCol[uint64])
   300  		case types.T_int8:
   301  			merger = newMerger(mergeHost, sort.GenericLess[int8], sortKeyPos, vector.MustFixedCol[int8])
   302  		case types.T_int16:
   303  			merger = newMerger(mergeHost, sort.GenericLess[int16], sortKeyPos, vector.MustFixedCol[int16])
   304  		case types.T_int32:
   305  			merger = newMerger(mergeHost, sort.GenericLess[int32], sortKeyPos, vector.MustFixedCol[int32])
   306  		case types.T_int64:
   307  			merger = newMerger(mergeHost, sort.GenericLess[int64], sortKeyPos, vector.MustFixedCol[int64])
   308  		case types.T_float32:
   309  			merger = newMerger(mergeHost, sort.GenericLess[float32], sortKeyPos, vector.MustFixedCol[float32])
   310  		case types.T_float64:
   311  			merger = newMerger(mergeHost, sort.GenericLess[float64], sortKeyPos, vector.MustFixedCol[float64])
   312  		case types.T_uint8:
   313  			merger = newMerger(mergeHost, sort.GenericLess[uint8], sortKeyPos, vector.MustFixedCol[uint8])
   314  		case types.T_uint16:
   315  			merger = newMerger(mergeHost, sort.GenericLess[uint16], sortKeyPos, vector.MustFixedCol[uint16])
   316  		case types.T_uint32:
   317  			merger = newMerger(mergeHost, sort.GenericLess[uint32], sortKeyPos, vector.MustFixedCol[uint32])
   318  		case types.T_uint64:
   319  			merger = newMerger(mergeHost, sort.GenericLess[uint64], sortKeyPos, vector.MustFixedCol[uint64])
   320  		case types.T_date:
   321  			merger = newMerger(mergeHost, sort.GenericLess[types.Date], sortKeyPos, vector.MustFixedCol[types.Date])
   322  		case types.T_timestamp:
   323  			merger = newMerger(mergeHost, sort.GenericLess[types.Timestamp], sortKeyPos, vector.MustFixedCol[types.Timestamp])
   324  		case types.T_datetime:
   325  			merger = newMerger(mergeHost, sort.GenericLess[types.Datetime], sortKeyPos, vector.MustFixedCol[types.Datetime])
   326  		case types.T_time:
   327  			merger = newMerger(mergeHost, sort.GenericLess[types.Time], sortKeyPos, vector.MustFixedCol[types.Time])
   328  		case types.T_enum:
   329  			merger = newMerger(mergeHost, sort.GenericLess[types.Enum], sortKeyPos, vector.MustFixedCol[types.Enum])
   330  		case types.T_decimal64:
   331  			merger = newMerger(mergeHost, sort.Decimal64Less, sortKeyPos, vector.MustFixedCol[types.Decimal64])
   332  		case types.T_decimal128:
   333  			merger = newMerger(mergeHost, sort.Decimal128Less, sortKeyPos, vector.MustFixedCol[types.Decimal128])
   334  		case types.T_uuid:
   335  			merger = newMerger(mergeHost, sort.UuidLess, sortKeyPos, vector.MustFixedCol[types.Uuid])
   336  		case types.T_TS:
   337  			merger = newMerger(mergeHost, sort.TsLess, sortKeyPos, vector.MustFixedCol[types.TS])
   338  		case types.T_Rowid:
   339  			merger = newMerger(mergeHost, sort.RowidLess, sortKeyPos, vector.MustFixedCol[types.Rowid])
   340  		case types.T_Blockid:
   341  			merger = newMerger(mergeHost, sort.BlockidLess, sortKeyPos, vector.MustFixedCol[types.Blockid])
   342  		default:
   343  			return moerr.NewErrUnsupportedDataType(ctx, typ)
   344  		}
   345  	}
   346  	return merger.merge(ctx)
   347  }