github.com/cloudflare/circl@v1.5.0/pke/kyber/internal/common/poly_test.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"testing"
     8  )
     9  
    10  func (p *Poly) RandAbsLe9Q() {
    11  	max := 9 * uint32(Q)
    12  	r := randSliceUint32WithMax(uint(N), max)
    13  	for i := 0; i < N; i++ {
    14  		p[i] = int16(int32(r[i]))
    15  	}
    16  }
    17  
    18  // Returns x mod^± q
    19  func sModQ(x int16) int16 {
    20  	x = x % Q
    21  	if x >= (Q-1)/2 {
    22  		x = x - Q
    23  	}
    24  	return x
    25  }
    26  
    27  func TestDecompressMessage(t *testing.T) {
    28  	var m, m2 [PlaintextSize]byte
    29  	var p Poly
    30  	for i := 0; i < 1000; i++ {
    31  		if n, err := rand.Read(m[:]); err != nil {
    32  			t.Error(err)
    33  		} else if n != len(m) {
    34  			t.Fatal("short read from RNG")
    35  		}
    36  
    37  		p.DecompressMessage(m[:])
    38  		p.CompressMessageTo(m2[:])
    39  		if m != m2 {
    40  			t.Fatal()
    41  		}
    42  	}
    43  }
    44  
    45  func TestCompress(t *testing.T) {
    46  	for _, d := range []int{4, 5, 10, 11} {
    47  		t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
    48  			var p, q Poly
    49  			bound := (Q + (1 << uint(d))) >> uint(d+1)
    50  			buf := make([]byte, (N*d-1)/8+1)
    51  			for i := 0; i < 1000; i++ {
    52  				p.Rand()
    53  				p.CompressTo(buf, d)
    54  				q.Decompress(buf, d)
    55  				for j := 0; j < N; j++ {
    56  					diff := sModQ(p[j] - q[j])
    57  					if diff < 0 {
    58  						diff = -diff
    59  					}
    60  					if diff > bound {
    61  						t.Logf("%v\n", buf)
    62  						t.Fatalf("|%d - %d mod^± q| = %d > %d, j=%d",
    63  							p[i], q[j], diff, bound, j)
    64  					}
    65  				}
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  func TestCompressMessage(t *testing.T) {
    72  	var p Poly
    73  	var m [32]byte
    74  	ok := true
    75  	for i := 0; i < int(Q); i++ {
    76  		p[0] = int16(i)
    77  		p.CompressMessageTo(m[:])
    78  		want := byte(0)
    79  		if i >= 833 && i < 2497 {
    80  			want = 1
    81  		}
    82  		if m[0] != want {
    83  			ok = false
    84  			t.Logf("%d %d %d", i, want, m[0])
    85  		}
    86  	}
    87  	if !ok {
    88  		t.Fatal()
    89  	}
    90  }
    91  
    92  func TestMulHat(t *testing.T) {
    93  	for k := 0; k < 1000; k++ {
    94  		var a, b, p, ah, bh, ph Poly
    95  		a.RandAbsLeQ()
    96  		b.RandAbsLeQ()
    97  		b[0] = 1
    98  
    99  		ah = a
   100  		bh = b
   101  		ah.NTT()
   102  		bh.NTT()
   103  		ph.MulHat(&ah, &bh)
   104  		ph.BarrettReduce()
   105  		ph.InvNTT()
   106  
   107  		for i := 0; i < N; i++ {
   108  			for j := 0; j < N; j++ {
   109  				v := montReduce(int32(a[i]) * int32(b[j]))
   110  				k := i + j
   111  				if k >= N {
   112  					// Recall xᴺ = -1.
   113  					k -= N
   114  					v = -v
   115  				}
   116  				p[k] = barrettReduce(v + p[k])
   117  			}
   118  		}
   119  
   120  		for i := 0; i < N; i++ {
   121  			p[i] = int16((int32(p[i]) * ((1 << 16) % int32(Q))) % int32(Q))
   122  		}
   123  
   124  		p.Normalize()
   125  		ph.Normalize()
   126  		a.Normalize()
   127  		b.Normalize()
   128  
   129  		if p != ph {
   130  			t.Fatalf("%v\n%v\n%v\n%v", a, b, p, ph)
   131  		}
   132  	}
   133  }
   134  
   135  func TestAddAgainstGeneric(t *testing.T) {
   136  	for k := 0; k < 1000; k++ {
   137  		var p1, p2, a, b Poly
   138  		a.RandAbsLeQ()
   139  		b.RandAbsLeQ()
   140  		p1.Add(&a, &b)
   141  		p2.addGeneric(&a, &b)
   142  		if p1 != p2 {
   143  			t.Fatalf("Add(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   144  		}
   145  	}
   146  }
   147  
   148  func BenchmarkAdd(b *testing.B) {
   149  	var p Poly
   150  	for i := 0; i < b.N; i++ {
   151  		p.Add(&p, &p)
   152  	}
   153  }
   154  
   155  func BenchmarkAddGeneric(b *testing.B) {
   156  	var p Poly
   157  	for i := 0; i < b.N; i++ {
   158  		p.addGeneric(&p, &p)
   159  	}
   160  }
   161  
   162  func TestSubAgainstGeneric(t *testing.T) {
   163  	for k := 0; k < 1000; k++ {
   164  		var p1, p2, a, b Poly
   165  		a.RandAbsLeQ()
   166  		b.RandAbsLeQ()
   167  		p1.Sub(&a, &b)
   168  		p2.subGeneric(&a, &b)
   169  		if p1 != p2 {
   170  			t.Fatalf("Sub(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   171  		}
   172  	}
   173  }
   174  
   175  func BenchmarkSub(b *testing.B) {
   176  	var p Poly
   177  	for i := 0; i < b.N; i++ {
   178  		p.Sub(&p, &p)
   179  	}
   180  }
   181  
   182  func BenchmarkSubGeneric(b *testing.B) {
   183  	var p Poly
   184  	for i := 0; i < b.N; i++ {
   185  		p.subGeneric(&p, &p)
   186  	}
   187  }
   188  
   189  func TestMulHatAgainstGeneric(t *testing.T) {
   190  	for k := 0; k < 1000; k++ {
   191  		var p1, p2, a, b Poly
   192  		a.RandAbsLeQ()
   193  		b.RandAbsLeQ()
   194  		a2 := a
   195  		b2 := b
   196  		a2.Tangle()
   197  		b2.Tangle()
   198  		p1.MulHat(&a2, &b2)
   199  		p1.Detangle()
   200  		p2.mulHatGeneric(&a, &b)
   201  		if p1 != p2 {
   202  			t.Fatalf("MulHat(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
   203  		}
   204  	}
   205  }
   206  
   207  func BenchmarkMulHat(b *testing.B) {
   208  	var p Poly
   209  	for i := 0; i < b.N; i++ {
   210  		p.MulHat(&p, &p)
   211  	}
   212  }
   213  
   214  func BenchmarkMulHatGeneric(b *testing.B) {
   215  	var p Poly
   216  	for i := 0; i < b.N; i++ {
   217  		p.mulHatGeneric(&p, &p)
   218  	}
   219  }
   220  
   221  func BenchmarkBarrettReduce(b *testing.B) {
   222  	var p Poly
   223  	for i := 0; i < b.N; i++ {
   224  		p.BarrettReduce()
   225  	}
   226  }
   227  
   228  func BenchmarkBarrettReduceGeneric(b *testing.B) {
   229  	var p Poly
   230  	for i := 0; i < b.N; i++ {
   231  		p.barrettReduceGeneric()
   232  	}
   233  }
   234  
   235  func TestBarrettReduceAgainstGeneric(t *testing.T) {
   236  	for k := 0; k < 1000; k++ {
   237  		var p1, p2, a Poly
   238  		a.RandAbsLe9Q()
   239  		p1 = a
   240  		p2 = a
   241  		p1.BarrettReduce()
   242  		p2.barrettReduceGeneric()
   243  		if p1 != p2 {
   244  			t.Fatalf("BarrettReduce(%v) = \n%v \n!= %v", a, p1, p2)
   245  		}
   246  	}
   247  }
   248  
   249  func BenchmarkNormalize(b *testing.B) {
   250  	var p Poly
   251  	for i := 0; i < b.N; i++ {
   252  		p.Normalize()
   253  	}
   254  }
   255  
   256  func BenchmarkNormalizeGeneric(b *testing.B) {
   257  	var p Poly
   258  	for i := 0; i < b.N; i++ {
   259  		p.barrettReduceGeneric()
   260  	}
   261  }
   262  
   263  func TestNormalizeAgainstGeneric(t *testing.T) {
   264  	for k := 0; k < 1000; k++ {
   265  		var p1, p2, a Poly
   266  		a.RandAbsLe9Q()
   267  		p1 = a
   268  		p2 = a
   269  		p1.Normalize()
   270  		p2.normalizeGeneric()
   271  		if p1 != p2 {
   272  			t.Fatalf("Normalize(%v) = \n%v \n!= %v", a, p1, p2)
   273  		}
   274  	}
   275  }
   276  
   277  func (p *Poly) OldCompressTo(m []byte, d int) {
   278  	switch d {
   279  	case 4:
   280  		var t [8]uint16
   281  		idx := 0
   282  		for i := 0; i < N/8; i++ {
   283  			for j := 0; j < 8; j++ {
   284  				t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
   285  					uint32(Q)) & ((1 << 4) - 1)
   286  			}
   287  			m[idx] = byte(t[0]) | byte(t[1]<<4)
   288  			m[idx+1] = byte(t[2]) | byte(t[3]<<4)
   289  			m[idx+2] = byte(t[4]) | byte(t[5]<<4)
   290  			m[idx+3] = byte(t[6]) | byte(t[7]<<4)
   291  			idx += 4
   292  		}
   293  
   294  	case 5:
   295  		var t [8]uint16
   296  		idx := 0
   297  		for i := 0; i < N/8; i++ {
   298  			for j := 0; j < 8; j++ {
   299  				t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
   300  					uint32(Q)) & ((1 << 5) - 1)
   301  			}
   302  			m[idx] = byte(t[0]) | byte(t[1]<<5)
   303  			m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
   304  			m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
   305  			m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
   306  			m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
   307  			idx += 5
   308  		}
   309  
   310  	case 10:
   311  		var t [4]uint16
   312  		idx := 0
   313  		for i := 0; i < N/4; i++ {
   314  			for j := 0; j < 4; j++ {
   315  				t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
   316  					uint32(Q)) & ((1 << 10) - 1)
   317  			}
   318  			m[idx] = byte(t[0])
   319  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
   320  			m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
   321  			m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
   322  			m[idx+4] = byte(t[3] >> 2)
   323  			idx += 5
   324  		}
   325  	case 11:
   326  		var t [8]uint16
   327  		idx := 0
   328  		for i := 0; i < N/8; i++ {
   329  			for j := 0; j < 8; j++ {
   330  				t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
   331  					uint32(Q)) & ((1 << 11) - 1)
   332  			}
   333  			m[idx] = byte(t[0])
   334  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
   335  			m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
   336  			m[idx+3] = byte(t[2] >> 2)
   337  			m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
   338  			m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
   339  			m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
   340  			m[idx+7] = byte(t[5] >> 1)
   341  			m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
   342  			m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
   343  			m[idx+10] = byte(t[7] >> 3)
   344  			idx += 11
   345  		}
   346  	default:
   347  		panic("unsupported d")
   348  	}
   349  }
   350  
   351  func TestCompressFullInputFirstCoeff(t *testing.T) {
   352  	for _, d := range []int{4, 5, 10, 11} {
   353  		t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
   354  			var p, q Poly
   355  			bound := (Q + (1 << uint(d))) >> uint(d+1)
   356  			buf := make([]byte, (N*d-1)/8+1)
   357  			buf2 := make([]byte, len(buf))
   358  			for i := int16(0); i < Q; i++ {
   359  				p[0] = i
   360  				p.CompressTo(buf, d)
   361  				p.OldCompressTo(buf2, d)
   362  				if !bytes.Equal(buf, buf2) {
   363  					t.Fatalf("%d", i)
   364  				}
   365  				q.Decompress(buf, d)
   366  				diff := sModQ(p[0] - q[0])
   367  				if diff < 0 {
   368  					diff = -diff
   369  				}
   370  				if diff > bound {
   371  					t.Logf("%v\n", buf)
   372  					t.Fatalf("|%d - %d mod^± q| = %d > %d",
   373  						p[0], q[0], diff, bound)
   374  				}
   375  			}
   376  		})
   377  	}
   378  }