github.com/apache/arrow/go/v16@v16.1.0/parquet/file/file_writer_test.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package file_test
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"reflect"
    23  	"testing"
    24  
    25  	"github.com/apache/arrow/go/v16/arrow/memory"
    26  	"github.com/apache/arrow/go/v16/parquet"
    27  	"github.com/apache/arrow/go/v16/parquet/compress"
    28  	"github.com/apache/arrow/go/v16/parquet/file"
    29  	"github.com/apache/arrow/go/v16/parquet/internal/encoding"
    30  	"github.com/apache/arrow/go/v16/parquet/internal/testutils"
    31  	"github.com/apache/arrow/go/v16/parquet/schema"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"github.com/stretchr/testify/suite"
    35  )
    36  
    37  type SerializeTestSuite struct {
    38  	testutils.PrimitiveTypedTest
    39  	suite.Suite
    40  
    41  	numCols      int
    42  	numRowGroups int
    43  	rowsPerRG    int
    44  	rowsPerBatch int
    45  }
    46  
    47  func (t *SerializeTestSuite) SetupTest() {
    48  	t.numCols = 4
    49  	t.numRowGroups = 4
    50  	t.rowsPerRG = 50
    51  	t.rowsPerBatch = 10
    52  	t.SetupSchema(parquet.Repetitions.Optional, t.numCols)
    53  }
    54  
    55  func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expected compress.Compression) {
    56  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
    57  
    58  	opts := make([]parquet.WriterProperty, 0)
    59  	for i := 0; i < t.numCols; i++ {
    60  		opts = append(opts, parquet.WithCompressionFor(t.Schema.Column(i).Name(), codec))
    61  	}
    62  
    63  	props := parquet.NewWriterProperties(opts...)
    64  
    65  	writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props))
    66  	t.GenerateData(int64(t.rowsPerRG))
    67  
    68  	t.serializeGeneratedData(writer)
    69  	writer.FlushWithFooter()
    70  
    71  	t.validateSerializedData(writer, sink, expected)
    72  
    73  	t.serializeGeneratedData(writer)
    74  	writer.Close()
    75  
    76  	t.numRowGroups *= 2
    77  	t.validateSerializedData(writer, sink, expected)
    78  }
    79  
    80  func (t *SerializeTestSuite) serializeGeneratedData(writer *file.Writer) {
    81  	for rg := 0; rg < t.numRowGroups/2; rg++ {
    82  		rgw := writer.AppendRowGroup()
    83  		for col := 0; col < t.numCols; col++ {
    84  			cw, _ := rgw.NextColumn()
    85  			t.WriteBatchValues(cw, t.DefLevels, nil)
    86  			cw.Close()
    87  			// ensure column() api which is specific to bufferedrowgroups cannot be called
    88  			t.Panics(func() { rgw.(file.BufferedRowGroupWriter).Column(col) })
    89  		}
    90  		rgw.Close()
    91  	}
    92  
    93  	// write half buffered row groups
    94  	for rg := 0; rg < t.numRowGroups/2; rg++ {
    95  		rgw := writer.AppendBufferedRowGroup()
    96  		for batch := 0; batch < (t.rowsPerRG / t.rowsPerBatch); batch++ {
    97  			for col := 0; col < t.numCols; col++ {
    98  				cw, _ := rgw.Column(col)
    99  				offset := batch * t.rowsPerBatch
   100  				t.WriteBatchSubset(t.rowsPerBatch, offset, cw, t.DefLevels[offset:t.rowsPerBatch+offset], nil)
   101  				// Ensure NextColumn api which is specific to RowGroup cannot be called
   102  				t.Panics(func() { rgw.(file.SerialRowGroupWriter).NextColumn() })
   103  			}
   104  		}
   105  		for col := 0; col < t.numCols; col++ {
   106  			cw, _ := rgw.Column(col)
   107  			cw.Close()
   108  		}
   109  		rgw.Close()
   110  	}
   111  }
   112  
   113  func (t *SerializeTestSuite) validateSerializedData(writer *file.Writer, sink *encoding.BufferWriter, expected compress.Compression) {
   114  	nrows := t.numRowGroups * t.rowsPerRG
   115  	t.EqualValues(nrows, writer.NumRows())
   116  
   117  	reader, err := file.NewParquetReader(bytes.NewReader(sink.Bytes()))
   118  	t.NoError(err)
   119  	t.Equal(t.numCols, reader.MetaData().Schema.NumColumns())
   120  	t.Equal(t.numRowGroups, reader.NumRowGroups())
   121  	t.EqualValues(nrows, reader.NumRows())
   122  
   123  	for rg := 0; rg < t.numRowGroups; rg++ {
   124  		rgr := reader.RowGroup(rg)
   125  		t.Equal(t.numCols, rgr.NumColumns())
   126  		t.EqualValues(t.rowsPerRG, rgr.NumRows())
   127  		chunk, _ := rgr.MetaData().ColumnChunk(0)
   128  		t.Equal(expected, chunk.Compression())
   129  
   130  		valuesRead := int64(0)
   131  
   132  		for i := 0; i < t.numCols; i++ {
   133  			chunk, _ := rgr.MetaData().ColumnChunk(i)
   134  			t.False(chunk.HasIndexPage())
   135  			t.DefLevelsOut = make([]int16, t.rowsPerRG)
   136  			t.RepLevelsOut = make([]int16, t.rowsPerRG)
   137  			colReader, err := rgr.Column(i)
   138  			t.NoError(err)
   139  			t.SetupValuesOut(int64(t.rowsPerRG))
   140  			valuesRead = t.ReadBatch(colReader, int64(t.rowsPerRG), 0, t.DefLevelsOut, t.RepLevelsOut)
   141  			t.EqualValues(t.rowsPerRG, valuesRead)
   142  			t.Equal(t.Values, t.ValuesOut)
   143  			t.Equal(t.DefLevels, t.DefLevelsOut)
   144  		}
   145  	}
   146  }
   147  
   148  func (t *SerializeTestSuite) unequalNumRows(maxRows int64, rowsPerCol []int64) {
   149  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   150  	props := parquet.NewWriterProperties()
   151  	writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props))
   152  	defer writer.Close()
   153  
   154  	rgw := writer.AppendRowGroup()
   155  	t.GenerateData(maxRows)
   156  	for col := 0; col < t.numCols; col++ {
   157  		cw, _ := rgw.NextColumn()
   158  		t.WriteBatchSubset(int(rowsPerCol[col]), 0, cw, t.DefLevels[:rowsPerCol[col]], nil)
   159  		cw.Close()
   160  	}
   161  	err := rgw.Close()
   162  	t.Error(err)
   163  	t.ErrorContains(err, "row mismatch for unbuffered row group")
   164  }
   165  
   166  func (t *SerializeTestSuite) unequalNumRowsBuffered(maxRows int64, rowsPerCol []int64) {
   167  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   168  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   169  	defer writer.Close()
   170  
   171  	rgw := writer.AppendBufferedRowGroup()
   172  	t.GenerateData(maxRows)
   173  	for col := 0; col < t.numCols; col++ {
   174  		cw, _ := rgw.Column(col)
   175  		t.WriteBatchSubset(int(rowsPerCol[col]), 0, cw, t.DefLevels[:rowsPerCol[col]], nil)
   176  		cw.Close()
   177  	}
   178  	err := rgw.Close()
   179  	t.Error(err)
   180  	t.ErrorContains(err, "row mismatch for buffered row group")
   181  }
   182  
   183  func (t *SerializeTestSuite) TestZeroRows() {
   184  	t.NotPanics(func() {
   185  		sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   186  		writer := file.NewParquetWriter(sink, t.Schema.Root())
   187  		defer writer.Close()
   188  
   189  		srgw := writer.AppendRowGroup()
   190  		for col := 0; col < t.numCols; col++ {
   191  			cw, _ := srgw.NextColumn()
   192  			cw.Close()
   193  		}
   194  		srgw.Close()
   195  
   196  		brgw := writer.AppendBufferedRowGroup()
   197  		for col := 0; col < t.numCols; col++ {
   198  			cw, _ := brgw.Column(col)
   199  			cw.Close()
   200  		}
   201  		brgw.Close()
   202  	})
   203  }
   204  
   205  func (t *SerializeTestSuite) TestTooManyColumns() {
   206  	t.SetupSchema(parquet.Repetitions.Optional, 1)
   207  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   208  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   209  	rgw := writer.AppendRowGroup()
   210  
   211  	rgw.NextColumn()                      // first column
   212  	t.Panics(func() { rgw.NextColumn() }) // only one column!
   213  }
   214  
   215  func (t *SerializeTestSuite) TestRepeatedTooFewRows() {
   216  	// optional and repeated, so definition and repetition levels
   217  	t.SetupSchema(parquet.Repetitions.Repeated, 1)
   218  	const nrows = 100
   219  	t.GenerateData(nrows)
   220  
   221  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   222  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   223  
   224  	rgw := writer.AppendRowGroup()
   225  	t.RepLevels = make([]int16, nrows)
   226  	for idx := range t.RepLevels {
   227  		t.RepLevels[idx] = 0
   228  	}
   229  
   230  	cw, _ := rgw.NextColumn()
   231  	t.WriteBatchValues(cw, t.DefLevels, t.RepLevels)
   232  	cw.Close()
   233  
   234  	t.RepLevels[3] = 1 // this makes it so that values 2 and 3 are a single row
   235  	// as a result there's one too few rows in the result
   236  
   237  	t.Panics(func() {
   238  		cw, _ = rgw.NextColumn()
   239  		t.WriteBatchValues(cw, t.DefLevels, t.RepLevels)
   240  		cw.Close()
   241  	})
   242  }
   243  
   244  func (t *SerializeTestSuite) TestTooFewRows() {
   245  	rowsPerCol := []int64{100, 100, 100, 99}
   246  	t.NotPanics(func() { t.unequalNumRows(100, rowsPerCol) })
   247  	t.NotPanics(func() { t.unequalNumRowsBuffered(100, rowsPerCol) })
   248  }
   249  
   250  func (t *SerializeTestSuite) TestTooManyRows() {
   251  	rowsPerCol := []int64{100, 100, 100, 101}
   252  	t.NotPanics(func() { t.unequalNumRows(101, rowsPerCol) })
   253  	t.NotPanics(func() { t.unequalNumRowsBuffered(101, rowsPerCol) })
   254  }
   255  
   256  func (t *SerializeTestSuite) TestSmallFile() {
   257  	codecs := []compress.Compression{
   258  		compress.Codecs.Uncompressed,
   259  		compress.Codecs.Snappy,
   260  		compress.Codecs.Brotli,
   261  		compress.Codecs.Gzip,
   262  		compress.Codecs.Zstd,
   263  		// compress.Codecs.Lz4,
   264  		// compress.Codecs.Lzo,
   265  	}
   266  	for _, c := range codecs {
   267  		t.Run(c.String(), func() {
   268  			t.NotPanics(func() { t.fileSerializeTest(c, c) })
   269  		})
   270  	}
   271  }
   272  
   273  func TestBufferedDisabledDictionary(t *testing.T) {
   274  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   275  	fields := schema.FieldList{schema.NewInt32Node("col", parquet.Repetitions.Required, 1)}
   276  	sc, _ := schema.NewGroupNode("schema", parquet.Repetitions.Required, fields, 0)
   277  	props := parquet.NewWriterProperties(parquet.WithDictionaryDefault(false))
   278  
   279  	writer := file.NewParquetWriter(sink, sc, file.WithWriterProps(props))
   280  	rgw := writer.AppendBufferedRowGroup()
   281  	cwr, _ := rgw.Column(0)
   282  	cw := cwr.(*file.Int32ColumnChunkWriter)
   283  	cw.WriteBatch([]int32{1}, nil, nil)
   284  	rgw.Close()
   285  	writer.Close()
   286  
   287  	buffer := sink.Finish()
   288  	defer buffer.Release()
   289  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()))
   290  	assert.NoError(t, err)
   291  	assert.EqualValues(t, 1, reader.NumRowGroups())
   292  	rgReader := reader.RowGroup(0)
   293  	assert.EqualValues(t, 1, rgReader.NumRows())
   294  	chunk, _ := rgReader.MetaData().ColumnChunk(0)
   295  	assert.False(t, chunk.HasDictionaryPage())
   296  }
   297  
   298  func TestBufferedMultiPageDisabledDictionary(t *testing.T) {
   299  	const (
   300  		valueCount = 10000
   301  		pageSize   = 16384
   302  	)
   303  	var (
   304  		sink  = encoding.NewBufferWriter(0, memory.DefaultAllocator)
   305  		props = parquet.NewWriterProperties(parquet.WithDictionaryDefault(false), parquet.WithDataPageSize(pageSize))
   306  		sc, _ = schema.NewGroupNode("schema", parquet.Repetitions.Required, schema.FieldList{
   307  			schema.NewInt32Node("col", parquet.Repetitions.Required, -1),
   308  		}, -1)
   309  	)
   310  
   311  	writer := file.NewParquetWriter(sink, sc, file.WithWriterProps(props))
   312  	rgWriter := writer.AppendBufferedRowGroup()
   313  	cwr, _ := rgWriter.Column(0)
   314  	cw := cwr.(*file.Int32ColumnChunkWriter)
   315  	valuesIn := make([]int32, 0, valueCount)
   316  	for i := int32(0); i < valueCount; i++ {
   317  		valuesIn = append(valuesIn, (i%100)+1)
   318  	}
   319  	cw.WriteBatch(valuesIn, nil, nil)
   320  	rgWriter.Close()
   321  	writer.Close()
   322  	buffer := sink.Finish()
   323  	defer buffer.Release()
   324  
   325  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()))
   326  	assert.NoError(t, err)
   327  
   328  	assert.EqualValues(t, 1, reader.NumRowGroups())
   329  	valuesOut := make([]int32, valueCount)
   330  
   331  	for r := 0; r < reader.NumRowGroups(); r++ {
   332  		rgr := reader.RowGroup(r)
   333  		assert.EqualValues(t, 1, rgr.NumColumns())
   334  		assert.EqualValues(t, valueCount, rgr.NumRows())
   335  
   336  		var totalRead int64
   337  		col, err := rgr.Column(0)
   338  		assert.NoError(t, err)
   339  		colReader := col.(*file.Int32ColumnChunkReader)
   340  		for colReader.HasNext() {
   341  			total, _, _ := colReader.ReadBatch(valueCount-totalRead, valuesOut[totalRead:], nil, nil)
   342  			totalRead += total
   343  		}
   344  		assert.EqualValues(t, valueCount, totalRead)
   345  		assert.Equal(t, valuesIn, valuesOut)
   346  	}
   347  }
   348  
   349  func TestAllNulls(t *testing.T) {
   350  	sc, _ := schema.NewGroupNode("root", parquet.Repetitions.Required, schema.FieldList{
   351  		schema.NewInt32Node("nulls", parquet.Repetitions.Optional, -1),
   352  	}, -1)
   353  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   354  
   355  	writer := file.NewParquetWriter(sink, sc)
   356  	rgw := writer.AppendRowGroup()
   357  	cwr, _ := rgw.NextColumn()
   358  	cw := cwr.(*file.Int32ColumnChunkWriter)
   359  
   360  	var (
   361  		values    [3]int32
   362  		defLevels = [...]int16{0, 0, 0}
   363  	)
   364  
   365  	cw.WriteBatch(values[:], defLevels[:], nil)
   366  	cw.Close()
   367  	rgw.Close()
   368  	writer.Close()
   369  
   370  	buffer := sink.Finish()
   371  	defer buffer.Release()
   372  	props := parquet.NewReaderProperties(memory.DefaultAllocator)
   373  	props.BufferedStreamEnabled = true
   374  
   375  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()), file.WithReadProps(props))
   376  	assert.NoError(t, err)
   377  
   378  	rgr := reader.RowGroup(0)
   379  	col, err := rgr.Column(0)
   380  	assert.NoError(t, err)
   381  	cr := col.(*file.Int32ColumnChunkReader)
   382  
   383  	defLevels[0] = -1
   384  	defLevels[1] = -1
   385  	defLevels[2] = -1
   386  	valRead, read, _ := cr.ReadBatch(3, values[:], defLevels[:], nil)
   387  	assert.EqualValues(t, 3, valRead)
   388  	assert.EqualValues(t, 0, read)
   389  	assert.Equal(t, []int16{0, 0, 0}, defLevels[:])
   390  }
   391  
   392  func TestKeyValueMetadata(t *testing.T) {
   393  	fields := schema.FieldList{
   394  		schema.NewInt32Node("unused", parquet.Repetitions.Optional, -1),
   395  	}
   396  	sc, _ := schema.NewGroupNode("root", parquet.Repetitions.Required, fields, -1)
   397  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   398  
   399  	writer := file.NewParquetWriter(sink, sc)
   400  
   401  	testKey := "testKey"
   402  	testValue := "testValue"
   403  	writer.AppendKeyValueMetadata(testKey, testValue)
   404  	writer.Close()
   405  
   406  	buffer := sink.Finish()
   407  	defer buffer.Release()
   408  	props := parquet.NewReaderProperties(memory.DefaultAllocator)
   409  	props.BufferedStreamEnabled = true
   410  
   411  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()), file.WithReadProps(props))
   412  	assert.NoError(t, err)
   413  
   414  	metadata := reader.MetaData()
   415  	got := metadata.KeyValueMetadata().FindValue(testKey)
   416  	require.NotNil(t, got)
   417  	assert.Equal(t, testValue, *got)
   418  }
   419  
   420  func createSerializeTestSuite(typ reflect.Type) suite.TestingSuite {
   421  	return &SerializeTestSuite{PrimitiveTypedTest: testutils.NewPrimitiveTypedTest(typ)}
   422  }
   423  
   424  func TestSerialize(t *testing.T) {
   425  	t.Parallel()
   426  	types := []struct {
   427  		typ reflect.Type
   428  	}{
   429  		{reflect.TypeOf(true)},
   430  		{reflect.TypeOf(int32(0))},
   431  		{reflect.TypeOf(int64(0))},
   432  		{reflect.TypeOf(float32(0))},
   433  		{reflect.TypeOf(float64(0))},
   434  		{reflect.TypeOf(parquet.Int96{})},
   435  		{reflect.TypeOf(parquet.ByteArray{})},
   436  	}
   437  	for _, tt := range types {
   438  		tt := tt
   439  		t.Run(tt.typ.String(), func(t *testing.T) {
   440  			t.Parallel()
   441  			suite.Run(t, createSerializeTestSuite(tt.typ))
   442  		})
   443  	}
   444  }
   445  
   446  type errCloseWriter struct {
   447  	sink *encoding.BufferWriter
   448  }
   449  
   450  func (c *errCloseWriter) Write(p []byte) (n int, err error) {
   451  	return c.sink.Write(p)
   452  }
   453  func (c *errCloseWriter) Close() error {
   454  	return fmt.Errorf("error during close")
   455  }
   456  func (c *errCloseWriter) Bytes() []byte {
   457  	return c.sink.Bytes()
   458  }
   459  
   460  func TestCloseError(t *testing.T) {
   461  	fields := schema.FieldList{schema.NewInt32Node("col", parquet.Repetitions.Required, 1)}
   462  	sc, _ := schema.NewGroupNode("schema", parquet.Repetitions.Required, fields, 0)
   463  	sink := &errCloseWriter{sink: encoding.NewBufferWriter(0, memory.DefaultAllocator)}
   464  	writer := file.NewParquetWriter(sink, sc)
   465  	assert.Error(t, writer.Close())
   466  }