github.com/apache/arrow/go/v14@v14.0.2/parquet/file/file_writer_test.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package file_test
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"reflect"
    23  	"testing"
    24  
    25  	"github.com/apache/arrow/go/v14/arrow/memory"
    26  	"github.com/apache/arrow/go/v14/parquet"
    27  	"github.com/apache/arrow/go/v14/parquet/compress"
    28  	"github.com/apache/arrow/go/v14/parquet/file"
    29  	"github.com/apache/arrow/go/v14/parquet/internal/encoding"
    30  	"github.com/apache/arrow/go/v14/parquet/internal/testutils"
    31  	"github.com/apache/arrow/go/v14/parquet/schema"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"github.com/stretchr/testify/suite"
    35  )
    36  
    37  type SerializeTestSuite struct {
    38  	testutils.PrimitiveTypedTest
    39  	suite.Suite
    40  
    41  	numCols      int
    42  	numRowGroups int
    43  	rowsPerRG    int
    44  	rowsPerBatch int
    45  }
    46  
    47  func (t *SerializeTestSuite) SetupTest() {
    48  	t.numCols = 4
    49  	t.numRowGroups = 4
    50  	t.rowsPerRG = 50
    51  	t.rowsPerBatch = 10
    52  	t.SetupSchema(parquet.Repetitions.Optional, t.numCols)
    53  }
    54  
    55  func (t *SerializeTestSuite) fileSerializeTest(codec compress.Compression, expected compress.Compression) {
    56  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
    57  
    58  	opts := make([]parquet.WriterProperty, 0)
    59  	for i := 0; i < t.numCols; i++ {
    60  		opts = append(opts, parquet.WithCompressionFor(t.Schema.Column(i).Name(), codec))
    61  	}
    62  
    63  	props := parquet.NewWriterProperties(opts...)
    64  
    65  	writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props))
    66  	t.GenerateData(int64(t.rowsPerRG))
    67  	for rg := 0; rg < t.numRowGroups/2; rg++ {
    68  		rgw := writer.AppendRowGroup()
    69  		for col := 0; col < t.numCols; col++ {
    70  			cw, _ := rgw.NextColumn()
    71  			t.WriteBatchValues(cw, t.DefLevels, nil)
    72  			cw.Close()
    73  			// ensure column() api which is specific to bufferedrowgroups cannot be called
    74  			t.Panics(func() { rgw.(file.BufferedRowGroupWriter).Column(col) })
    75  		}
    76  		rgw.Close()
    77  	}
    78  
    79  	// write half buffered row groups
    80  	for rg := 0; rg < t.numRowGroups/2; rg++ {
    81  		rgw := writer.AppendBufferedRowGroup()
    82  		for batch := 0; batch < (t.rowsPerRG / t.rowsPerBatch); batch++ {
    83  			for col := 0; col < t.numCols; col++ {
    84  				cw, _ := rgw.Column(col)
    85  				offset := batch * t.rowsPerBatch
    86  				t.WriteBatchSubset(t.rowsPerBatch, offset, cw, t.DefLevels[offset:t.rowsPerBatch+offset], nil)
    87  				// Ensure NextColumn api which is specific to RowGroup cannot be called
    88  				t.Panics(func() { rgw.(file.SerialRowGroupWriter).NextColumn() })
    89  			}
    90  		}
    91  		for col := 0; col < t.numCols; col++ {
    92  			cw, _ := rgw.Column(col)
    93  			cw.Close()
    94  		}
    95  		rgw.Close()
    96  	}
    97  	writer.Close()
    98  
    99  	nrows := t.numRowGroups * t.rowsPerRG
   100  	reader, err := file.NewParquetReader(bytes.NewReader(sink.Bytes()))
   101  	t.NoError(err)
   102  	t.Equal(t.numCols, reader.MetaData().Schema.NumColumns())
   103  	t.Equal(t.numRowGroups, reader.NumRowGroups())
   104  	t.EqualValues(nrows, reader.NumRows())
   105  
   106  	for rg := 0; rg < t.numRowGroups; rg++ {
   107  		rgr := reader.RowGroup(rg)
   108  		t.Equal(t.numCols, rgr.NumColumns())
   109  		t.EqualValues(t.rowsPerRG, rgr.NumRows())
   110  		chunk, _ := rgr.MetaData().ColumnChunk(0)
   111  		t.Equal(expected, chunk.Compression())
   112  
   113  		valuesRead := int64(0)
   114  
   115  		for i := 0; i < t.numCols; i++ {
   116  			chunk, _ := rgr.MetaData().ColumnChunk(i)
   117  			t.False(chunk.HasIndexPage())
   118  			t.DefLevelsOut = make([]int16, t.rowsPerRG)
   119  			t.RepLevelsOut = make([]int16, t.rowsPerRG)
   120  			colReader, err := rgr.Column(i)
   121  			t.NoError(err)
   122  			t.SetupValuesOut(int64(t.rowsPerRG))
   123  			valuesRead = t.ReadBatch(colReader, int64(t.rowsPerRG), 0, t.DefLevelsOut, t.RepLevelsOut)
   124  			t.EqualValues(t.rowsPerRG, valuesRead)
   125  			t.Equal(t.Values, t.ValuesOut)
   126  			t.Equal(t.DefLevels, t.DefLevelsOut)
   127  		}
   128  	}
   129  }
   130  
   131  func (t *SerializeTestSuite) unequalNumRows(maxRows int64, rowsPerCol []int64) {
   132  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   133  	props := parquet.NewWriterProperties()
   134  	writer := file.NewParquetWriter(sink, t.Schema.Root(), file.WithWriterProps(props))
   135  	defer writer.Close()
   136  
   137  	rgw := writer.AppendRowGroup()
   138  	t.GenerateData(maxRows)
   139  	for col := 0; col < t.numCols; col++ {
   140  		cw, _ := rgw.NextColumn()
   141  		t.WriteBatchSubset(int(rowsPerCol[col]), 0, cw, t.DefLevels[:rowsPerCol[col]], nil)
   142  		cw.Close()
   143  	}
   144  	err := rgw.Close()
   145  	t.Error(err)
   146  	t.ErrorContains(err, "row mismatch for unbuffered row group")
   147  }
   148  
   149  func (t *SerializeTestSuite) unequalNumRowsBuffered(maxRows int64, rowsPerCol []int64) {
   150  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   151  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   152  	defer writer.Close()
   153  
   154  	rgw := writer.AppendBufferedRowGroup()
   155  	t.GenerateData(maxRows)
   156  	for col := 0; col < t.numCols; col++ {
   157  		cw, _ := rgw.Column(col)
   158  		t.WriteBatchSubset(int(rowsPerCol[col]), 0, cw, t.DefLevels[:rowsPerCol[col]], nil)
   159  		cw.Close()
   160  	}
   161  	err := rgw.Close()
   162  	t.Error(err)
   163  	t.ErrorContains(err, "row mismatch for buffered row group")
   164  }
   165  
   166  func (t *SerializeTestSuite) TestZeroRows() {
   167  	t.NotPanics(func() {
   168  		sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   169  		writer := file.NewParquetWriter(sink, t.Schema.Root())
   170  		defer writer.Close()
   171  
   172  		srgw := writer.AppendRowGroup()
   173  		for col := 0; col < t.numCols; col++ {
   174  			cw, _ := srgw.NextColumn()
   175  			cw.Close()
   176  		}
   177  		srgw.Close()
   178  
   179  		brgw := writer.AppendBufferedRowGroup()
   180  		for col := 0; col < t.numCols; col++ {
   181  			cw, _ := brgw.Column(col)
   182  			cw.Close()
   183  		}
   184  		brgw.Close()
   185  	})
   186  }
   187  
   188  func (t *SerializeTestSuite) TestTooManyColumns() {
   189  	t.SetupSchema(parquet.Repetitions.Optional, 1)
   190  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   191  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   192  	rgw := writer.AppendRowGroup()
   193  
   194  	rgw.NextColumn()                      // first column
   195  	t.Panics(func() { rgw.NextColumn() }) // only one column!
   196  }
   197  
   198  func (t *SerializeTestSuite) TestRepeatedTooFewRows() {
   199  	// optional and repeated, so definition and repetition levels
   200  	t.SetupSchema(parquet.Repetitions.Repeated, 1)
   201  	const nrows = 100
   202  	t.GenerateData(nrows)
   203  
   204  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   205  	writer := file.NewParquetWriter(sink, t.Schema.Root())
   206  
   207  	rgw := writer.AppendRowGroup()
   208  	t.RepLevels = make([]int16, nrows)
   209  	for idx := range t.RepLevels {
   210  		t.RepLevels[idx] = 0
   211  	}
   212  
   213  	cw, _ := rgw.NextColumn()
   214  	t.WriteBatchValues(cw, t.DefLevels, t.RepLevels)
   215  	cw.Close()
   216  
   217  	t.RepLevels[3] = 1 // this makes it so that values 2 and 3 are a single row
   218  	// as a result there's one too few rows in the result
   219  
   220  	t.Panics(func() {
   221  		cw, _ = rgw.NextColumn()
   222  		t.WriteBatchValues(cw, t.DefLevels, t.RepLevels)
   223  		cw.Close()
   224  	})
   225  }
   226  
   227  func (t *SerializeTestSuite) TestTooFewRows() {
   228  	rowsPerCol := []int64{100, 100, 100, 99}
   229  	t.NotPanics(func() { t.unequalNumRows(100, rowsPerCol) })
   230  	t.NotPanics(func() { t.unequalNumRowsBuffered(100, rowsPerCol) })
   231  }
   232  
   233  func (t *SerializeTestSuite) TestTooManyRows() {
   234  	rowsPerCol := []int64{100, 100, 100, 101}
   235  	t.NotPanics(func() { t.unequalNumRows(101, rowsPerCol) })
   236  	t.NotPanics(func() { t.unequalNumRowsBuffered(101, rowsPerCol) })
   237  }
   238  
   239  func (t *SerializeTestSuite) TestSmallFile() {
   240  	codecs := []compress.Compression{
   241  		compress.Codecs.Uncompressed,
   242  		compress.Codecs.Snappy,
   243  		compress.Codecs.Brotli,
   244  		compress.Codecs.Gzip,
   245  		compress.Codecs.Zstd,
   246  		// compress.Codecs.Lz4,
   247  		// compress.Codecs.Lzo,
   248  	}
   249  	for _, c := range codecs {
   250  		t.Run(c.String(), func() {
   251  			t.NotPanics(func() { t.fileSerializeTest(c, c) })
   252  		})
   253  	}
   254  }
   255  
   256  func TestBufferedDisabledDictionary(t *testing.T) {
   257  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   258  	fields := schema.FieldList{schema.NewInt32Node("col", parquet.Repetitions.Required, 1)}
   259  	sc, _ := schema.NewGroupNode("schema", parquet.Repetitions.Required, fields, 0)
   260  	props := parquet.NewWriterProperties(parquet.WithDictionaryDefault(false))
   261  
   262  	writer := file.NewParquetWriter(sink, sc, file.WithWriterProps(props))
   263  	rgw := writer.AppendBufferedRowGroup()
   264  	cwr, _ := rgw.Column(0)
   265  	cw := cwr.(*file.Int32ColumnChunkWriter)
   266  	cw.WriteBatch([]int32{1}, nil, nil)
   267  	rgw.Close()
   268  	writer.Close()
   269  
   270  	buffer := sink.Finish()
   271  	defer buffer.Release()
   272  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()))
   273  	assert.NoError(t, err)
   274  	assert.EqualValues(t, 1, reader.NumRowGroups())
   275  	rgReader := reader.RowGroup(0)
   276  	assert.EqualValues(t, 1, rgReader.NumRows())
   277  	chunk, _ := rgReader.MetaData().ColumnChunk(0)
   278  	assert.False(t, chunk.HasDictionaryPage())
   279  }
   280  
   281  func TestBufferedMultiPageDisabledDictionary(t *testing.T) {
   282  	const (
   283  		valueCount = 10000
   284  		pageSize   = 16384
   285  	)
   286  	var (
   287  		sink  = encoding.NewBufferWriter(0, memory.DefaultAllocator)
   288  		props = parquet.NewWriterProperties(parquet.WithDictionaryDefault(false), parquet.WithDataPageSize(pageSize))
   289  		sc, _ = schema.NewGroupNode("schema", parquet.Repetitions.Required, schema.FieldList{
   290  			schema.NewInt32Node("col", parquet.Repetitions.Required, -1),
   291  		}, -1)
   292  	)
   293  
   294  	writer := file.NewParquetWriter(sink, sc, file.WithWriterProps(props))
   295  	rgWriter := writer.AppendBufferedRowGroup()
   296  	cwr, _ := rgWriter.Column(0)
   297  	cw := cwr.(*file.Int32ColumnChunkWriter)
   298  	valuesIn := make([]int32, 0, valueCount)
   299  	for i := int32(0); i < valueCount; i++ {
   300  		valuesIn = append(valuesIn, (i%100)+1)
   301  	}
   302  	cw.WriteBatch(valuesIn, nil, nil)
   303  	rgWriter.Close()
   304  	writer.Close()
   305  	buffer := sink.Finish()
   306  	defer buffer.Release()
   307  
   308  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()))
   309  	assert.NoError(t, err)
   310  
   311  	assert.EqualValues(t, 1, reader.NumRowGroups())
   312  	valuesOut := make([]int32, valueCount)
   313  
   314  	for r := 0; r < reader.NumRowGroups(); r++ {
   315  		rgr := reader.RowGroup(r)
   316  		assert.EqualValues(t, 1, rgr.NumColumns())
   317  		assert.EqualValues(t, valueCount, rgr.NumRows())
   318  
   319  		var totalRead int64
   320  		col, err := rgr.Column(0)
   321  		assert.NoError(t, err)
   322  		colReader := col.(*file.Int32ColumnChunkReader)
   323  		for colReader.HasNext() {
   324  			total, _, _ := colReader.ReadBatch(valueCount-totalRead, valuesOut[totalRead:], nil, nil)
   325  			totalRead += total
   326  		}
   327  		assert.EqualValues(t, valueCount, totalRead)
   328  		assert.Equal(t, valuesIn, valuesOut)
   329  	}
   330  }
   331  
   332  func TestAllNulls(t *testing.T) {
   333  	sc, _ := schema.NewGroupNode("root", parquet.Repetitions.Required, schema.FieldList{
   334  		schema.NewInt32Node("nulls", parquet.Repetitions.Optional, -1),
   335  	}, -1)
   336  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   337  
   338  	writer := file.NewParquetWriter(sink, sc)
   339  	rgw := writer.AppendRowGroup()
   340  	cwr, _ := rgw.NextColumn()
   341  	cw := cwr.(*file.Int32ColumnChunkWriter)
   342  
   343  	var (
   344  		values    [3]int32
   345  		defLevels = [...]int16{0, 0, 0}
   346  	)
   347  
   348  	cw.WriteBatch(values[:], defLevels[:], nil)
   349  	cw.Close()
   350  	rgw.Close()
   351  	writer.Close()
   352  
   353  	buffer := sink.Finish()
   354  	defer buffer.Release()
   355  	props := parquet.NewReaderProperties(memory.DefaultAllocator)
   356  	props.BufferedStreamEnabled = true
   357  
   358  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()), file.WithReadProps(props))
   359  	assert.NoError(t, err)
   360  
   361  	rgr := reader.RowGroup(0)
   362  	col, err := rgr.Column(0)
   363  	assert.NoError(t, err)
   364  	cr := col.(*file.Int32ColumnChunkReader)
   365  
   366  	defLevels[0] = -1
   367  	defLevels[1] = -1
   368  	defLevels[2] = -1
   369  	valRead, read, _ := cr.ReadBatch(3, values[:], defLevels[:], nil)
   370  	assert.EqualValues(t, 3, valRead)
   371  	assert.EqualValues(t, 0, read)
   372  	assert.Equal(t, []int16{0, 0, 0}, defLevels[:])
   373  }
   374  
   375  func TestKeyValueMetadata(t *testing.T) {
   376  	fields := schema.FieldList{
   377  		schema.NewInt32Node("unused", parquet.Repetitions.Optional, -1),
   378  	}
   379  	sc, _ := schema.NewGroupNode("root", parquet.Repetitions.Required, fields, -1)
   380  	sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
   381  
   382  	writer := file.NewParquetWriter(sink, sc)
   383  
   384  	testKey := "testKey"
   385  	testValue := "testValue"
   386  	writer.AppendKeyValueMetadata(testKey, testValue)
   387  	writer.Close()
   388  
   389  	buffer := sink.Finish()
   390  	defer buffer.Release()
   391  	props := parquet.NewReaderProperties(memory.DefaultAllocator)
   392  	props.BufferedStreamEnabled = true
   393  
   394  	reader, err := file.NewParquetReader(bytes.NewReader(buffer.Bytes()), file.WithReadProps(props))
   395  	assert.NoError(t, err)
   396  
   397  	metadata := reader.MetaData()
   398  	got := metadata.KeyValueMetadata().FindValue(testKey)
   399  	require.NotNil(t, got)
   400  	assert.Equal(t, testValue, *got)
   401  }
   402  
   403  func createSerializeTestSuite(typ reflect.Type) suite.TestingSuite {
   404  	return &SerializeTestSuite{PrimitiveTypedTest: testutils.NewPrimitiveTypedTest(typ)}
   405  }
   406  
   407  func TestSerialize(t *testing.T) {
   408  	t.Parallel()
   409  	types := []struct {
   410  		typ reflect.Type
   411  	}{
   412  		{reflect.TypeOf(true)},
   413  		{reflect.TypeOf(int32(0))},
   414  		{reflect.TypeOf(int64(0))},
   415  		{reflect.TypeOf(float32(0))},
   416  		{reflect.TypeOf(float64(0))},
   417  		{reflect.TypeOf(parquet.Int96{})},
   418  		{reflect.TypeOf(parquet.ByteArray{})},
   419  	}
   420  	for _, tt := range types {
   421  		tt := tt
   422  		t.Run(tt.typ.String(), func(t *testing.T) {
   423  			t.Parallel()
   424  			suite.Run(t, createSerializeTestSuite(tt.typ))
   425  		})
   426  	}
   427  }
   428  
   429  type errCloseWriter struct {
   430  	sink *encoding.BufferWriter
   431  }
   432  
   433  func (c *errCloseWriter) Write(p []byte) (n int, err error) {
   434  	return c.sink.Write(p)
   435  }
   436  func (c *errCloseWriter) Close() error {
   437  	return fmt.Errorf("error during close")
   438  }
   439  func (c *errCloseWriter) Bytes() []byte {
   440  	return c.sink.Bytes()
   441  }
   442  
   443  func TestCloseError(t *testing.T) {
   444  	fields := schema.FieldList{schema.NewInt32Node("col", parquet.Repetitions.Required, 1)}
   445  	sc, _ := schema.NewGroupNode("schema", parquet.Repetitions.Required, fields, 0)
   446  	sink := &errCloseWriter{sink: encoding.NewBufferWriter(0, memory.DefaultAllocator)}
   447  	writer := file.NewParquetWriter(sink, sc)
   448  	assert.Error(t, writer.Close())
   449  }