github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/row_buffer.go (about)

     1  //go:build go1.18
     2  
     3  package parquet
     4  
     5  import (
     6  	"io"
     7  	"sort"
     8  
     9  	"github.com/segmentio/parquet-go/deprecated"
    10  	"github.com/segmentio/parquet-go/encoding"
    11  )
    12  
    13  // RowBuffer is an implementation of the RowGroup interface which stores parquet
    14  // rows in memory.
    15  //
    16  // Unlike GenericBuffer which uses a column layout to store values in memory
    17  // buffers, RowBuffer uses a row layout. The use of row layout provides greater
    18  // efficiency when sorting the buffer, which is the primary use case for the
    19  // RowBuffer type. Applications which intend to sort rows prior to writing them
    20  // to a parquet file will often see lower CPU utilization from using a RowBuffer
    21  // than a GenericBuffer.
    22  //
    23  // RowBuffer values are not safe to use concurrently from multiple goroutines.
    24  type RowBuffer[T any] struct {
    25  	alloc   rowAllocator
    26  	schema  *Schema
    27  	sorting []SortingColumn
    28  	rows    []Row
    29  	values  []Value
    30  	compare func(Row, Row) int
    31  }
    32  
    33  // NewRowBuffer constructs a new row buffer.
    34  func NewRowBuffer[T any](options ...RowGroupOption) *RowBuffer[T] {
    35  	config := DefaultRowGroupConfig()
    36  	config.Apply(options...)
    37  	if err := config.Validate(); err != nil {
    38  		panic(err)
    39  	}
    40  
    41  	t := typeOf[T]()
    42  	if config.Schema == nil && t != nil {
    43  		config.Schema = schemaOf(dereference(t))
    44  	}
    45  
    46  	if config.Schema == nil {
    47  		panic("row buffer must be instantiated with schema or concrete type.")
    48  	}
    49  
    50  	return &RowBuffer[T]{
    51  		schema:  config.Schema,
    52  		sorting: config.Sorting.SortingColumns,
    53  		compare: config.Schema.Comparator(config.Sorting.SortingColumns...),
    54  	}
    55  }
    56  
    57  // Reset clears the content of the buffer without releasing its memory.
    58  func (buf *RowBuffer[T]) Reset() {
    59  	for i := range buf.rows {
    60  		buf.rows[i] = nil
    61  	}
    62  	for i := range buf.values {
    63  		buf.values[i] = Value{}
    64  	}
    65  	buf.rows = buf.rows[:0]
    66  	buf.values = buf.values[:0]
    67  	buf.alloc.reset()
    68  }
    69  
    70  // NumRows returns the number of rows currently written to the buffer.
    71  func (buf *RowBuffer[T]) NumRows() int64 { return int64(len(buf.rows)) }
    72  
    73  // ColumnChunks returns a view of the buffer's columns.
    74  //
    75  // Note that reading columns of a RowBuffer will be less efficient than reading
    76  // columns of a GenericBuffer since the latter uses a column layout. This method
    77  // is mainly exposed to satisfy the RowGroup interface, applications which need
    78  // compute-efficient column scans on in-memory buffers should likely use a
    79  // GenericBuffer instead.
    80  //
    81  // The returned column chunks are snapshots at the time the method is called,
    82  // they remain valid until the next call to Reset on the buffer.
    83  func (buf *RowBuffer[T]) ColumnChunks() []ColumnChunk {
    84  	columns := buf.schema.Columns()
    85  	chunks := make([]rowBufferColumnChunk, len(columns))
    86  
    87  	for i, column := range columns {
    88  		leafColumn, _ := buf.schema.Lookup(column...)
    89  		chunks[i] = rowBufferColumnChunk{
    90  			page: rowBufferPage{
    91  				rows:               buf.rows,
    92  				typ:                leafColumn.Node.Type(),
    93  				column:             leafColumn.ColumnIndex,
    94  				maxRepetitionLevel: byte(leafColumn.MaxRepetitionLevel),
    95  				maxDefinitionLevel: byte(leafColumn.MaxDefinitionLevel),
    96  			},
    97  		}
    98  	}
    99  
   100  	columnChunks := make([]ColumnChunk, len(chunks))
   101  	for i := range chunks {
   102  		columnChunks[i] = &chunks[i]
   103  	}
   104  	return columnChunks
   105  }
   106  
   107  // SortingColumns returns the list of columns that rows are expected to be
   108  // sorted by.
   109  //
   110  // The list of sorting columns is configured when the buffer is created and used
   111  // when it is sorted.
   112  //
   113  // Note that unless the buffer is explicitly sorted, there are no guarantees
   114  // that the rows it contains will be in the order specified by the sorting
   115  // columns.
   116  func (buf *RowBuffer[T]) SortingColumns() []SortingColumn { return buf.sorting }
   117  
   118  // Schema returns the schema of rows in the buffer.
   119  func (buf *RowBuffer[T]) Schema() *Schema { return buf.schema }
   120  
   121  // Len returns the number of rows in the buffer.
   122  //
   123  // The method contributes to satisfying sort.Interface.
   124  func (buf *RowBuffer[T]) Len() int { return len(buf.rows) }
   125  
   126  // Less compares the rows at index i and j according to the sorting columns
   127  // configured on the buffer.
   128  //
   129  // The method contributes to satisfying sort.Interface.
   130  func (buf *RowBuffer[T]) Less(i, j int) bool {
   131  	return buf.compare(buf.rows[i], buf.rows[j]) < 0
   132  }
   133  
   134  // Swap exchanges the rows at index i and j in the buffer.
   135  //
   136  // The method contributes to satisfying sort.Interface.
   137  func (buf *RowBuffer[T]) Swap(i, j int) {
   138  	buf.rows[i], buf.rows[j] = buf.rows[j], buf.rows[i]
   139  }
   140  
   141  // Rows returns a Rows instance exposing rows stored in the buffer.
   142  //
   143  // The rows returned are a snapshot at the time the method is called.
   144  // The returned rows and values read from it remain valid until the next call
   145  // to Reset on the buffer.
   146  func (buf *RowBuffer[T]) Rows() Rows {
   147  	return &rowBufferRows{rows: buf.rows, schema: buf.schema}
   148  }
   149  
   150  // Write writes rows to the buffer, returning the number of rows written.
   151  func (buf *RowBuffer[T]) Write(rows []T) (int, error) {
   152  	for i := range rows {
   153  		off := len(buf.values)
   154  		buf.values = buf.schema.Deconstruct(buf.values, &rows[i])
   155  		end := len(buf.values)
   156  		row := buf.values[off:end:end]
   157  		buf.alloc.capture(row)
   158  		buf.rows = append(buf.rows, row)
   159  	}
   160  	return len(rows), nil
   161  }
   162  
   163  // WriteRows writes parquet rows to the buffer, returing the number of rows
   164  // written.
   165  func (buf *RowBuffer[T]) WriteRows(rows []Row) (int, error) {
   166  	for i := range rows {
   167  		off := len(buf.values)
   168  		buf.values = append(buf.values, rows[i]...)
   169  		end := len(buf.values)
   170  		row := buf.values[off:end:end]
   171  		buf.alloc.capture(row)
   172  		buf.rows = append(buf.rows, row)
   173  	}
   174  	return len(rows), nil
   175  }
   176  
   177  type rowBufferColumnChunk struct{ page rowBufferPage }
   178  
   179  func (c *rowBufferColumnChunk) Type() Type { return c.page.Type() }
   180  
   181  func (c *rowBufferColumnChunk) Column() int { return c.page.Column() }
   182  
   183  func (c *rowBufferColumnChunk) Pages() Pages { return onePage(&c.page) }
   184  
   185  func (c *rowBufferColumnChunk) ColumnIndex() ColumnIndex { return nil }
   186  
   187  func (c *rowBufferColumnChunk) OffsetIndex() OffsetIndex { return nil }
   188  
   189  func (c *rowBufferColumnChunk) BloomFilter() BloomFilter { return nil }
   190  
   191  func (c *rowBufferColumnChunk) NumValues() int64 { return c.page.NumValues() }
   192  
   193  type rowBufferPage struct {
   194  	rows               []Row
   195  	typ                Type
   196  	column             int
   197  	maxRepetitionLevel byte
   198  	maxDefinitionLevel byte
   199  }
   200  
   201  func (p *rowBufferPage) Type() Type { return p.typ }
   202  
   203  func (p *rowBufferPage) Column() int { return p.column }
   204  
   205  func (p *rowBufferPage) Dictionary() Dictionary { return nil }
   206  
   207  func (p *rowBufferPage) NumRows() int64 { return int64(len(p.rows)) }
   208  
   209  func (p *rowBufferPage) NumValues() int64 {
   210  	numValues := int64(0)
   211  	p.scan(func(value Value) {
   212  		if !value.isNull() {
   213  			numValues++
   214  		}
   215  	})
   216  	return numValues
   217  }
   218  
   219  func (p *rowBufferPage) NumNulls() int64 {
   220  	numNulls := int64(0)
   221  	p.scan(func(value Value) {
   222  		if value.isNull() {
   223  			numNulls++
   224  		}
   225  	})
   226  	return numNulls
   227  }
   228  
   229  func (p *rowBufferPage) Bounds() (min, max Value, ok bool) {
   230  	p.scan(func(value Value) {
   231  		if !value.IsNull() {
   232  			switch {
   233  			case !ok:
   234  				min, max, ok = value, value, true
   235  			case p.typ.Compare(value, min) < 0:
   236  				min = value
   237  			case p.typ.Compare(value, max) > 0:
   238  				max = value
   239  			}
   240  		}
   241  	})
   242  	return min, max, ok
   243  }
   244  
   245  func (p *rowBufferPage) Size() int64 { return 0 }
   246  
   247  func (p *rowBufferPage) Values() ValueReader {
   248  	return &rowBufferPageValueReader{
   249  		page:        p,
   250  		columnIndex: ^int16(p.column),
   251  	}
   252  }
   253  
   254  func (p *rowBufferPage) Clone() Page {
   255  	rows := make([]Row, len(p.rows))
   256  	for i := range rows {
   257  		rows[i] = p.rows[i].Clone()
   258  	}
   259  	return &rowBufferPage{
   260  		rows:   rows,
   261  		typ:    p.typ,
   262  		column: p.column,
   263  	}
   264  }
   265  
   266  func (p *rowBufferPage) Slice(i, j int64) Page {
   267  	return &rowBufferPage{
   268  		rows:   p.rows[i:j],
   269  		typ:    p.typ,
   270  		column: p.column,
   271  	}
   272  }
   273  
   274  func (p *rowBufferPage) RepetitionLevels() (repetitionLevels []byte) {
   275  	if p.maxRepetitionLevel != 0 {
   276  		repetitionLevels = make([]byte, 0, len(p.rows))
   277  		p.scan(func(value Value) {
   278  			repetitionLevels = append(repetitionLevels, value.repetitionLevel)
   279  		})
   280  	}
   281  	return repetitionLevels
   282  }
   283  
   284  func (p *rowBufferPage) DefinitionLevels() (definitionLevels []byte) {
   285  	if p.maxDefinitionLevel != 0 {
   286  		definitionLevels = make([]byte, 0, len(p.rows))
   287  		p.scan(func(value Value) {
   288  			definitionLevels = append(definitionLevels, value.definitionLevel)
   289  		})
   290  	}
   291  	return definitionLevels
   292  }
   293  
   294  func (p *rowBufferPage) Data() encoding.Values {
   295  	switch p.typ.Kind() {
   296  	case Boolean:
   297  		values := make([]byte, (len(p.rows)+7)/8)
   298  		numValues := 0
   299  		p.scanNonNull(func(value Value) {
   300  			if value.boolean() {
   301  				i := uint(numValues) / 8
   302  				j := uint(numValues) % 8
   303  				values[i] |= 1 << j
   304  			}
   305  			numValues++
   306  		})
   307  		return encoding.BooleanValues(values[:(numValues+7)/8])
   308  
   309  	case Int32:
   310  		values := make([]int32, 0, len(p.rows))
   311  		p.scanNonNull(func(value Value) { values = append(values, value.int32()) })
   312  		return encoding.Int32Values(values)
   313  
   314  	case Int64:
   315  		values := make([]int64, 0, len(p.rows))
   316  		p.scanNonNull(func(value Value) { values = append(values, value.int64()) })
   317  		return encoding.Int64Values(values)
   318  
   319  	case Int96:
   320  		values := make([]deprecated.Int96, 0, len(p.rows))
   321  		p.scanNonNull(func(value Value) { values = append(values, value.int96()) })
   322  		return encoding.Int96Values(values)
   323  
   324  	case Float:
   325  		values := make([]float32, 0, len(p.rows))
   326  		p.scanNonNull(func(value Value) { values = append(values, value.float()) })
   327  		return encoding.FloatValues(values)
   328  
   329  	case Double:
   330  		values := make([]float64, 0, len(p.rows))
   331  		p.scanNonNull(func(value Value) { values = append(values, value.double()) })
   332  		return encoding.DoubleValues(values)
   333  
   334  	case ByteArray:
   335  		values := make([]byte, 0, p.typ.EstimateSize(len(p.rows)))
   336  		offsets := make([]uint32, 0, len(p.rows))
   337  		p.scanNonNull(func(value Value) {
   338  			offsets = append(offsets, uint32(len(values)))
   339  			values = append(values, value.byteArray()...)
   340  		})
   341  		offsets = append(offsets, uint32(len(values)))
   342  		return encoding.ByteArrayValues(values, offsets)
   343  
   344  	case FixedLenByteArray:
   345  		length := p.typ.Length()
   346  		values := make([]byte, 0, length*len(p.rows))
   347  		p.scanNonNull(func(value Value) { values = append(values, value.byteArray()...) })
   348  		return encoding.FixedLenByteArrayValues(values, length)
   349  
   350  	default:
   351  		return encoding.Values{}
   352  	}
   353  }
   354  
   355  func (p *rowBufferPage) scan(f func(Value)) {
   356  	columnIndex := ^int16(p.column)
   357  
   358  	for _, row := range p.rows {
   359  		for _, value := range row {
   360  			if value.columnIndex == columnIndex {
   361  				f(value)
   362  			}
   363  		}
   364  	}
   365  }
   366  
   367  func (p *rowBufferPage) scanNonNull(f func(Value)) {
   368  	p.scan(func(value Value) {
   369  		if !value.isNull() {
   370  			f(value)
   371  		}
   372  	})
   373  }
   374  
   375  type rowBufferPageValueReader struct {
   376  	page        *rowBufferPage
   377  	rowIndex    int
   378  	valueIndex  int
   379  	columnIndex int16
   380  }
   381  
   382  func (r *rowBufferPageValueReader) ReadValues(values []Value) (n int, err error) {
   383  	for n < len(values) && r.rowIndex < len(r.page.rows) {
   384  		for n < len(values) && r.valueIndex < len(r.page.rows[r.rowIndex]) {
   385  			if v := r.page.rows[r.rowIndex][r.valueIndex]; v.columnIndex == r.columnIndex {
   386  				values[n] = v
   387  				n++
   388  			}
   389  			r.valueIndex++
   390  		}
   391  		r.rowIndex++
   392  		r.valueIndex = 0
   393  	}
   394  	if r.rowIndex == len(r.page.rows) {
   395  		err = io.EOF
   396  	}
   397  	return n, err
   398  }
   399  
   400  type rowBufferRows struct {
   401  	rows   []Row
   402  	index  int
   403  	schema *Schema
   404  }
   405  
   406  func (r *rowBufferRows) Close() error {
   407  	r.index = -1
   408  	return nil
   409  }
   410  
   411  func (r *rowBufferRows) Schema() *Schema {
   412  	return r.schema
   413  }
   414  
   415  func (r *rowBufferRows) SeekToRow(rowIndex int64) error {
   416  	if rowIndex < 0 {
   417  		return ErrSeekOutOfRange
   418  	}
   419  
   420  	if r.index < 0 {
   421  		return io.ErrClosedPipe
   422  	}
   423  
   424  	maxRowIndex := int64(len(r.rows))
   425  	if rowIndex > maxRowIndex {
   426  		rowIndex = maxRowIndex
   427  	}
   428  
   429  	r.index = int(rowIndex)
   430  	return nil
   431  }
   432  
   433  func (r *rowBufferRows) ReadRows(rows []Row) (n int, err error) {
   434  	if r.index < 0 {
   435  		return 0, io.EOF
   436  	}
   437  
   438  	if n = len(r.rows) - r.index; n > len(rows) {
   439  		n = len(rows)
   440  	}
   441  
   442  	for i, row := range r.rows[r.index : r.index+n] {
   443  		rows[i] = append(rows[i][:0], row...)
   444  	}
   445  
   446  	if r.index += n; r.index == len(r.rows) {
   447  		err = io.EOF
   448  	}
   449  
   450  	return n, err
   451  }
   452  
   453  func (r *rowBufferRows) WriteRowsTo(w RowWriter) (int64, error) {
   454  	n, err := w.WriteRows(r.rows[r.index:])
   455  	r.index += n
   456  	return int64(n), err
   457  }
   458  
   459  var (
   460  	_ RowGroup       = (*RowBuffer[any])(nil)
   461  	_ RowWriter      = (*RowBuffer[any])(nil)
   462  	_ sort.Interface = (*RowBuffer[any])(nil)
   463  
   464  	_ RowWriterTo = (*rowBufferRows)(nil)
   465  )