github.com/mad-day/Yawning-crypto@v0.0.0-20190711051033-5a5f8cca32ec/bcns/fft.go (about)

     1  //
     2  // FFT based polynomial multiplication.
     3  //
     4  // To the extent possible under law, Yawning Angel waived all copyright
     5  // and related or neighboring rights to ringlwe, using the Creative
     6  // Commons "CC0" public domain dedication. See LICENSE or
     7  // <http://creativecommons.org/publicdomain/zero/1.0/> for full details.
     8  
     9  package bcns
    10  
    11  import (
    12  	"unsafe"
    13  )
    14  
    15  type fftCtx struct {
    16  	x1 [64][64]uint32
    17  	y1 [64][64]uint32
    18  	z1 [64][64]uint32
    19  	t1 [64]uint32
    20  }
    21  
    22  // Reduction modulo p = 2^32 - 1.
    23  // This is not a prime since 2^32-1 = (2^1+1)*(2^2+1)*(2^4+1)*(2^8+1)*(2^16+1).
    24  // But since 2 is a unit in Z/pZ we can use it for computing FFTs in
    25  // Z/pZ[X]/(X^(2^7)+1)
    26  
    27  // Caution:
    28  // We use a redundant representation where the integer 0 is represented both
    29  // by 0 and 2^32-1.
    30  // This approach follows the describtion from the paper:
    31  // Joppe W. Bos, Craig Costello, Huseyin Hisil, and Kristin Lauter: Fast Cryptography in Genus 2
    32  // EUROCRYPT 2013, Lecture Notes in Computer Science 7881, pp. 194-210, Springer, 2013.
    33  // More specifically see: Section 3 related to Modular Addition/Subtraction.
    34  
    35  // Compute: c = (a+b) mod (2^32-1)
    36  // Let, t = a+b = t_1*2^32 + t0, where 0 <= t_1 <= 1, 0 <= t_0 < 2^32.
    37  // Then t mod (2^32-1) = t0 + t1
    38  
    39  // Yawning: Golang is so fucking stupid sometimes.  Like when I would kill to
    40  // have macros.  Or something that converts a bool to an int that's does not
    41  // involve either branches, or using "unsafe".  I should probably revisit this
    42  // and provide a vectorized assembly implementation of the entire FFT multiply.
    43  
    44  func boolToInt(b bool) uint32 {
    45  	// Yes, unsafe.  Really.  There is no better way to do this, which is all
    46  	// sorts of fucking braindamaged.
    47  	return uint32(*(*byte)(unsafe.Pointer(&b)))
    48  }
    49  
    50  func modadd(a, b uint32) (c uint32) {
    51  	t := a + b
    52  	c = t + boolToInt(t < a)
    53  	return
    54  }
    55  
    56  func modsub(a, b uint32) (c uint32) {
    57  	c = (a - b) - boolToInt(b > a)
    58  	return
    59  }
    60  
    61  func modmul(a, b uint32) (c uint32) {
    62  	t := uint64(a) * uint64(b)
    63  	c = modadd(uint32(t), (uint32(uint64(t) >> 32)))
    64  	return
    65  }
    66  
    67  func modmuladd(c, a, b uint32) uint32 {
    68  	t := uint64(a)*uint64(b) + uint64(c)
    69  	c = modadd(uint32(t), (uint32(t >> 32)))
    70  	return c
    71  }
    72  
    73  func div2(a uint32) (c uint32) {
    74  	c = uint32((uint64(a) + uint64(uint32(0-((a)&1))&0xFFFFFFFF)) >> 1)
    75  	return
    76  }
    77  
    78  func normalize(a uint32) (c uint32) {
    79  	c = a + boolToInt(a == 0xFFFFFFFF)
    80  	return c
    81  }
    82  
    83  func moddiv2(a uint32) (c uint32) {
    84  	c = normalize(a)
    85  	c = div2(c)
    86  	return
    87  }
    88  
    89  func neg(a uint32) (c uint32) {
    90  	c = 0xFFFFFFFF - a
    91  	c = normalize(c)
    92  	return
    93  }
    94  
    95  // Reverse the bits, approach from "Bit Twiddling Hacks"
    96  // See: https://graphics.stanford.edu/~seander/bithacks.html
    97  func reverse(x uint32) uint32 {
    98  	x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1))
    99  	x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2))
   100  	x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4))
   101  	x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8))
   102  	return ((x >> 16) | (x << 16))
   103  }
   104  
   105  // Nussbaumer approach, see:
   106  // H. J. Nussbaumer. Fast polynomial transform algorithms for digital convolution. Acoustics, Speech and
   107  // Signal Processing, IEEE Transactions on, 28(2):205{215, 1980
   108  // We followed the describtion from Knuth:
   109  // D. E. Knuth. Seminumerical Algorithms. The Art of Computer Programming. Addison-Wesley, Reading,
   110  // Massachusetts, USA, 3rd edition, 1997
   111  // Exercise Exercise 4.6.4.59.
   112  
   113  func naive(z, x, y *[64]uint32, n uint) {
   114  	for i := uint(0); i < n; i++ {
   115  		B := uint32(0)
   116  
   117  		A := modmul(x[0], y[i])
   118  
   119  		var j uint
   120  		for j = 1; j <= i; j++ {
   121  			A = modmuladd(A, x[j], y[i-j])
   122  		}
   123  
   124  		for k := uint(1); j < n; j, k = j+1, k+1 {
   125  			B = modmuladd(B, x[j], y[n-k])
   126  		}
   127  		z[i] = modsub(A, B)
   128  	}
   129  }
   130  
   131  func nussbaumerFFT(z []uint32, x []uint32, y []uint32, ctx *fftCtx) {
   132  	X1 := &ctx.x1
   133  	Y1 := &ctx.y1
   134  
   135  	for i := 0; i < 32; i++ {
   136  		for j := 0; j < 32; j++ {
   137  			X1[i][j] = x[32*j+i]
   138  			X1[i+32][j] = x[32*j+i]
   139  
   140  			Y1[i][j] = y[32*j+i]
   141  			Y1[i+32][j] = y[32*j+i]
   142  		}
   143  	}
   144  
   145  	Z1 := &ctx.z1
   146  	T1 := &ctx.t1
   147  
   148  	for j := 4; j >= 0; j-- {
   149  		jj := uint(j)
   150  		for i := uint32(0); i < (1 << (5 - jj)); i++ {
   151  			ssr := reverse(i)
   152  			for t := uint32(0); t < (1 << jj); t++ {
   153  				s := i
   154  				sr := ssr >> (32 - 5 + jj)
   155  				sr <<= jj
   156  				s <<= (jj + 1)
   157  
   158  				// X_i(w) = X_i(w) + w^kX_l(w) can be computed as
   159  				// X_ij = X_ij - X_l(j-k+r)  for  0 <= j < k
   160  				// X_ij = X_ij + X_l(j-k)    for  k <= j < r
   161  				I := s + t
   162  				L := s + t + (1 << jj)
   163  
   164  				for a := sr; a < 32; a++ {
   165  					T1[a] = X1[L][a-sr]
   166  				}
   167  				for a := uint32(0); a < sr; a++ {
   168  					T1[a] = neg(X1[L][32+a-sr])
   169  				}
   170  
   171  				for a := 0; a < 32; a++ {
   172  					X1[L][a] = modsub(X1[I][a], T1[a])
   173  					X1[I][a] = modadd(X1[I][a], T1[a])
   174  				}
   175  
   176  				for a := sr; a < 32; a++ {
   177  					T1[a] = Y1[L][a-sr]
   178  				}
   179  				for a := uint32(0); a < sr; a++ {
   180  					T1[a] = neg(Y1[L][32+a-sr])
   181  				}
   182  
   183  				for a := 0; a < 32; a++ {
   184  					Y1[L][a] = modsub(Y1[I][a], T1[a])
   185  					Y1[I][a] = modadd(Y1[I][a], T1[a])
   186  				}
   187  			}
   188  		}
   189  	}
   190  
   191  	for i := 0; i < 2*32; i++ {
   192  		naive(&Z1[i], &X1[i], &Y1[i], 32)
   193  	}
   194  
   195  	for j := uint32(0); j <= 5; j++ {
   196  		for i := uint32(0); i < (1 << (5 - j)); i++ {
   197  			ssr := reverse(i)
   198  			for t := uint32(0); t < (1 << j); t++ {
   199  				s := i
   200  				sr := (ssr >> (32 - 5 + j))
   201  				sr <<= j
   202  				s <<= (j + 1)
   203  
   204  				A := s + t
   205  				B := s + t + (1 << j)
   206  
   207  				for a := 0; a < 32; a++ {
   208  					T1[a] = modsub(Z1[A][a], Z1[B][a])
   209  					T1[a] = moddiv2(T1[a])
   210  					Z1[A][a] = modadd(Z1[A][a], Z1[B][a])
   211  					Z1[A][a] = moddiv2(Z1[A][a])
   212  				}
   213  
   214  				// w^{-(r/m)s'} (Z_{s+t}(w)-Z_{s+t+2^j}(w))
   215  				for a := uint32(0); a < 32-sr; a++ {
   216  					Z1[B][a] = T1[a+sr]
   217  				}
   218  				for a := 32 - sr; a < 32; a++ {
   219  					Z1[B][a] = neg(T1[a-(32-sr)])
   220  				}
   221  			}
   222  		}
   223  	}
   224  
   225  	for i := 0; i < 32; i++ {
   226  		z[i] = modsub(Z1[i][0], Z1[32+i][32-1])
   227  		for j := 1; j < 32; j++ {
   228  			z[32*j+i] = modadd(Z1[i][j], Z1[32+i][j-1])
   229  		}
   230  	}
   231  }
   232  
   233  func (f *fftCtx) multiply(z, x, y *[1024]uint32) {
   234  	nussbaumerFFT(z[:], x[:], y[:], f)
   235  }
   236  
   237  func (f *fftCtx) add(z, x, y *[1024]uint32) {
   238  	for i := 0; i < 1024; i++ {
   239  		z[i] = modadd(x[i], y[i])
   240  	}
   241  }
   242  
   243  func init() {
   244  	// Validate the assumptions made regarding bool/unsafe, in case the
   245  	// developers decide to torment me further in the future.
   246  	if unsafe.Sizeof(true) != 1 {
   247  		panic("sizeof(bool) != 1")
   248  	}
   249  	if boolToInt(true) != 1 || boolToInt(false) != 0 {
   250  		panic("bool primitive type data format is unexpected.")
   251  	}
   252  }