github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/merge.go (about)

     1  package parquet
     2  
     3  import (
     4  	"container/heap"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  )
    10  
    11  // MergeRowGroups constructs a row group which is a merged view of rowGroups. If
    12  // rowGroups are sorted and the passed options include sorting, the merged row
    13  // group will also be sorted.
    14  //
    15  // The function validates the input to ensure that the merge operation is
    16  // possible, ensuring that the schemas match or can be converted to an
    17  // optionally configured target schema passed as argument in the option list.
    18  //
    19  // The sorting columns of each row group are also consulted to determine whether
    20  // the output can be represented. If sorting columns are configured on the merge
    21  // they must be a prefix of sorting columns of all row groups being merged.
    22  func MergeRowGroups(rowGroups []RowGroup, options ...RowGroupOption) (RowGroup, error) {
    23  	config, err := NewRowGroupConfig(options...)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  
    28  	schema := config.Schema
    29  	if len(rowGroups) == 0 {
    30  		return newEmptyRowGroup(schema), nil
    31  	}
    32  	if schema == nil {
    33  		schema = rowGroups[0].Schema()
    34  
    35  		for _, rowGroup := range rowGroups[1:] {
    36  			if !nodesAreEqual(schema, rowGroup.Schema()) {
    37  				return nil, ErrRowGroupSchemaMismatch
    38  			}
    39  		}
    40  	}
    41  
    42  	mergedRowGroups := make([]RowGroup, len(rowGroups))
    43  	copy(mergedRowGroups, rowGroups)
    44  
    45  	for i, rowGroup := range mergedRowGroups {
    46  		if rowGroupSchema := rowGroup.Schema(); !nodesAreEqual(schema, rowGroupSchema) {
    47  			conv, err := Convert(schema, rowGroupSchema)
    48  			if err != nil {
    49  				return nil, fmt.Errorf("cannot merge row groups: %w", err)
    50  			}
    51  			mergedRowGroups[i] = ConvertRowGroup(rowGroup, conv)
    52  		}
    53  	}
    54  
    55  	m := &mergedRowGroup{sorting: config.Sorting.SortingColumns}
    56  	m.init(schema, mergedRowGroups)
    57  
    58  	if len(m.sorting) == 0 {
    59  		// When the row group has no ordering, use a simpler version of the
    60  		// merger which simply concatenates rows from each of the row groups.
    61  		// This is preferable because it makes the output deterministic, the
    62  		// heap merge may otherwise reorder rows across groups.
    63  		return &m.multiRowGroup, nil
    64  	}
    65  
    66  	for _, rowGroup := range m.rowGroups {
    67  		if !sortingColumnsHavePrefix(rowGroup.SortingColumns(), m.sorting) {
    68  			return nil, ErrRowGroupSortingColumnsMismatch
    69  		}
    70  	}
    71  
    72  	m.compare = compareRowsFuncOf(schema, m.sorting)
    73  	return m, nil
    74  }
    75  
    76  type mergedRowGroup struct {
    77  	multiRowGroup
    78  	sorting []SortingColumn
    79  	compare func(Row, Row) int
    80  }
    81  
    82  func (m *mergedRowGroup) SortingColumns() []SortingColumn {
    83  	return m.sorting
    84  }
    85  
    86  func (m *mergedRowGroup) Rows() Rows {
    87  	// The row group needs to respect a sorting order; the merged row reader
    88  	// uses a heap to merge rows from the row groups.
    89  	rows := make([]Rows, len(m.rowGroups))
    90  	for i := range rows {
    91  		rows[i] = m.rowGroups[i].Rows()
    92  	}
    93  	return &mergedRowGroupRows{
    94  		merge: mergedRowReader{
    95  			compare: m.compare,
    96  			readers: makeBufferedRowReaders(len(rows), func(i int) RowReader { return rows[i] }),
    97  		},
    98  		rows:   rows,
    99  		schema: m.schema,
   100  	}
   101  }
   102  
   103  type mergedRowGroupRows struct {
   104  	merge     mergedRowReader
   105  	rowIndex  int64
   106  	seekToRow int64
   107  	rows      []Rows
   108  	schema    *Schema
   109  }
   110  
   111  func (r *mergedRowGroupRows) WriteRowsTo(w RowWriter) (n int64, err error) {
   112  	b := newMergeBuffer()
   113  	b.setup(r.rows, r.merge.compare)
   114  	n, err = b.WriteRowsTo(w)
   115  	r.rowIndex += int64(n)
   116  	b.release()
   117  	return
   118  }
   119  
   120  func (r *mergedRowGroupRows) readInternal(rows []Row) (int, error) {
   121  	n, err := r.merge.ReadRows(rows)
   122  	r.rowIndex += int64(n)
   123  	return n, err
   124  }
   125  
   126  func (r *mergedRowGroupRows) Close() (lastErr error) {
   127  	r.merge.close()
   128  	r.rowIndex = 0
   129  	r.seekToRow = 0
   130  
   131  	for _, rows := range r.rows {
   132  		if err := rows.Close(); err != nil {
   133  			lastErr = err
   134  		}
   135  	}
   136  
   137  	return lastErr
   138  }
   139  
   140  func (r *mergedRowGroupRows) ReadRows(rows []Row) (int, error) {
   141  	for r.rowIndex < r.seekToRow {
   142  		n := int(r.seekToRow - r.rowIndex)
   143  		if n > len(rows) {
   144  			n = len(rows)
   145  		}
   146  		n, err := r.readInternal(rows[:n])
   147  		if err != nil {
   148  			return 0, err
   149  		}
   150  	}
   151  
   152  	return r.readInternal(rows)
   153  }
   154  
   155  func (r *mergedRowGroupRows) SeekToRow(rowIndex int64) error {
   156  	if rowIndex >= r.rowIndex {
   157  		r.seekToRow = rowIndex
   158  		return nil
   159  	}
   160  	return fmt.Errorf("SeekToRow: merged row reader cannot seek backward from row %d to %d", r.rowIndex, rowIndex)
   161  }
   162  
   163  func (r *mergedRowGroupRows) Schema() *Schema {
   164  	return r.schema
   165  }
   166  
   167  // MergeRowReader constructs a RowReader which creates an ordered sequence of
   168  // all the readers using the given compare function as the ordering predicate.
   169  func MergeRowReaders(readers []RowReader, compare func(Row, Row) int) RowReader {
   170  	return &mergedRowReader{
   171  		compare: compare,
   172  		readers: makeBufferedRowReaders(len(readers), func(i int) RowReader { return readers[i] }),
   173  	}
   174  }
   175  
   176  func makeBufferedRowReaders(numReaders int, readerAt func(int) RowReader) []*bufferedRowReader {
   177  	buffers := make([]bufferedRowReader, numReaders)
   178  	readers := make([]*bufferedRowReader, numReaders)
   179  
   180  	for i := range readers {
   181  		buffers[i].rows = readerAt(i)
   182  		readers[i] = &buffers[i]
   183  	}
   184  
   185  	return readers
   186  }
   187  
   188  type mergedRowReader struct {
   189  	compare     func(Row, Row) int
   190  	readers     []*bufferedRowReader
   191  	initialized bool
   192  }
   193  
   194  func (m *mergedRowReader) initialize() error {
   195  	for i, r := range m.readers {
   196  		switch err := r.read(); err {
   197  		case nil:
   198  		case io.EOF:
   199  			m.readers[i] = nil
   200  		default:
   201  			m.readers = nil
   202  			return err
   203  		}
   204  	}
   205  
   206  	n := 0
   207  	for _, r := range m.readers {
   208  		if r != nil {
   209  			m.readers[n] = r
   210  			n++
   211  		}
   212  	}
   213  
   214  	clear := m.readers[n:]
   215  	for i := range clear {
   216  		clear[i] = nil
   217  	}
   218  
   219  	m.readers = m.readers[:n]
   220  	heap.Init(m)
   221  	return nil
   222  }
   223  
   224  func (m *mergedRowReader) close() {
   225  	for _, r := range m.readers {
   226  		r.close()
   227  	}
   228  	m.readers = nil
   229  }
   230  
   231  func (m *mergedRowReader) ReadRows(rows []Row) (n int, err error) {
   232  	if !m.initialized {
   233  		m.initialized = true
   234  
   235  		if err := m.initialize(); err != nil {
   236  			return 0, err
   237  		}
   238  	}
   239  
   240  	for n < len(rows) && len(m.readers) != 0 {
   241  		r := m.readers[0]
   242  
   243  		rows[n] = append(rows[n][:0], r.head()...)
   244  		n++
   245  
   246  		if err := r.next(); err != nil {
   247  			if err != io.EOF {
   248  				return n, err
   249  			}
   250  			heap.Pop(m)
   251  		} else {
   252  			heap.Fix(m, 0)
   253  		}
   254  	}
   255  
   256  	if len(m.readers) == 0 {
   257  		err = io.EOF
   258  	}
   259  
   260  	return n, err
   261  }
   262  
   263  func (m *mergedRowReader) Less(i, j int) bool {
   264  	return m.compare(m.readers[i].head(), m.readers[j].head()) < 0
   265  }
   266  
   267  func (m *mergedRowReader) Len() int {
   268  	return len(m.readers)
   269  }
   270  
   271  func (m *mergedRowReader) Swap(i, j int) {
   272  	m.readers[i], m.readers[j] = m.readers[j], m.readers[i]
   273  }
   274  
   275  func (m *mergedRowReader) Push(x interface{}) {
   276  	panic("NOT IMPLEMENTED")
   277  }
   278  
   279  func (m *mergedRowReader) Pop() interface{} {
   280  	i := len(m.readers) - 1
   281  	r := m.readers[i]
   282  	m.readers = m.readers[:i]
   283  	return r
   284  }
   285  
   286  type bufferedRowReader struct {
   287  	rows RowReader
   288  	off  int32
   289  	end  int32
   290  	buf  [10]Row
   291  }
   292  
   293  func (r *bufferedRowReader) head() Row {
   294  	return r.buf[r.off]
   295  }
   296  
   297  func (r *bufferedRowReader) next() error {
   298  	if r.off++; r.off == r.end {
   299  		r.off = 0
   300  		r.end = 0
   301  		return r.read()
   302  	}
   303  	return nil
   304  }
   305  
   306  func (r *bufferedRowReader) read() error {
   307  	if r.rows == nil {
   308  		return io.EOF
   309  	}
   310  	n, err := r.rows.ReadRows(r.buf[r.end:])
   311  	if err != nil && n == 0 {
   312  		return err
   313  	}
   314  	r.end += int32(n)
   315  	return nil
   316  }
   317  
   318  func (r *bufferedRowReader) close() {
   319  	r.rows = nil
   320  	r.off = 0
   321  	r.end = 0
   322  }
   323  
   324  type mergeBuffer struct {
   325  	compare func(Row, Row) int
   326  	rows    []Rows
   327  	buffer  [][]Row
   328  	head    []int
   329  	len     int
   330  	copy    [mergeBufferSize]Row
   331  }
   332  
   333  const mergeBufferSize = 1 << 10
   334  
   335  func newMergeBuffer() *mergeBuffer {
   336  	return mergeBufferPool.Get().(*mergeBuffer)
   337  }
   338  
   339  var mergeBufferPool = &sync.Pool{
   340  	New: func() any {
   341  		return new(mergeBuffer)
   342  	},
   343  }
   344  
   345  func (m *mergeBuffer) setup(rows []Rows, compare func(Row, Row) int) {
   346  	m.compare = compare
   347  	m.rows = append(m.rows, rows...)
   348  	size := len(rows)
   349  	if len(m.buffer) < size {
   350  		extra := size - len(m.buffer)
   351  		b := make([][]Row, extra)
   352  		for i := range b {
   353  			b[i] = make([]Row, 0, mergeBufferSize)
   354  		}
   355  		m.buffer = append(m.buffer, b...)
   356  		m.head = append(m.head, make([]int, extra)...)
   357  	}
   358  	m.len = size
   359  }
   360  
   361  func (m *mergeBuffer) reset() {
   362  	for i := range m.rows {
   363  		m.buffer[i] = m.buffer[i][:0]
   364  		m.head[i] = 0
   365  	}
   366  	m.rows = m.rows[:0]
   367  	m.compare = nil
   368  	for i := range m.copy {
   369  		m.copy[i] = nil
   370  	}
   371  	m.len = 0
   372  }
   373  
   374  func (m *mergeBuffer) release() {
   375  	m.reset()
   376  	mergeBufferPool.Put(m)
   377  }
   378  
   379  func (m *mergeBuffer) fill() error {
   380  	m.len = len(m.rows)
   381  	for i := range m.rows {
   382  		if m.head[i] < len(m.buffer[i]) {
   383  			// There is still rows data in m.buffer[i]. Skip filling the row buffer until
   384  			// all rows have been read.
   385  			continue
   386  		}
   387  		m.head[i] = 0
   388  		m.buffer[i] = m.buffer[i][:mergeBufferSize]
   389  		n, err := m.rows[i].ReadRows(m.buffer[i])
   390  		if err != nil {
   391  			if !errors.Is(err, io.EOF) {
   392  				return err
   393  			}
   394  		}
   395  		m.buffer[i] = m.buffer[i][:n]
   396  	}
   397  	heap.Init(m)
   398  	return nil
   399  }
   400  
   401  func (m *mergeBuffer) Less(i, j int) bool {
   402  	x := m.buffer[i]
   403  	if len(x) == 0 {
   404  		return false
   405  	}
   406  	y := m.buffer[j]
   407  	if len(y) == 0 {
   408  		return true
   409  	}
   410  	return m.compare(x[m.head[i]], y[m.head[j]]) == -1
   411  }
   412  
   413  func (m *mergeBuffer) Pop() interface{} {
   414  	m.len--
   415  	// We don't use the popped value.
   416  	return nil
   417  }
   418  
   419  func (m *mergeBuffer) Len() int {
   420  	return m.len
   421  }
   422  
   423  func (m *mergeBuffer) Swap(i, j int) {
   424  	m.buffer[i], m.buffer[j] = m.buffer[j], m.buffer[i]
   425  	m.head[i], m.head[j] = m.head[j], m.head[i]
   426  }
   427  
   428  func (m *mergeBuffer) Push(x interface{}) {
   429  	panic("NOT IMPLEMENTED")
   430  }
   431  
   432  func (m *mergeBuffer) WriteRowsTo(w RowWriter) (n int64, err error) {
   433  	err = m.fill()
   434  	if err != nil {
   435  		return 0, err
   436  	}
   437  	var count int
   438  	for m.left() {
   439  		size := m.read()
   440  		if size == 0 {
   441  			break
   442  		}
   443  		count, err = w.WriteRows(m.copy[:size])
   444  		if err != nil {
   445  			return
   446  		}
   447  		n += int64(count)
   448  		err = m.fill()
   449  		if err != nil {
   450  			return
   451  		}
   452  	}
   453  	return
   454  }
   455  
   456  func (m *mergeBuffer) left() bool {
   457  	for i := 0; i < m.len; i++ {
   458  		if m.head[i] < len(m.buffer[i]) {
   459  			return true
   460  		}
   461  	}
   462  	return false
   463  }
   464  
   465  func (m *mergeBuffer) read() (n int64) {
   466  	for n < int64(len(m.copy)) && m.Len() != 0 {
   467  		r := m.buffer[:m.len][0]
   468  		if len(r) == 0 {
   469  			heap.Pop(m)
   470  			continue
   471  		}
   472  		m.copy[n] = append(m.copy[n][:0], r[m.head[0]]...)
   473  		m.head[0]++
   474  		n++
   475  		if m.head[0] < len(r) {
   476  			// There is still rows in this row group. Adjust  the heap
   477  			heap.Fix(m, 0)
   478  		} else {
   479  			heap.Pop(m)
   480  		}
   481  	}
   482  	return
   483  }
   484  
   485  var (
   486  	_ RowReaderWithSchema = (*mergedRowGroupRows)(nil)
   487  	_ RowWriterTo         = (*mergedRowGroupRows)(nil)
   488  	_ heap.Interface      = (*mergeBuffer)(nil)
   489  	_ RowWriterTo         = (*mergeBuffer)(nil)
   490  )