github.com/consensys/gnark-crypto@v0.14.0/internal/generator/fft/template/tests/bitreverse.go.tmpl (about)

     1  import (
     2  	"fmt"
     3  	"testing"
     4  
     5  	{{ template "import_fr" . }}
     6  )
     7  
     8  
     9  type bitReverseVariant struct {
    10  	name string
    11  	buf  []fr.Element
    12  	fn   func([]fr.Element)
    13  }
    14  
    15  
    16  
    17  const maxSizeBitReverse = 1 << 23
    18  
    19  var bitReverse = []bitReverseVariant{
    20  	{name: "bitReverseNaive", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseNaive},
    21  	{name: "BitReverse", buf: make([]fr.Element, maxSizeBitReverse), fn: BitReverse},
    22  	{name: "bitReverseCobraInPlace", buf: make([]fr.Element, maxSizeBitReverse), fn: bitReverseCobraInPlace},
    23  }
    24  
    25  func TestBitReverse(t *testing.T) {
    26  
    27  	// generate a random []fr.Element array of size 2**20
    28  	pol := make([]fr.Element, maxSizeBitReverse)
    29  	one := fr.One()
    30  	pol[0].SetRandom()
    31  	for i := 1; i < maxSizeBitReverse; i++ {
    32  		pol[i].Add(&pol[i-1], &one)
    33  	}
    34  
    35  	// for each size, check that all the bitReverse functions fn compute the same result.
    36  	for size := 2; size <= maxSizeBitReverse; size <<= 1 {
    37  
    38  		// copy pol into the buffers
    39  		for _, data := range bitReverse {
    40  			copy(data.buf, pol[:size])
    41  		}
    42  
    43  		// compute bit reverse shuffling
    44  		for _, data := range bitReverse {
    45  			data.fn(data.buf[:size])
    46  		}
    47  
    48  		// all bitReverse.buf should hold the same result
    49  		for i := 0; i < size; i++ {
    50  			for j := 1; j < len(bitReverse); j++ {
    51  				if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) {
    52  					t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name)
    53  				}
    54  			}
    55  		}
    56  
    57  		// bitReverse back should be identity
    58  		for _, data := range bitReverse {
    59  			data.fn(data.buf[:size])
    60  		}
    61  
    62  		for i := 0; i < size; i++ {
    63  			for j := 1; j < len(bitReverse); j++ {
    64  				if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) {
    65  					t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name)
    66  				}
    67  			}
    68  		}
    69  	}
    70  
    71  }
    72  
    73  func BenchmarkBitReverse(b *testing.B) {
    74  	// generate a random []fr.Element array of size 2**22
    75  	pol := make([]fr.Element, maxSizeBitReverse)
    76  	one := fr.One()
    77  	pol[0].SetRandom()
    78  	for i := 1; i < maxSizeBitReverse; i++ {
    79  		pol[i].Add(&pol[i-1], &one)
    80  	}
    81  
    82  	// copy pol into the buffers
    83  	for _, data := range bitReverse {
    84  		copy(data.buf, pol[:maxSizeBitReverse])
    85  	}
    86  
    87  	// benchmark for each size, each bitReverse function
    88  	for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 {
    89  		for _, data := range bitReverse {
    90  			b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) {
    91  				b.ResetTimer()
    92  				for j := 0; j < b.N; j++ {
    93  					data.fn(data.buf[:size])
    94  				}
    95  			})
    96  		}
    97  	}
    98  }