github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/merge.go (about)

     1  package parquet
     2  
     3  import (
     4  	"container/heap"
     5  	"fmt"
     6  	"io"
     7  )
     8  
     9  // MergeRowGroups constructs a row group which is a merged view of rowGroups. If
    10  // rowGroups are sorted and the passed options include sorting, the merged row
    11  // group will also be sorted.
    12  //
    13  // The function validates the input to ensure that the merge operation is
    14  // possible, ensuring that the schemas match or can be converted to an
    15  // optionally configured target schema passed as argument in the option list.
    16  //
    17  // The sorting columns of each row group are also consulted to determine whether
    18  // the output can be represented. If sorting columns are configured on the merge
    19  // they must be a prefix of sorting columns of all row groups being merged.
    20  func MergeRowGroups(rowGroups []RowGroup, options ...RowGroupOption) (RowGroup, error) {
    21  	config, err := NewRowGroupConfig(options...)
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	schema := config.Schema
    27  	if len(rowGroups) == 0 {
    28  		return newEmptyRowGroup(schema), nil
    29  	}
    30  	if schema == nil {
    31  		schema = rowGroups[0].Schema()
    32  
    33  		for _, rowGroup := range rowGroups[1:] {
    34  			if !nodesAreEqual(schema, rowGroup.Schema()) {
    35  				return nil, ErrRowGroupSchemaMismatch
    36  			}
    37  		}
    38  	}
    39  
    40  	mergedRowGroups := make([]RowGroup, len(rowGroups))
    41  	copy(mergedRowGroups, rowGroups)
    42  
    43  	for i, rowGroup := range mergedRowGroups {
    44  		if rowGroupSchema := rowGroup.Schema(); !nodesAreEqual(schema, rowGroupSchema) {
    45  			conv, err := Convert(schema, rowGroupSchema)
    46  			if err != nil {
    47  				return nil, fmt.Errorf("cannot merge row groups: %w", err)
    48  			}
    49  			mergedRowGroups[i] = ConvertRowGroup(rowGroup, conv)
    50  		}
    51  	}
    52  
    53  	m := &mergedRowGroup{sorting: config.Sorting.SortingColumns}
    54  	m.init(schema, mergedRowGroups)
    55  
    56  	if len(m.sorting) == 0 {
    57  		// When the row group has no ordering, use a simpler version of the
    58  		// merger which simply concatenates rows from each of the row groups.
    59  		// This is preferable because it makes the output deterministic, the
    60  		// heap merge may otherwise reorder rows across groups.
    61  		return &m.multiRowGroup, nil
    62  	}
    63  
    64  	for _, rowGroup := range m.rowGroups {
    65  		if !sortingColumnsHavePrefix(rowGroup.SortingColumns(), m.sorting) {
    66  			return nil, ErrRowGroupSortingColumnsMismatch
    67  		}
    68  	}
    69  
    70  	m.compare = compareRowsFuncOf(schema, m.sorting)
    71  	return m, nil
    72  }
    73  
    74  type mergedRowGroup struct {
    75  	multiRowGroup
    76  	sorting []SortingColumn
    77  	compare func(Row, Row) int
    78  }
    79  
    80  func (m *mergedRowGroup) SortingColumns() []SortingColumn {
    81  	return m.sorting
    82  }
    83  
    84  func (m *mergedRowGroup) Rows() Rows {
    85  	// The row group needs to respect a sorting order; the merged row reader
    86  	// uses a heap to merge rows from the row groups.
    87  	rows := make([]Rows, len(m.rowGroups))
    88  	for i := range rows {
    89  		rows[i] = m.rowGroups[i].Rows()
    90  	}
    91  	return &mergedRowGroupRows{
    92  		merge: mergedRowReader{
    93  			compare: m.compare,
    94  			readers: makeBufferedRowReaders(len(rows), func(i int) RowReader { return rows[i] }),
    95  		},
    96  		rows:   rows,
    97  		schema: m.schema,
    98  	}
    99  }
   100  
   101  type mergedRowGroupRows struct {
   102  	merge     mergedRowReader
   103  	rowIndex  int64
   104  	seekToRow int64
   105  	rows      []Rows
   106  	schema    *Schema
   107  }
   108  
   109  func (r *mergedRowGroupRows) readInternal(rows []Row) (int, error) {
   110  	n, err := r.merge.ReadRows(rows)
   111  	r.rowIndex += int64(n)
   112  	return n, err
   113  }
   114  
   115  func (r *mergedRowGroupRows) Close() (lastErr error) {
   116  	r.merge.close()
   117  	r.rowIndex = 0
   118  	r.seekToRow = 0
   119  
   120  	for _, rows := range r.rows {
   121  		if err := rows.Close(); err != nil {
   122  			lastErr = err
   123  		}
   124  	}
   125  
   126  	return lastErr
   127  }
   128  
   129  func (r *mergedRowGroupRows) ReadRows(rows []Row) (int, error) {
   130  	for r.rowIndex < r.seekToRow {
   131  		n := int(r.seekToRow - r.rowIndex)
   132  		if n > len(rows) {
   133  			n = len(rows)
   134  		}
   135  		n, err := r.readInternal(rows[:n])
   136  		if err != nil {
   137  			return 0, err
   138  		}
   139  	}
   140  
   141  	return r.readInternal(rows)
   142  }
   143  
   144  func (r *mergedRowGroupRows) SeekToRow(rowIndex int64) error {
   145  	if rowIndex >= r.rowIndex {
   146  		r.seekToRow = rowIndex
   147  		return nil
   148  	}
   149  	return fmt.Errorf("SeekToRow: merged row reader cannot seek backward from row %d to %d", r.rowIndex, rowIndex)
   150  }
   151  
   152  func (r *mergedRowGroupRows) Schema() *Schema {
   153  	return r.schema
   154  }
   155  
   156  // MergeRowReader constructs a RowReader which creates an ordered sequence of
   157  // all the readers using the given compare function as the ordering predicate.
   158  func MergeRowReaders(readers []RowReader, compare func(Row, Row) int) RowReader {
   159  	return &mergedRowReader{
   160  		compare: compare,
   161  		readers: makeBufferedRowReaders(len(readers), func(i int) RowReader { return readers[i] }),
   162  	}
   163  }
   164  
   165  func makeBufferedRowReaders(numReaders int, readerAt func(int) RowReader) []*bufferedRowReader {
   166  	buffers := make([]bufferedRowReader, numReaders)
   167  	readers := make([]*bufferedRowReader, numReaders)
   168  
   169  	for i := range readers {
   170  		buffers[i].rows = readerAt(i)
   171  		readers[i] = &buffers[i]
   172  	}
   173  
   174  	return readers
   175  }
   176  
   177  type mergedRowReader struct {
   178  	compare     func(Row, Row) int
   179  	readers     []*bufferedRowReader
   180  	initialized bool
   181  }
   182  
   183  func (m *mergedRowReader) initialize() error {
   184  	for i, r := range m.readers {
   185  		switch err := r.read(); err {
   186  		case nil:
   187  		case io.EOF:
   188  			m.readers[i] = nil
   189  		default:
   190  			m.readers = nil
   191  			return err
   192  		}
   193  	}
   194  
   195  	n := 0
   196  	for _, r := range m.readers {
   197  		if r != nil {
   198  			m.readers[n] = r
   199  			n++
   200  		}
   201  	}
   202  
   203  	clear := m.readers[n:]
   204  	for i := range clear {
   205  		clear[i] = nil
   206  	}
   207  
   208  	m.readers = m.readers[:n]
   209  	heap.Init(m)
   210  	return nil
   211  }
   212  
   213  func (m *mergedRowReader) close() {
   214  	for _, r := range m.readers {
   215  		r.close()
   216  	}
   217  	m.readers = nil
   218  }
   219  
   220  func (m *mergedRowReader) ReadRows(rows []Row) (n int, err error) {
   221  	if !m.initialized {
   222  		m.initialized = true
   223  
   224  		if err := m.initialize(); err != nil {
   225  			return 0, err
   226  		}
   227  	}
   228  
   229  	for n < len(rows) && len(m.readers) != 0 {
   230  		r := m.readers[0]
   231  
   232  		rows[n] = append(rows[n][:0], r.head()...)
   233  		n++
   234  
   235  		if err := r.next(); err != nil {
   236  			if err != io.EOF {
   237  				return n, err
   238  			}
   239  			heap.Pop(m)
   240  		} else {
   241  			heap.Fix(m, 0)
   242  		}
   243  	}
   244  
   245  	if len(m.readers) == 0 {
   246  		err = io.EOF
   247  	}
   248  
   249  	return n, err
   250  }
   251  
   252  func (m *mergedRowReader) Less(i, j int) bool {
   253  	return m.compare(m.readers[i].head(), m.readers[j].head()) < 0
   254  }
   255  
   256  func (m *mergedRowReader) Len() int {
   257  	return len(m.readers)
   258  }
   259  
   260  func (m *mergedRowReader) Swap(i, j int) {
   261  	m.readers[i], m.readers[j] = m.readers[j], m.readers[i]
   262  }
   263  
   264  func (m *mergedRowReader) Push(x interface{}) {
   265  	panic("NOT IMPLEMENTED")
   266  }
   267  
   268  func (m *mergedRowReader) Pop() interface{} {
   269  	i := len(m.readers) - 1
   270  	r := m.readers[i]
   271  	m.readers = m.readers[:i]
   272  	return r
   273  }
   274  
   275  type bufferedRowReader struct {
   276  	rows RowReader
   277  	off  int32
   278  	end  int32
   279  	buf  [10]Row
   280  }
   281  
   282  func (r *bufferedRowReader) head() Row {
   283  	return r.buf[r.off]
   284  }
   285  
   286  func (r *bufferedRowReader) next() error {
   287  	if r.off++; r.off == r.end {
   288  		r.off = 0
   289  		r.end = 0
   290  		return r.read()
   291  	}
   292  	return nil
   293  }
   294  
   295  func (r *bufferedRowReader) read() error {
   296  	if r.rows == nil {
   297  		return io.EOF
   298  	}
   299  	n, err := r.rows.ReadRows(r.buf[r.end:])
   300  	if err != nil && n == 0 {
   301  		return err
   302  	}
   303  	r.end += int32(n)
   304  	return nil
   305  }
   306  
   307  func (r *bufferedRowReader) close() {
   308  	r.rows = nil
   309  	r.off = 0
   310  	r.end = 0
   311  }
   312  
   313  var (
   314  	_ RowReaderWithSchema = (*mergedRowGroupRows)(nil)
   315  )