github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/encoding/test/test.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/parquet-go/parquet-go/encoding"
     8  )
     9  
    10  func EncodeInt32(t *testing.T, enc encoding.Encoding, min, max int, bitWidth uint) {
    11  	t.Helper()
    12  	encode(t, enc, min, max,
    13  		encoding.Encoding.EncodeInt32,
    14  		encoding.Encoding.DecodeInt32,
    15  		func(i int) int32 {
    16  			value := int32(i)
    17  			mask := int32((1 << bitWidth) - 1)
    18  			if (i % 2) != 0 {
    19  				value = -value
    20  			}
    21  			return value & mask
    22  		},
    23  	)
    24  }
    25  
    26  func EncodeInt64(t *testing.T, enc encoding.Encoding, min, max int, bitWidth uint) {
    27  	t.Helper()
    28  	encode(t, enc, min, max,
    29  		encoding.Encoding.EncodeInt64,
    30  		encoding.Encoding.DecodeInt64,
    31  		func(i int) int64 {
    32  			value := int64(i)
    33  			mask := int64((1 << bitWidth) - 1)
    34  			if (i % 2) != 0 {
    35  				value = -value
    36  			}
    37  			return value & mask
    38  		},
    39  	)
    40  }
    41  
    42  func EncodeFloat(t *testing.T, enc encoding.Encoding, min, max int) {
    43  	t.Helper()
    44  	encode(t, enc, min, max,
    45  		encoding.Encoding.EncodeFloat,
    46  		encoding.Encoding.DecodeFloat,
    47  		func(i int) float32 { return float32(i) },
    48  	)
    49  }
    50  
    51  func EncodeDouble(t *testing.T, enc encoding.Encoding, min, max int) {
    52  	t.Helper()
    53  	encode(t, enc, min, max,
    54  		encoding.Encoding.EncodeDouble,
    55  		encoding.Encoding.DecodeDouble,
    56  		func(i int) float64 { return float64(i) },
    57  	)
    58  }
    59  
    60  type encodingFunc[T comparable] func(encoding.Encoding, []byte, []T) ([]byte, error)
    61  
    62  type decodingFunc[T comparable] func(encoding.Encoding, []T, []byte) ([]T, error)
    63  
    64  func encode[T comparable](t *testing.T, enc encoding.Encoding, min, max int, encode encodingFunc[T], decode decodingFunc[T], valueOf func(int) T) {
    65  	t.Helper()
    66  
    67  	for k := min; k <= max; k++ {
    68  		t.Run(fmt.Sprintf("N=%d", k), func(t *testing.T) {
    69  			src := make([]T, k)
    70  			for i := range src {
    71  				src[i] = valueOf(i)
    72  			}
    73  
    74  			buf, err := encode(enc, nil, src)
    75  			if err != nil {
    76  				t.Fatalf("encoding %d values: %v", k, err)
    77  			}
    78  
    79  			res, err := decode(enc, nil, buf)
    80  			if err != nil {
    81  				t.Fatalf("decoding %d values: %v", k, err)
    82  			}
    83  
    84  			if err := assertEqual(src, res); err != nil {
    85  				t.Fatalf("testing %d values: %v", k, err)
    86  			}
    87  		})
    88  	}
    89  }
    90  
    91  func assertEqual[T comparable](want, got []T) error {
    92  	if len(want) != len(got) {
    93  		return fmt.Errorf("number of values mismatch: want=%d got=%d", len(want), len(got))
    94  	}
    95  
    96  	for i := range want {
    97  		if want[i] != got[i] {
    98  			return fmt.Errorf("values at index %d/%d mismatch: want=%+v got=%+v", i, len(want), want[i], got[i])
    99  		}
   100  	}
   101  
   102  	return nil
   103  }