github.com/apache/arrow/go/v14@v14.0.1/parquet/schema/schema_flatten_test.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package schema
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/apache/arrow/go/v14/parquet"
    23  	format "github.com/apache/arrow/go/v14/parquet/internal/gen-go/parquet"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/suite"
    26  )
    27  
    28  func NewPrimitive(name string, repetition format.FieldRepetitionType, typ format.Type, fieldID int32) *format.SchemaElement {
    29  	ret := &format.SchemaElement{
    30  		Name:           name,
    31  		RepetitionType: format.FieldRepetitionTypePtr(repetition),
    32  		Type:           format.TypePtr(typ),
    33  	}
    34  	if fieldID >= 0 {
    35  		ret.FieldID = &fieldID
    36  	}
    37  	return ret
    38  }
    39  
    40  func NewGroup(name string, repetition format.FieldRepetitionType, numChildren, fieldID int32) *format.SchemaElement {
    41  	ret := &format.SchemaElement{
    42  		Name:           name,
    43  		RepetitionType: format.FieldRepetitionTypePtr(repetition),
    44  		NumChildren:    &numChildren,
    45  	}
    46  	if fieldID >= 0 {
    47  		ret.FieldID = &fieldID
    48  	}
    49  	return ret
    50  }
    51  
    52  type SchemaFlattenSuite struct {
    53  	suite.Suite
    54  
    55  	name string
    56  }
    57  
    58  func (s *SchemaFlattenSuite) SetupSuite() {
    59  	s.name = "parquet_schema"
    60  }
    61  
    62  func (s *SchemaFlattenSuite) TestDecimalMetadata() {
    63  	group := MustGroup(NewGroupNodeConverted("group" /* name */, parquet.Repetitions.Repeated, FieldList{
    64  		MustPrimitive(NewPrimitiveNodeConverted("decimal" /* name */, parquet.Repetitions.Required, parquet.Types.Int64,
    65  			ConvertedTypes.Decimal, 0 /* type len */, 8 /* precision */, 4 /* scale */, -1 /* fieldID */)),
    66  	}, ConvertedTypes.List, -1 /* fieldID */))
    67  	elements := ToThrift(group)
    68  
    69  	s.Len(elements, 2)
    70  	s.Equal("decimal", elements[1].GetName())
    71  	s.True(elements[1].IsSetPrecision())
    72  	s.True(elements[1].IsSetScale())
    73  
    74  	group = MustGroup(NewGroupNodeLogical("group" /* name */, parquet.Repetitions.Repeated, FieldList{
    75  		MustPrimitive(NewPrimitiveNodeLogical("decimal" /* name */, parquet.Repetitions.Required, NewDecimalLogicalType(10 /* precision */, 5 /* scale */),
    76  			parquet.Types.Int64, 0 /* type len */, -1 /* fieldID */)),
    77  	}, NewListLogicalType(), -1 /* fieldID */))
    78  	elements = ToThrift(group)
    79  	s.Equal("decimal", elements[1].Name)
    80  	s.True(elements[1].IsSetPrecision())
    81  	s.True(elements[1].IsSetScale())
    82  
    83  	group = MustGroup(NewGroupNodeConverted("group" /* name */, parquet.Repetitions.Repeated, FieldList{
    84  		NewInt64Node("int64" /* name */, parquet.Repetitions.Required, -1 /* fieldID */)}, ConvertedTypes.List, -1 /* fieldID */))
    85  	elements = ToThrift(group)
    86  	s.Equal("int64", elements[1].Name)
    87  	s.False(elements[0].IsSetPrecision())
    88  	s.False(elements[1].IsSetPrecision())
    89  	s.False(elements[0].IsSetScale())
    90  	s.False(elements[1].IsSetScale())
    91  }
    92  
    93  func (s *SchemaFlattenSuite) TestNestedExample() {
    94  	elements := make([]*format.SchemaElement, 0)
    95  	elements = append(elements,
    96  		NewGroup(s.name, format.FieldRepetitionType_REPEATED, 2 /* numChildren */, 0 /* fieldID */),
    97  		NewPrimitive("a" /* name */, format.FieldRepetitionType_REQUIRED, format.Type_INT32, 1 /* fieldID */),
    98  		NewGroup("bag" /* name */, format.FieldRepetitionType_OPTIONAL, 1 /* numChildren */, 2 /* fieldID */))
    99  
   100  	elt := NewGroup("b" /* name */, format.FieldRepetitionType_REPEATED, 1 /* numChildren */, 3 /* fieldID */)
   101  	elt.ConvertedType = format.ConvertedTypePtr(format.ConvertedType_LIST)
   102  	elt.LogicalType = &format.LogicalType{LIST: format.NewListType()}
   103  	elements = append(elements, elt, NewPrimitive("item" /* name */, format.FieldRepetitionType_OPTIONAL, format.Type_INT64, 4 /* fieldID */))
   104  
   105  	fields := FieldList{NewInt32Node("a" /* name */, parquet.Repetitions.Required, 1 /* fieldID */)}
   106  	list := MustGroup(NewGroupNodeConverted("b" /* name */, parquet.Repetitions.Repeated, FieldList{
   107  		NewInt64Node("item" /* name */, parquet.Repetitions.Optional, 4 /* fieldID */)}, ConvertedTypes.List, 3 /* fieldID */))
   108  	fields = append(fields, MustGroup(NewGroupNode("bag" /* name */, parquet.Repetitions.Optional, FieldList{list}, 2 /* fieldID */)))
   109  
   110  	sc := MustGroup(NewGroupNode(s.name, parquet.Repetitions.Repeated, fields, 0 /* fieldID */))
   111  
   112  	flattened := ToThrift(sc)
   113  	s.Len(flattened, len(elements))
   114  	for idx, elem := range flattened {
   115  		s.Equal(elements[idx], elem)
   116  	}
   117  }
   118  
   119  func TestSchemaFlatten(t *testing.T) {
   120  	suite.Run(t, new(SchemaFlattenSuite))
   121  }
   122  
   123  func TestInvalidConvertedTypeInDeserialize(t *testing.T) {
   124  	n := MustPrimitive(NewPrimitiveNodeLogical("string" /* name */, parquet.Repetitions.Required, StringLogicalType{},
   125  		parquet.Types.ByteArray, -1 /* type len */, -1 /* fieldID */))
   126  	assert.True(t, n.LogicalType().Equals(StringLogicalType{}))
   127  	assert.True(t, n.LogicalType().IsValid())
   128  	assert.True(t, n.LogicalType().IsSerialized())
   129  	intermediary := n.toThrift()
   130  	// corrupt it
   131  	intermediary.LogicalType.STRING = nil
   132  	assert.Panics(t, func() {
   133  		PrimitiveNodeFromThrift(intermediary)
   134  	})
   135  }
   136  
   137  func TestInvalidTimeUnitInTimeLogical(t *testing.T) {
   138  	n := MustPrimitive(NewPrimitiveNodeLogical("time" /* name */, parquet.Repetitions.Required,
   139  		NewTimeLogicalType(true /* adjustedToUTC */, TimeUnitNanos), parquet.Types.Int64, -1 /* type len */, -1 /* fieldID */))
   140  	intermediary := n.toThrift()
   141  	// corrupt it
   142  	intermediary.LogicalType.TIME.Unit.NANOS = nil
   143  	assert.Panics(t, func() {
   144  		PrimitiveNodeFromThrift(intermediary)
   145  	})
   146  }
   147  
   148  func TestInvalidTimeUnitInTimestampLogical(t *testing.T) {
   149  	n := MustPrimitive(NewPrimitiveNodeLogical("time" /* name */, parquet.Repetitions.Required,
   150  		NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitNanos), parquet.Types.Int64, -1 /* type len */, -1 /* fieldID */))
   151  	intermediary := n.toThrift()
   152  	// corrupt it
   153  	intermediary.LogicalType.TIMESTAMP.Unit.NANOS = nil
   154  	assert.Panics(t, func() {
   155  		PrimitiveNodeFromThrift(intermediary)
   156  	})
   157  }