github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/transform.go (about)

     1  package parquet
     2  
     3  // TransformRowReader constructs a RowReader which applies the given transform
     4  // to each row rad from reader.
     5  //
     6  // The transformation function appends the transformed src row to dst, returning
     7  // dst and any error that occurred during the transformation. If dst is returned
     8  // unchanged, the row is skipped.
     9  func TransformRowReader(reader RowReader, transform func(dst, src Row) (Row, error)) RowReader {
    10  	return &transformRowReader{reader: reader, transform: transform}
    11  }
    12  
    13  type transformRowReader struct {
    14  	reader    RowReader
    15  	transform func(Row, Row) (Row, error)
    16  	rows      []Row
    17  	offset    int
    18  	length    int
    19  }
    20  
    21  func (t *transformRowReader) ReadRows(rows []Row) (n int, err error) {
    22  	if len(t.rows) == 0 {
    23  		t.rows = makeRows(len(rows))
    24  	}
    25  
    26  	for {
    27  		for n < len(rows) && t.offset < t.length {
    28  			dst := rows[n][:0]
    29  			src := t.rows[t.offset]
    30  			rows[n], err = t.transform(dst, src)
    31  			if err != nil {
    32  				return n, err
    33  			}
    34  			clearValues(src)
    35  			t.rows[t.offset] = src[:0]
    36  			t.offset++
    37  			n++
    38  		}
    39  
    40  		if n == len(rows) {
    41  			return n, nil
    42  		}
    43  
    44  		r, err := t.reader.ReadRows(t.rows)
    45  		if r == 0 && err != nil {
    46  			return n, err
    47  		}
    48  		t.offset = 0
    49  		t.length = r
    50  	}
    51  }
    52  
    53  type transformRowBuffer struct {
    54  	buffer []Row
    55  	offset int32
    56  	length int32
    57  }
    58  
    59  func (b *transformRowBuffer) init(n int) {
    60  	b.buffer = makeRows(n)
    61  	b.offset = 0
    62  	b.length = 0
    63  }
    64  
    65  func (b *transformRowBuffer) discard() {
    66  	row := b.buffer[b.offset]
    67  	clearValues(row)
    68  	b.buffer[b.offset] = row[:0]
    69  
    70  	if b.offset++; b.offset == b.length {
    71  		b.reset(0)
    72  	}
    73  }
    74  
    75  func (b *transformRowBuffer) reset(n int) {
    76  	b.offset = 0
    77  	b.length = int32(n)
    78  }
    79  
    80  func (b *transformRowBuffer) rows() []Row {
    81  	return b.buffer[b.offset:b.length]
    82  }
    83  
    84  func (b *transformRowBuffer) cap() int {
    85  	return len(b.buffer)
    86  }
    87  
    88  func (b *transformRowBuffer) len() int {
    89  	return int(b.length - b.offset)
    90  }
    91  
    92  // TransformRowWriter constructs a RowWriter which applies the given transform
    93  // to each row writter to writer.
    94  //
    95  // The transformation function appends the transformed src row to dst, returning
    96  // dst and any error that occurred during the transformation. If dst is returned
    97  // unchanged, the row is skipped.
    98  func TransformRowWriter(writer RowWriter, transform func(dst, src Row) (Row, error)) RowWriter {
    99  	return &transformRowWriter{writer: writer, transform: transform}
   100  }
   101  
   102  type transformRowWriter struct {
   103  	writer    RowWriter
   104  	transform func(Row, Row) (Row, error)
   105  	rows      []Row
   106  }
   107  
   108  func (t *transformRowWriter) WriteRows(rows []Row) (n int, err error) {
   109  	if len(t.rows) == 0 {
   110  		t.rows = makeRows(len(rows))
   111  	}
   112  
   113  	for n < len(rows) {
   114  		numRows := len(rows) - n
   115  		if numRows > len(t.rows) {
   116  			numRows = len(t.rows)
   117  		}
   118  		if err := t.writeRows(rows[n : n+numRows]); err != nil {
   119  			return n, err
   120  		}
   121  		n += numRows
   122  	}
   123  
   124  	return n, nil
   125  }
   126  
   127  func (t *transformRowWriter) writeRows(rows []Row) (err error) {
   128  	numRows := 0
   129  	defer func() { clearRows(t.rows[:numRows]) }()
   130  
   131  	for _, row := range rows {
   132  		t.rows[numRows], err = t.transform(t.rows[numRows][:0], row)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		if len(t.rows[numRows]) != 0 {
   137  			numRows++
   138  		}
   139  	}
   140  
   141  	_, err = t.writer.WriteRows(t.rows[:numRows])
   142  	return err
   143  }