github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/row_buffer.go (about)

     1  package parquet
     2  
     3  import (
     4  	"io"
     5  	"sort"
     6  
     7  	"github.com/parquet-go/parquet-go/deprecated"
     8  	"github.com/parquet-go/parquet-go/encoding"
     9  )
    10  
    11  // RowBuffer is an implementation of the RowGroup interface which stores parquet
    12  // rows in memory.
    13  //
    14  // Unlike GenericBuffer which uses a column layout to store values in memory
    15  // buffers, RowBuffer uses a row layout. The use of row layout provides greater
    16  // efficiency when sorting the buffer, which is the primary use case for the
    17  // RowBuffer type. Applications which intend to sort rows prior to writing them
    18  // to a parquet file will often see lower CPU utilization from using a RowBuffer
    19  // than a GenericBuffer.
    20  //
    21  // RowBuffer values are not safe to use concurrently from multiple goroutines.
    22  type RowBuffer[T any] struct {
    23  	alloc   rowAllocator
    24  	schema  *Schema
    25  	sorting []SortingColumn
    26  	rows    []Row
    27  	values  []Value
    28  	compare func(Row, Row) int
    29  }
    30  
    31  // NewRowBuffer constructs a new row buffer.
    32  func NewRowBuffer[T any](options ...RowGroupOption) *RowBuffer[T] {
    33  	config := DefaultRowGroupConfig()
    34  	config.Apply(options...)
    35  	if err := config.Validate(); err != nil {
    36  		panic(err)
    37  	}
    38  
    39  	t := typeOf[T]()
    40  	if config.Schema == nil && t != nil {
    41  		config.Schema = schemaOf(dereference(t))
    42  	}
    43  
    44  	if config.Schema == nil {
    45  		panic("row buffer must be instantiated with schema or concrete type.")
    46  	}
    47  
    48  	return &RowBuffer[T]{
    49  		schema:  config.Schema,
    50  		sorting: config.Sorting.SortingColumns,
    51  		compare: config.Schema.Comparator(config.Sorting.SortingColumns...),
    52  	}
    53  }
    54  
    55  // Reset clears the content of the buffer without releasing its memory.
    56  func (buf *RowBuffer[T]) Reset() {
    57  	for i := range buf.rows {
    58  		buf.rows[i] = nil
    59  	}
    60  	for i := range buf.values {
    61  		buf.values[i] = Value{}
    62  	}
    63  	buf.rows = buf.rows[:0]
    64  	buf.values = buf.values[:0]
    65  	buf.alloc.reset()
    66  }
    67  
    68  // NumRows returns the number of rows currently written to the buffer.
    69  func (buf *RowBuffer[T]) NumRows() int64 { return int64(len(buf.rows)) }
    70  
    71  // ColumnChunks returns a view of the buffer's columns.
    72  //
    73  // Note that reading columns of a RowBuffer will be less efficient than reading
    74  // columns of a GenericBuffer since the latter uses a column layout. This method
    75  // is mainly exposed to satisfy the RowGroup interface, applications which need
    76  // compute-efficient column scans on in-memory buffers should likely use a
    77  // GenericBuffer instead.
    78  //
    79  // The returned column chunks are snapshots at the time the method is called,
    80  // they remain valid until the next call to Reset on the buffer.
    81  func (buf *RowBuffer[T]) ColumnChunks() []ColumnChunk {
    82  	columns := buf.schema.Columns()
    83  	chunks := make([]rowBufferColumnChunk, len(columns))
    84  
    85  	for i, column := range columns {
    86  		leafColumn, _ := buf.schema.Lookup(column...)
    87  		chunks[i] = rowBufferColumnChunk{
    88  			page: rowBufferPage{
    89  				rows:               buf.rows,
    90  				typ:                leafColumn.Node.Type(),
    91  				column:             leafColumn.ColumnIndex,
    92  				maxRepetitionLevel: byte(leafColumn.MaxRepetitionLevel),
    93  				maxDefinitionLevel: byte(leafColumn.MaxDefinitionLevel),
    94  			},
    95  		}
    96  	}
    97  
    98  	columnChunks := make([]ColumnChunk, len(chunks))
    99  	for i := range chunks {
   100  		columnChunks[i] = &chunks[i]
   101  	}
   102  	return columnChunks
   103  }
   104  
   105  // SortingColumns returns the list of columns that rows are expected to be
   106  // sorted by.
   107  //
   108  // The list of sorting columns is configured when the buffer is created and used
   109  // when it is sorted.
   110  //
   111  // Note that unless the buffer is explicitly sorted, there are no guarantees
   112  // that the rows it contains will be in the order specified by the sorting
   113  // columns.
   114  func (buf *RowBuffer[T]) SortingColumns() []SortingColumn { return buf.sorting }
   115  
   116  // Schema returns the schema of rows in the buffer.
   117  func (buf *RowBuffer[T]) Schema() *Schema { return buf.schema }
   118  
   119  // Len returns the number of rows in the buffer.
   120  //
   121  // The method contributes to satisfying sort.Interface.
   122  func (buf *RowBuffer[T]) Len() int { return len(buf.rows) }
   123  
   124  // Less compares the rows at index i and j according to the sorting columns
   125  // configured on the buffer.
   126  //
   127  // The method contributes to satisfying sort.Interface.
   128  func (buf *RowBuffer[T]) Less(i, j int) bool {
   129  	return buf.compare(buf.rows[i], buf.rows[j]) < 0
   130  }
   131  
   132  // Swap exchanges the rows at index i and j in the buffer.
   133  //
   134  // The method contributes to satisfying sort.Interface.
   135  func (buf *RowBuffer[T]) Swap(i, j int) {
   136  	buf.rows[i], buf.rows[j] = buf.rows[j], buf.rows[i]
   137  }
   138  
   139  // Rows returns a Rows instance exposing rows stored in the buffer.
   140  //
   141  // The rows returned are a snapshot at the time the method is called.
   142  // The returned rows and values read from it remain valid until the next call
   143  // to Reset on the buffer.
   144  func (buf *RowBuffer[T]) Rows() Rows {
   145  	return &rowBufferRows{rows: buf.rows, schema: buf.schema}
   146  }
   147  
   148  // Write writes rows to the buffer, returning the number of rows written.
   149  func (buf *RowBuffer[T]) Write(rows []T) (int, error) {
   150  	for i := range rows {
   151  		off := len(buf.values)
   152  		buf.values = buf.schema.Deconstruct(buf.values, &rows[i])
   153  		end := len(buf.values)
   154  		row := buf.values[off:end:end]
   155  		buf.alloc.capture(row)
   156  		buf.rows = append(buf.rows, row)
   157  	}
   158  	return len(rows), nil
   159  }
   160  
   161  // WriteRows writes parquet rows to the buffer, returing the number of rows
   162  // written.
   163  func (buf *RowBuffer[T]) WriteRows(rows []Row) (int, error) {
   164  	for i := range rows {
   165  		off := len(buf.values)
   166  		buf.values = append(buf.values, rows[i]...)
   167  		end := len(buf.values)
   168  		row := buf.values[off:end:end]
   169  		buf.alloc.capture(row)
   170  		buf.rows = append(buf.rows, row)
   171  	}
   172  	return len(rows), nil
   173  }
   174  
   175  type rowBufferColumnChunk struct{ page rowBufferPage }
   176  
   177  func (c *rowBufferColumnChunk) Type() Type { return c.page.Type() }
   178  
   179  func (c *rowBufferColumnChunk) Column() int { return c.page.Column() }
   180  
   181  func (c *rowBufferColumnChunk) Pages() Pages { return onePage(&c.page) }
   182  
   183  func (c *rowBufferColumnChunk) ColumnIndex() (ColumnIndex, error) { return nil, nil }
   184  
   185  func (c *rowBufferColumnChunk) OffsetIndex() (OffsetIndex, error) { return nil, nil }
   186  
   187  func (c *rowBufferColumnChunk) BloomFilter() BloomFilter { return nil }
   188  
   189  func (c *rowBufferColumnChunk) NumValues() int64 { return c.page.NumValues() }
   190  
   191  type rowBufferPage struct {
   192  	rows               []Row
   193  	typ                Type
   194  	column             int
   195  	maxRepetitionLevel byte
   196  	maxDefinitionLevel byte
   197  }
   198  
   199  func (p *rowBufferPage) Type() Type { return p.typ }
   200  
   201  func (p *rowBufferPage) Column() int { return p.column }
   202  
   203  func (p *rowBufferPage) Dictionary() Dictionary { return nil }
   204  
   205  func (p *rowBufferPage) NumRows() int64 { return int64(len(p.rows)) }
   206  
   207  func (p *rowBufferPage) NumValues() int64 {
   208  	numValues := int64(0)
   209  	p.scan(func(value Value) {
   210  		if !value.isNull() {
   211  			numValues++
   212  		}
   213  	})
   214  	return numValues
   215  }
   216  
   217  func (p *rowBufferPage) NumNulls() int64 {
   218  	numNulls := int64(0)
   219  	p.scan(func(value Value) {
   220  		if value.isNull() {
   221  			numNulls++
   222  		}
   223  	})
   224  	return numNulls
   225  }
   226  
   227  func (p *rowBufferPage) Bounds() (min, max Value, ok bool) {
   228  	p.scan(func(value Value) {
   229  		if !value.IsNull() {
   230  			switch {
   231  			case !ok:
   232  				min, max, ok = value, value, true
   233  			case p.typ.Compare(value, min) < 0:
   234  				min = value
   235  			case p.typ.Compare(value, max) > 0:
   236  				max = value
   237  			}
   238  		}
   239  	})
   240  	return min, max, ok
   241  }
   242  
   243  func (p *rowBufferPage) Size() int64 { return 0 }
   244  
   245  func (p *rowBufferPage) Values() ValueReader {
   246  	return &rowBufferPageValueReader{
   247  		page:        p,
   248  		columnIndex: ^int16(p.column),
   249  	}
   250  }
   251  
   252  func (p *rowBufferPage) Clone() Page {
   253  	rows := make([]Row, len(p.rows))
   254  	for i := range rows {
   255  		rows[i] = p.rows[i].Clone()
   256  	}
   257  	return &rowBufferPage{
   258  		rows:   rows,
   259  		typ:    p.typ,
   260  		column: p.column,
   261  	}
   262  }
   263  
   264  func (p *rowBufferPage) Slice(i, j int64) Page {
   265  	return &rowBufferPage{
   266  		rows:   p.rows[i:j],
   267  		typ:    p.typ,
   268  		column: p.column,
   269  	}
   270  }
   271  
   272  func (p *rowBufferPage) RepetitionLevels() (repetitionLevels []byte) {
   273  	if p.maxRepetitionLevel != 0 {
   274  		repetitionLevels = make([]byte, 0, len(p.rows))
   275  		p.scan(func(value Value) {
   276  			repetitionLevels = append(repetitionLevels, value.repetitionLevel)
   277  		})
   278  	}
   279  	return repetitionLevels
   280  }
   281  
   282  func (p *rowBufferPage) DefinitionLevels() (definitionLevels []byte) {
   283  	if p.maxDefinitionLevel != 0 {
   284  		definitionLevels = make([]byte, 0, len(p.rows))
   285  		p.scan(func(value Value) {
   286  			definitionLevels = append(definitionLevels, value.definitionLevel)
   287  		})
   288  	}
   289  	return definitionLevels
   290  }
   291  
   292  func (p *rowBufferPage) Data() encoding.Values {
   293  	switch p.typ.Kind() {
   294  	case Boolean:
   295  		values := make([]byte, (len(p.rows)+7)/8)
   296  		numValues := 0
   297  		p.scanNonNull(func(value Value) {
   298  			if value.boolean() {
   299  				i := uint(numValues) / 8
   300  				j := uint(numValues) % 8
   301  				values[i] |= 1 << j
   302  			}
   303  			numValues++
   304  		})
   305  		return encoding.BooleanValues(values[:(numValues+7)/8])
   306  
   307  	case Int32:
   308  		values := make([]int32, 0, len(p.rows))
   309  		p.scanNonNull(func(value Value) { values = append(values, value.int32()) })
   310  		return encoding.Int32Values(values)
   311  
   312  	case Int64:
   313  		values := make([]int64, 0, len(p.rows))
   314  		p.scanNonNull(func(value Value) { values = append(values, value.int64()) })
   315  		return encoding.Int64Values(values)
   316  
   317  	case Int96:
   318  		values := make([]deprecated.Int96, 0, len(p.rows))
   319  		p.scanNonNull(func(value Value) { values = append(values, value.int96()) })
   320  		return encoding.Int96Values(values)
   321  
   322  	case Float:
   323  		values := make([]float32, 0, len(p.rows))
   324  		p.scanNonNull(func(value Value) { values = append(values, value.float()) })
   325  		return encoding.FloatValues(values)
   326  
   327  	case Double:
   328  		values := make([]float64, 0, len(p.rows))
   329  		p.scanNonNull(func(value Value) { values = append(values, value.double()) })
   330  		return encoding.DoubleValues(values)
   331  
   332  	case ByteArray:
   333  		values := make([]byte, 0, p.typ.EstimateSize(len(p.rows)))
   334  		offsets := make([]uint32, 0, len(p.rows))
   335  		p.scanNonNull(func(value Value) {
   336  			offsets = append(offsets, uint32(len(values)))
   337  			values = append(values, value.byteArray()...)
   338  		})
   339  		offsets = append(offsets, uint32(len(values)))
   340  		return encoding.ByteArrayValues(values, offsets)
   341  
   342  	case FixedLenByteArray:
   343  		length := p.typ.Length()
   344  		values := make([]byte, 0, length*len(p.rows))
   345  		p.scanNonNull(func(value Value) { values = append(values, value.byteArray()...) })
   346  		return encoding.FixedLenByteArrayValues(values, length)
   347  
   348  	default:
   349  		return encoding.Values{}
   350  	}
   351  }
   352  
   353  func (p *rowBufferPage) scan(f func(Value)) {
   354  	columnIndex := ^int16(p.column)
   355  
   356  	for _, row := range p.rows {
   357  		for _, value := range row {
   358  			if value.columnIndex == columnIndex {
   359  				f(value)
   360  			}
   361  		}
   362  	}
   363  }
   364  
   365  func (p *rowBufferPage) scanNonNull(f func(Value)) {
   366  	p.scan(func(value Value) {
   367  		if !value.isNull() {
   368  			f(value)
   369  		}
   370  	})
   371  }
   372  
   373  type rowBufferPageValueReader struct {
   374  	page        *rowBufferPage
   375  	rowIndex    int
   376  	valueIndex  int
   377  	columnIndex int16
   378  }
   379  
   380  func (r *rowBufferPageValueReader) ReadValues(values []Value) (n int, err error) {
   381  	for n < len(values) && r.rowIndex < len(r.page.rows) {
   382  		for n < len(values) && r.valueIndex < len(r.page.rows[r.rowIndex]) {
   383  			if v := r.page.rows[r.rowIndex][r.valueIndex]; v.columnIndex == r.columnIndex {
   384  				values[n] = v
   385  				n++
   386  			}
   387  			r.valueIndex++
   388  		}
   389  		r.rowIndex++
   390  		r.valueIndex = 0
   391  	}
   392  	if r.rowIndex == len(r.page.rows) {
   393  		err = io.EOF
   394  	}
   395  	return n, err
   396  }
   397  
   398  type rowBufferRows struct {
   399  	rows   []Row
   400  	index  int
   401  	schema *Schema
   402  }
   403  
   404  func (r *rowBufferRows) Close() error {
   405  	r.index = -1
   406  	return nil
   407  }
   408  
   409  func (r *rowBufferRows) Schema() *Schema {
   410  	return r.schema
   411  }
   412  
   413  func (r *rowBufferRows) SeekToRow(rowIndex int64) error {
   414  	if rowIndex < 0 {
   415  		return ErrSeekOutOfRange
   416  	}
   417  
   418  	if r.index < 0 {
   419  		return io.ErrClosedPipe
   420  	}
   421  
   422  	maxRowIndex := int64(len(r.rows))
   423  	if rowIndex > maxRowIndex {
   424  		rowIndex = maxRowIndex
   425  	}
   426  
   427  	r.index = int(rowIndex)
   428  	return nil
   429  }
   430  
   431  func (r *rowBufferRows) ReadRows(rows []Row) (n int, err error) {
   432  	if r.index < 0 {
   433  		return 0, io.EOF
   434  	}
   435  
   436  	if n = len(r.rows) - r.index; n > len(rows) {
   437  		n = len(rows)
   438  	}
   439  
   440  	for i, row := range r.rows[r.index : r.index+n] {
   441  		rows[i] = append(rows[i][:0], row...)
   442  	}
   443  
   444  	if r.index += n; r.index == len(r.rows) {
   445  		err = io.EOF
   446  	}
   447  
   448  	return n, err
   449  }
   450  
   451  func (r *rowBufferRows) WriteRowsTo(w RowWriter) (int64, error) {
   452  	n, err := w.WriteRows(r.rows[r.index:])
   453  	r.index += n
   454  	return int64(n), err
   455  }
   456  
   457  var (
   458  	_ RowGroup       = (*RowBuffer[any])(nil)
   459  	_ RowWriter      = (*RowBuffer[any])(nil)
   460  	_ sort.Interface = (*RowBuffer[any])(nil)
   461  
   462  	_ RowWriterTo = (*rowBufferRows)(nil)
   463  )