github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/bloom_test.go (about)

     1  package parquet
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  
     7  	"github.com/parquet-go/parquet-go/bloom"
     8  	"github.com/parquet-go/parquet-go/deprecated"
     9  	"github.com/parquet-go/parquet-go/internal/quick"
    10  	"github.com/parquet-go/parquet-go/internal/unsafecast"
    11  )
    12  
    13  func TestSplitBlockFilter(t *testing.T) {
    14  	newFilter := func(numValues int) bloom.SplitBlockFilter {
    15  		return make(bloom.SplitBlockFilter, bloom.NumSplitBlocksOf(int64(numValues), 11))
    16  	}
    17  
    18  	enc := SplitBlockFilter(10, "$").Encoding()
    19  
    20  	check := func(filter bloom.SplitBlockFilter, value Value) bool {
    21  		return filter.Check(value.hash(&bloom.XXH64{}))
    22  	}
    23  
    24  	tests := []struct {
    25  		scenario string
    26  		function interface{}
    27  	}{
    28  		{
    29  			scenario: "BOOLEAN",
    30  			function: func(values []bool) bool {
    31  				filter := newFilter(len(values))
    32  				enc.EncodeBoolean(filter.Bytes(), unsafecast.BoolToBytes(values))
    33  				for _, v := range values {
    34  					if !check(filter, ValueOf(v)) {
    35  						return false
    36  					}
    37  				}
    38  				return true
    39  			},
    40  		},
    41  
    42  		{
    43  			scenario: "INT32",
    44  			function: func(values []int32) bool {
    45  				filter := newFilter(len(values))
    46  				enc.EncodeInt32(filter.Bytes(), values)
    47  				for _, v := range values {
    48  					if !check(filter, ValueOf(v)) {
    49  						return false
    50  					}
    51  				}
    52  				return true
    53  			},
    54  		},
    55  
    56  		{
    57  			scenario: "INT64",
    58  			function: func(values []int64) bool {
    59  				filter := newFilter(len(values))
    60  				enc.EncodeInt64(filter.Bytes(), values)
    61  				for _, v := range values {
    62  					if !check(filter, ValueOf(v)) {
    63  						return false
    64  					}
    65  				}
    66  				return true
    67  			},
    68  		},
    69  
    70  		{
    71  			scenario: "INT96",
    72  			function: func(values []deprecated.Int96) bool {
    73  				filter := newFilter(len(values))
    74  				enc.EncodeInt96(filter.Bytes(), values)
    75  				for _, v := range values {
    76  					if !check(filter, ValueOf(v)) {
    77  						return false
    78  					}
    79  				}
    80  				return true
    81  			},
    82  		},
    83  
    84  		{
    85  			scenario: "FLOAT",
    86  			function: func(values []float32) bool {
    87  				filter := newFilter(len(values))
    88  				enc.EncodeFloat(filter.Bytes(), values)
    89  				for _, v := range values {
    90  					if !check(filter, ValueOf(v)) {
    91  						return false
    92  					}
    93  				}
    94  				return true
    95  			},
    96  		},
    97  
    98  		{
    99  			scenario: "DOUBLE",
   100  			function: func(values []float64) bool {
   101  				filter := newFilter(len(values))
   102  				enc.EncodeDouble(filter.Bytes(), values)
   103  				for _, v := range values {
   104  					if !check(filter, ValueOf(v)) {
   105  						return false
   106  					}
   107  				}
   108  				return true
   109  			},
   110  		},
   111  
   112  		{
   113  			scenario: "BYTE_ARRAY",
   114  			function: func(values [][]byte) bool {
   115  				content := make([]byte, 0, 512)
   116  				offsets := make([]uint32, len(values))
   117  				for _, value := range values {
   118  					offsets = append(offsets, uint32(len(content)))
   119  					content = append(content, value...)
   120  				}
   121  				offsets = append(offsets, uint32(len(content)))
   122  				filter := newFilter(len(values))
   123  				enc.EncodeByteArray(filter.Bytes(), content, offsets)
   124  				for _, v := range values {
   125  					if !check(filter, ValueOf(v)) {
   126  						return false
   127  					}
   128  				}
   129  				return true
   130  			},
   131  		},
   132  
   133  		{
   134  			scenario: "FIXED_LEN_BYTE_ARRAY",
   135  			function: func(values []byte) bool {
   136  				filter := newFilter(len(values))
   137  				enc.EncodeFixedLenByteArray(filter.Bytes(), values, 1)
   138  				for _, v := range values {
   139  					if !check(filter, ValueOf([1]byte{v})) {
   140  						return false
   141  					}
   142  				}
   143  				return true
   144  			},
   145  		},
   146  	}
   147  
   148  	for _, test := range tests {
   149  		t.Run(test.scenario, func(t *testing.T) {
   150  			if err := quick.Check(test.function); err != nil {
   151  				t.Error(err)
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  func BenchmarkSplitBlockFilter(b *testing.B) {
   158  	const N = 1000
   159  	f := make(bloom.SplitBlockFilter, bloom.NumSplitBlocksOf(N, 10)).Bytes()
   160  	e := SplitBlockFilter(10, "$").Encoding()
   161  
   162  	v := make([]int64, N)
   163  	r := rand.NewSource(10)
   164  	for i := range v {
   165  		v[i] = r.Int63()
   166  	}
   167  
   168  	for i := 0; i < b.N; i++ {
   169  		e.EncodeInt64(f, v)
   170  	}
   171  
   172  	b.SetBytes(8 * N)
   173  }