github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/dictionary_test.go (about)

     1  package parquet_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math/rand"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/segmentio/parquet-go"
    11  )
    12  
    13  var dictionaryTypes = [...]parquet.Type{
    14  	parquet.BooleanType,
    15  	parquet.Int32Type,
    16  	parquet.Int64Type,
    17  	parquet.Int96Type,
    18  	parquet.FloatType,
    19  	parquet.DoubleType,
    20  	parquet.ByteArrayType,
    21  	parquet.FixedLenByteArrayType(10),
    22  	parquet.FixedLenByteArrayType(16),
    23  	parquet.Uint(32).Type(),
    24  	parquet.Uint(64).Type(),
    25  }
    26  
    27  func TestDictionary(t *testing.T) {
    28  	for _, typ := range dictionaryTypes {
    29  		for _, numValues := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 1e2, 1e3, 1e4} {
    30  			t.Run(fmt.Sprintf("%s/N=%d", typ, numValues), func(t *testing.T) {
    31  				testDictionary(t, typ, numValues)
    32  			})
    33  		}
    34  	}
    35  }
    36  
    37  func testDictionary(t *testing.T, typ parquet.Type, numValues int) {
    38  	const columnIndex = 1
    39  
    40  	dict := typ.NewDictionary(columnIndex, 0, typ.NewValues(nil, nil))
    41  	values := make([]parquet.Value, numValues)
    42  	indexes := make([]int32, numValues)
    43  	lookups := make([]parquet.Value, numValues)
    44  
    45  	f := randValueFuncOf(typ)
    46  	r := rand.New(rand.NewSource(int64(numValues)))
    47  
    48  	for i := range values {
    49  		values[i] = f(r)
    50  		values[i] = values[i].Level(0, 0, columnIndex)
    51  	}
    52  
    53  	mapping := make(map[int32]parquet.Value, numValues)
    54  
    55  	for i := 0; i < numValues; {
    56  		j := i + ((numValues-i)/2 + 1)
    57  		if j > numValues {
    58  			j = numValues
    59  		}
    60  
    61  		dict.Insert(indexes[i:j], values[i:j])
    62  
    63  		for k, v := range values[i:j] {
    64  			mapping[indexes[i+k]] = v
    65  		}
    66  
    67  		for _, index := range indexes[i:j] {
    68  			if index < 0 || index >= int32(dict.Len()) {
    69  				t.Fatalf("index out of bounds: %d", index)
    70  			}
    71  		}
    72  
    73  		// second insert is a no-op since all the values are already in the dictionary
    74  		lastDictLen := dict.Len()
    75  		dict.Insert(indexes[i:j], values[i:j])
    76  
    77  		if dict.Len() != lastDictLen {
    78  			for k, index := range indexes[i:j] {
    79  				if index >= int32(len(mapping)) {
    80  					t.Log(values[i+k])
    81  				}
    82  			}
    83  
    84  			t.Fatalf("%d values were inserted on the second pass", dict.Len()-len(mapping))
    85  		}
    86  
    87  		r.Shuffle(j-i, func(a, b int) {
    88  			indexes[a+i], indexes[b+i] = indexes[b+i], indexes[a+i]
    89  		})
    90  
    91  		dict.Lookup(indexes[i:j], lookups[i:j])
    92  
    93  		for lookupIndex, valueIndex := range indexes[i:j] {
    94  			want := mapping[valueIndex]
    95  			got := lookups[lookupIndex+i]
    96  
    97  			if !parquet.DeepEqual(want, got) {
    98  				t.Fatalf("wrong value looked up at index %d: want=%#v got=%#v", valueIndex, want, got)
    99  			}
   100  		}
   101  
   102  		minValue := values[i]
   103  		maxValue := values[i]
   104  
   105  		for _, value := range values[i+1 : j] {
   106  			switch {
   107  			case typ.Compare(value, minValue) < 0:
   108  				minValue = value
   109  			case typ.Compare(value, maxValue) > 0:
   110  				maxValue = value
   111  			}
   112  		}
   113  
   114  		lowerBound, upperBound := dict.Bounds(indexes[i:j])
   115  		if !parquet.DeepEqual(lowerBound, minValue) {
   116  			t.Errorf("wrong lower bound between indexes %d and %d: want=%#v got=%#v", i, j, minValue, lowerBound)
   117  		}
   118  		if !parquet.DeepEqual(upperBound, maxValue) {
   119  			t.Errorf("wrong upper bound between indexes %d and %d: want=%#v got=%#v", i, j, maxValue, upperBound)
   120  		}
   121  
   122  		i = j
   123  	}
   124  
   125  	for i := range lookups {
   126  		lookups[i] = parquet.Value{}
   127  	}
   128  
   129  	dict.Lookup(indexes, lookups)
   130  
   131  	for lookupIndex, valueIndex := range indexes {
   132  		want := mapping[valueIndex]
   133  		got := lookups[lookupIndex]
   134  
   135  		if !parquet.Equal(want, got) {
   136  			t.Fatalf("wrong value looked up at index %d: want=%+v got=%+v", valueIndex, want, got)
   137  		}
   138  	}
   139  }
   140  
   141  func BenchmarkDictionary(b *testing.B) {
   142  	tests := []struct {
   143  		scenario string
   144  		init     func(parquet.Dictionary, []int32, []parquet.Value)
   145  		test     func(parquet.Dictionary, []int32, []parquet.Value)
   146  	}{
   147  		{
   148  			scenario: "Bounds",
   149  			init:     parquet.Dictionary.Insert,
   150  			test: func(dict parquet.Dictionary, indexes []int32, _ []parquet.Value) {
   151  				dict.Bounds(indexes)
   152  			},
   153  		},
   154  
   155  		{
   156  			scenario: "Insert",
   157  			test:     parquet.Dictionary.Insert,
   158  		},
   159  
   160  		{
   161  			scenario: "Lookup",
   162  			init:     parquet.Dictionary.Insert,
   163  			test:     parquet.Dictionary.Lookup,
   164  		},
   165  	}
   166  
   167  	for i, test := range tests {
   168  		b.Run(test.scenario, func(b *testing.B) {
   169  			for j, typ := range dictionaryTypes {
   170  				for _, numValues := range []int{1e2, 1e3, 1e4, 1e5, 1e6} {
   171  					buf := typ.NewValues(make([]byte, 0, 4*numValues), nil)
   172  					dict := typ.NewDictionary(0, 0, buf)
   173  					values := make([]parquet.Value, numValues)
   174  
   175  					f := randValueFuncOf(typ)
   176  					r := rand.New(rand.NewSource(int64(i * j * numValues)))
   177  
   178  					for i := range values {
   179  						values[i] = f(r)
   180  					}
   181  
   182  					indexes := make([]int32, len(values))
   183  					if test.init != nil {
   184  						test.init(dict, indexes, values)
   185  					}
   186  
   187  					b.Run(fmt.Sprintf("%s/N=%d", typ, numValues), func(b *testing.B) {
   188  						start := time.Now()
   189  
   190  						for i := 0; i < b.N; i++ {
   191  							test.test(dict, indexes, values)
   192  						}
   193  
   194  						seconds := time.Since(start).Seconds()
   195  						b.ReportMetric(float64(numValues*b.N)/seconds, "value/s")
   196  					})
   197  				}
   198  			}
   199  		})
   200  	}
   201  }
   202  
   203  func TestIssue312(t *testing.T) {
   204  	node := parquet.String()
   205  	node = parquet.Encoded(node, &parquet.RLEDictionary)
   206  	g := parquet.Group{}
   207  	g["mystring"] = node
   208  	schema := parquet.NewSchema("test", g)
   209  
   210  	rows := []parquet.Row{[]parquet.Value{parquet.ValueOf("hello").Level(0, 0, 0)}}
   211  
   212  	var storage bytes.Buffer
   213  
   214  	tests := []struct {
   215  		name        string
   216  		getRowGroup func(t *testing.T) parquet.RowGroup
   217  	}{
   218  		{
   219  			name: "Writer",
   220  			getRowGroup: func(t *testing.T) parquet.RowGroup {
   221  				t.Helper()
   222  
   223  				w := parquet.NewWriter(&storage, schema)
   224  				_, err := w.WriteRows(rows)
   225  				if err != nil {
   226  					t.Fatal(err)
   227  				}
   228  				if err := w.Close(); err != nil {
   229  					t.Fatal(err)
   230  				}
   231  
   232  				r := bytes.NewReader(storage.Bytes())
   233  				f, err := parquet.OpenFile(r, int64(storage.Len()))
   234  				if err != nil {
   235  					t.Fatal(err)
   236  				}
   237  				return f.RowGroups()[0]
   238  			},
   239  		},
   240  		{
   241  			name: "Buffer",
   242  			getRowGroup: func(t *testing.T) parquet.RowGroup {
   243  				t.Helper()
   244  
   245  				b := parquet.NewBuffer(schema)
   246  				_, err := b.WriteRows(rows)
   247  				if err != nil {
   248  					t.Fatal(err)
   249  				}
   250  				return b
   251  			},
   252  		},
   253  	}
   254  
   255  	for _, testCase := range tests {
   256  		t.Run(testCase.name, func(t *testing.T) {
   257  			rowGroup := testCase.getRowGroup(t)
   258  
   259  			chunk := rowGroup.ColumnChunks()[0]
   260  			idx := chunk.ColumnIndex()
   261  			val := idx.MinValue(0)
   262  			columnType := chunk.Type()
   263  			values := columnType.NewValues(val.Bytes(), []uint32{0, uint32(len(val.Bytes()))})
   264  
   265  			// This test ensures that the dictionary type created by column
   266  			// chunks of parquet readers and buffers are the same. We want the
   267  			// column chunk type to be the actual value type, even when the
   268  			// schema uses a dictionary encoding.
   269  			//
   270  			// https://github.com/segmentio/parquet-go/issues/312
   271  			_ = columnType.NewDictionary(0, 1, values)
   272  		})
   273  	}
   274  }