github.com/vc42/parquet-go@v0.0.0-20240320194221-1a9adb5f23f5/bloom_test.go (about)

     1  package parquet
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  
     7  	"github.com/vc42/parquet-go/bloom"
     8  	"github.com/vc42/parquet-go/deprecated"
     9  	"github.com/vc42/parquet-go/encoding/plain"
    10  	"github.com/vc42/parquet-go/internal/quick"
    11  	"github.com/vc42/parquet-go/internal/unsafecast"
    12  )
    13  
    14  func TestSplitBlockFilter(t *testing.T) {
    15  	newFilter := func(numValues int) bloom.SplitBlockFilter {
    16  		return make(bloom.SplitBlockFilter, bloom.NumSplitBlocksOf(int64(numValues), 11))
    17  	}
    18  
    19  	encoding := SplitBlockFilter("$").Encoding()
    20  
    21  	check := func(filter bloom.SplitBlockFilter, value Value) bool {
    22  		return filter.Check(value.hash(&bloom.XXH64{}))
    23  	}
    24  
    25  	tests := []struct {
    26  		scenario string
    27  		function interface{}
    28  	}{
    29  		{
    30  			scenario: "BOOLEAN",
    31  			function: func(values []bool) bool {
    32  				filter := newFilter(len(values))
    33  				encoding.EncodeBoolean(filter.Bytes(), unsafecast.BoolToBytes(values))
    34  				for _, v := range values {
    35  					if !check(filter, ValueOf(v)) {
    36  						return false
    37  					}
    38  				}
    39  				return true
    40  			},
    41  		},
    42  
    43  		{
    44  			scenario: "INT32",
    45  			function: func(values []int32) bool {
    46  				filter := newFilter(len(values))
    47  				encoding.EncodeInt32(filter.Bytes(), unsafecast.Int32ToBytes(values))
    48  				for _, v := range values {
    49  					if !check(filter, ValueOf(v)) {
    50  						return false
    51  					}
    52  				}
    53  				return true
    54  			},
    55  		},
    56  
    57  		{
    58  			scenario: "INT64",
    59  			function: func(values []int64) bool {
    60  				filter := newFilter(len(values))
    61  				encoding.EncodeInt64(filter.Bytes(), unsafecast.Int64ToBytes(values))
    62  				for _, v := range values {
    63  					if !check(filter, ValueOf(v)) {
    64  						return false
    65  					}
    66  				}
    67  				return true
    68  			},
    69  		},
    70  
    71  		{
    72  			scenario: "INT96",
    73  			function: func(values []deprecated.Int96) bool {
    74  				filter := newFilter(len(values))
    75  				encoding.EncodeInt96(filter.Bytes(), deprecated.Int96ToBytes(values))
    76  				for _, v := range values {
    77  					if !check(filter, ValueOf(v)) {
    78  						return false
    79  					}
    80  				}
    81  				return true
    82  			},
    83  		},
    84  
    85  		{
    86  			scenario: "FLOAT",
    87  			function: func(values []float32) bool {
    88  				filter := newFilter(len(values))
    89  				encoding.EncodeFloat(filter.Bytes(), unsafecast.Float32ToBytes(values))
    90  				for _, v := range values {
    91  					if !check(filter, ValueOf(v)) {
    92  						return false
    93  					}
    94  				}
    95  				return true
    96  			},
    97  		},
    98  
    99  		{
   100  			scenario: "DOUBLE",
   101  			function: func(values []float64) bool {
   102  				filter := newFilter(len(values))
   103  				encoding.EncodeDouble(filter.Bytes(), unsafecast.Float64ToBytes(values))
   104  				for _, v := range values {
   105  					if !check(filter, ValueOf(v)) {
   106  						return false
   107  					}
   108  				}
   109  				return true
   110  			},
   111  		},
   112  
   113  		{
   114  			scenario: "BYTE_ARRAY",
   115  			function: func(values [][]byte) bool {
   116  				byteArrays := make([]byte, 0)
   117  				for _, value := range values {
   118  					byteArrays = plain.AppendByteArray(byteArrays, value)
   119  				}
   120  				filter := newFilter(len(values))
   121  				encoding.EncodeByteArray(filter.Bytes(), byteArrays)
   122  				for _, v := range values {
   123  					if !check(filter, ValueOf(v)) {
   124  						return false
   125  					}
   126  				}
   127  				return true
   128  			},
   129  		},
   130  
   131  		{
   132  			scenario: "FIXED_LEN_BYTE_ARRAY",
   133  			function: func(values []byte) bool {
   134  				filter := newFilter(len(values))
   135  				encoding.EncodeFixedLenByteArray(filter.Bytes(), values, 1)
   136  				for _, v := range values {
   137  					if !check(filter, ValueOf([1]byte{v})) {
   138  						return false
   139  					}
   140  				}
   141  				return true
   142  			},
   143  		},
   144  	}
   145  
   146  	for _, test := range tests {
   147  		t.Run(test.scenario, func(t *testing.T) {
   148  			if err := quick.Check(test.function); err != nil {
   149  				t.Error(err)
   150  			}
   151  		})
   152  	}
   153  }
   154  
   155  func BenchmarkSplitBlockFilter(b *testing.B) {
   156  	const N = 1000
   157  	f := make(bloom.SplitBlockFilter, bloom.NumSplitBlocksOf(N, 10)).Bytes()
   158  	e := SplitBlockFilter("$").Encoding()
   159  
   160  	v := make([]int64, N)
   161  	r := rand.NewSource(10)
   162  	for i := range v {
   163  		v[i] = r.Int63()
   164  	}
   165  
   166  	v64 := unsafecast.Int64ToBytes(v)
   167  	for i := 0; i < b.N; i++ {
   168  		e.EncodeInt64(f, v64)
   169  	}
   170  
   171  	b.SetBytes(8 * N)
   172  }