github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/encoding/fuzz/fuzz.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  // Package fuzz contains functions to help fuzz test parquet encodings.
     5  package fuzz
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  	"unsafe"
    11  
    12  	"github.com/segmentio/parquet-go/encoding"
    13  	"github.com/segmentio/parquet-go/internal/unsafecast"
    14  )
    15  
    16  func EncodeBoolean(f *testing.F, e encoding.Encoding) {
    17  	encode(f, e,
    18  		encoding.Encoding.EncodeBoolean,
    19  		encoding.Encoding.DecodeBoolean,
    20  		generate[byte],
    21  	)
    22  }
    23  
    24  func EncodeLevels(f *testing.F, e encoding.Encoding) {
    25  	encode(f, e,
    26  		encoding.Encoding.EncodeLevels,
    27  		encoding.Encoding.DecodeLevels,
    28  		generate[byte],
    29  	)
    30  }
    31  
    32  func EncodeInt32(f *testing.F, e encoding.Encoding) {
    33  	encode(f, e,
    34  		encoding.Encoding.EncodeInt32,
    35  		encoding.Encoding.DecodeInt32,
    36  		generate[int32],
    37  	)
    38  }
    39  
    40  func EncodeInt64(f *testing.F, e encoding.Encoding) {
    41  	encode(f, e,
    42  		encoding.Encoding.EncodeInt64,
    43  		encoding.Encoding.DecodeInt64,
    44  		generate[int64],
    45  	)
    46  }
    47  
    48  func EncodeFloat(f *testing.F, e encoding.Encoding) {
    49  	encode(f, e,
    50  		encoding.Encoding.EncodeFloat,
    51  		encoding.Encoding.DecodeFloat,
    52  		generate[float32],
    53  	)
    54  }
    55  
    56  func EncodeDouble(f *testing.F, e encoding.Encoding) {
    57  	encode(f, e,
    58  		encoding.Encoding.EncodeDouble,
    59  		encoding.Encoding.DecodeDouble,
    60  		generate[float64],
    61  	)
    62  }
    63  
    64  func EncodeByteArray(f *testing.F, e encoding.Encoding) {
    65  	encode(f, e,
    66  		func(enc encoding.Encoding, dst []byte, src []string) ([]byte, error) {
    67  			size := 0
    68  			for _, s := range src {
    69  				size += len(s)
    70  			}
    71  
    72  			offsets := make([]uint32, 0, len(src)+1)
    73  			values := make([]byte, 0, size)
    74  
    75  			for _, s := range src {
    76  				offsets = append(offsets, uint32(len(values)))
    77  				values = append(values, s...)
    78  			}
    79  
    80  			offsets = append(offsets, uint32(len(values)))
    81  			return enc.EncodeByteArray(dst, values, offsets)
    82  		},
    83  
    84  		func(enc encoding.Encoding, dst []string, src []byte) ([]string, error) {
    85  			dst = dst[:0]
    86  
    87  			values, offsets, err := enc.DecodeByteArray(nil, src, nil)
    88  			if err != nil {
    89  				return dst, err
    90  			}
    91  
    92  			if len(offsets) > 0 {
    93  				baseOffset := offsets[0]
    94  
    95  				for _, endOffset := range offsets[1:] {
    96  					dst = append(dst, unsafecast.BytesToString(values[baseOffset:endOffset]))
    97  					baseOffset = endOffset
    98  				}
    99  			}
   100  
   101  			return dst, nil
   102  		},
   103  
   104  		func(dst []string, src []byte, prng *rand.Rand) []string {
   105  			limit := len(src)/10 + 1
   106  
   107  			for i := 0; i < len(src); {
   108  				n := prng.Intn(limit) + 1
   109  				r := len(src) - i
   110  				if n > r {
   111  					n = r
   112  				}
   113  				dst = append(dst, unsafecast.BytesToString(src[i:i+n]))
   114  				i += n
   115  			}
   116  
   117  			return dst
   118  		},
   119  	)
   120  }
   121  
   122  type encodingFunc[T comparable] func(encoding.Encoding, []byte, []T) ([]byte, error)
   123  
   124  type decodingFunc[T comparable] func(encoding.Encoding, []T, []byte) ([]T, error)
   125  
   126  type generateFunc[T comparable] func(dst []T, src []byte, prng *rand.Rand) []T
   127  
   128  func encode[T comparable](f *testing.F, e encoding.Encoding, encode encodingFunc[T], decode decodingFunc[T], generate generateFunc[T]) {
   129  	const bufferSize = 64 * 1024
   130  	var zero T
   131  	var err error
   132  	var buf = make([]T, bufferSize/unsafe.Sizeof(zero))
   133  	var src = make([]T, bufferSize/unsafe.Sizeof(zero))
   134  	var dst = make([]byte, bufferSize)
   135  	var prng = rand.New(rand.NewSource(0))
   136  
   137  	f.Fuzz(func(t *testing.T, input []byte, seed int64) {
   138  		prng.Seed(seed)
   139  		src = generate(src[:0], input, prng)
   140  
   141  		dst, err = encode(e, dst, src)
   142  		if err != nil {
   143  			t.Error(err)
   144  			return
   145  		}
   146  
   147  		buf, err = decode(e, buf, dst)
   148  		if err != nil {
   149  			t.Error(err)
   150  			return
   151  		}
   152  
   153  		if !equal(buf, src) {
   154  			t.Error("decoded output does not match the original input")
   155  			return
   156  		}
   157  	})
   158  }
   159  
   160  func equal[T comparable](a, b []T) bool {
   161  	if len(a) != len(b) {
   162  		return false
   163  	}
   164  	for i := range a {
   165  		if a[i] != b[i] {
   166  			return false
   167  		}
   168  	}
   169  	return true
   170  }
   171  
   172  func generate[T comparable](dst []T, src []byte, prng *rand.Rand) []T {
   173  	return append(dst[:0], unsafecast.Slice[T](src)...)
   174  }