github.com/cloudflare/circl@v1.5.0/pke/kyber/internal/common/poly_test.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "fmt" 7 "testing" 8 ) 9 10 func (p *Poly) RandAbsLe9Q() { 11 max := 9 * uint32(Q) 12 r := randSliceUint32WithMax(uint(N), max) 13 for i := 0; i < N; i++ { 14 p[i] = int16(int32(r[i])) 15 } 16 } 17 18 // Returns x mod^± q 19 func sModQ(x int16) int16 { 20 x = x % Q 21 if x >= (Q-1)/2 { 22 x = x - Q 23 } 24 return x 25 } 26 27 func TestDecompressMessage(t *testing.T) { 28 var m, m2 [PlaintextSize]byte 29 var p Poly 30 for i := 0; i < 1000; i++ { 31 if n, err := rand.Read(m[:]); err != nil { 32 t.Error(err) 33 } else if n != len(m) { 34 t.Fatal("short read from RNG") 35 } 36 37 p.DecompressMessage(m[:]) 38 p.CompressMessageTo(m2[:]) 39 if m != m2 { 40 t.Fatal() 41 } 42 } 43 } 44 45 func TestCompress(t *testing.T) { 46 for _, d := range []int{4, 5, 10, 11} { 47 t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) { 48 var p, q Poly 49 bound := (Q + (1 << uint(d))) >> uint(d+1) 50 buf := make([]byte, (N*d-1)/8+1) 51 for i := 0; i < 1000; i++ { 52 p.Rand() 53 p.CompressTo(buf, d) 54 q.Decompress(buf, d) 55 for j := 0; j < N; j++ { 56 diff := sModQ(p[j] - q[j]) 57 if diff < 0 { 58 diff = -diff 59 } 60 if diff > bound { 61 t.Logf("%v\n", buf) 62 t.Fatalf("|%d - %d mod^± q| = %d > %d, j=%d", 63 p[i], q[j], diff, bound, j) 64 } 65 } 66 } 67 }) 68 } 69 } 70 71 func TestCompressMessage(t *testing.T) { 72 var p Poly 73 var m [32]byte 74 ok := true 75 for i := 0; i < int(Q); i++ { 76 p[0] = int16(i) 77 p.CompressMessageTo(m[:]) 78 want := byte(0) 79 if i >= 833 && i < 2497 { 80 want = 1 81 } 82 if m[0] != want { 83 ok = false 84 t.Logf("%d %d %d", i, want, m[0]) 85 } 86 } 87 if !ok { 88 t.Fatal() 89 } 90 } 91 92 func TestMulHat(t *testing.T) { 93 for k := 0; k < 1000; k++ { 94 var a, b, p, ah, bh, ph Poly 95 a.RandAbsLeQ() 96 b.RandAbsLeQ() 97 b[0] = 1 98 99 ah = a 100 bh = b 101 ah.NTT() 102 bh.NTT() 103 ph.MulHat(&ah, &bh) 104 ph.BarrettReduce() 105 ph.InvNTT() 106 107 for i := 0; i < N; i++ { 108 for j := 0; j < N; j++ { 109 v := montReduce(int32(a[i]) * int32(b[j])) 110 k := i + j 111 if k >= N { 112 // Recall xᴺ = -1. 113 k -= N 114 v = -v 115 } 116 p[k] = barrettReduce(v + p[k]) 117 } 118 } 119 120 for i := 0; i < N; i++ { 121 p[i] = int16((int32(p[i]) * ((1 << 16) % int32(Q))) % int32(Q)) 122 } 123 124 p.Normalize() 125 ph.Normalize() 126 a.Normalize() 127 b.Normalize() 128 129 if p != ph { 130 t.Fatalf("%v\n%v\n%v\n%v", a, b, p, ph) 131 } 132 } 133 } 134 135 func TestAddAgainstGeneric(t *testing.T) { 136 for k := 0; k < 1000; k++ { 137 var p1, p2, a, b Poly 138 a.RandAbsLeQ() 139 b.RandAbsLeQ() 140 p1.Add(&a, &b) 141 p2.addGeneric(&a, &b) 142 if p1 != p2 { 143 t.Fatalf("Add(%v, %v) = \n%v \n!= %v", a, b, p1, p2) 144 } 145 } 146 } 147 148 func BenchmarkAdd(b *testing.B) { 149 var p Poly 150 for i := 0; i < b.N; i++ { 151 p.Add(&p, &p) 152 } 153 } 154 155 func BenchmarkAddGeneric(b *testing.B) { 156 var p Poly 157 for i := 0; i < b.N; i++ { 158 p.addGeneric(&p, &p) 159 } 160 } 161 162 func TestSubAgainstGeneric(t *testing.T) { 163 for k := 0; k < 1000; k++ { 164 var p1, p2, a, b Poly 165 a.RandAbsLeQ() 166 b.RandAbsLeQ() 167 p1.Sub(&a, &b) 168 p2.subGeneric(&a, &b) 169 if p1 != p2 { 170 t.Fatalf("Sub(%v, %v) = \n%v \n!= %v", a, b, p1, p2) 171 } 172 } 173 } 174 175 func BenchmarkSub(b *testing.B) { 176 var p Poly 177 for i := 0; i < b.N; i++ { 178 p.Sub(&p, &p) 179 } 180 } 181 182 func BenchmarkSubGeneric(b *testing.B) { 183 var p Poly 184 for i := 0; i < b.N; i++ { 185 p.subGeneric(&p, &p) 186 } 187 } 188 189 func TestMulHatAgainstGeneric(t *testing.T) { 190 for k := 0; k < 1000; k++ { 191 var p1, p2, a, b Poly 192 a.RandAbsLeQ() 193 b.RandAbsLeQ() 194 a2 := a 195 b2 := b 196 a2.Tangle() 197 b2.Tangle() 198 p1.MulHat(&a2, &b2) 199 p1.Detangle() 200 p2.mulHatGeneric(&a, &b) 201 if p1 != p2 { 202 t.Fatalf("MulHat(%v, %v) = \n%v \n!= %v", a, b, p1, p2) 203 } 204 } 205 } 206 207 func BenchmarkMulHat(b *testing.B) { 208 var p Poly 209 for i := 0; i < b.N; i++ { 210 p.MulHat(&p, &p) 211 } 212 } 213 214 func BenchmarkMulHatGeneric(b *testing.B) { 215 var p Poly 216 for i := 0; i < b.N; i++ { 217 p.mulHatGeneric(&p, &p) 218 } 219 } 220 221 func BenchmarkBarrettReduce(b *testing.B) { 222 var p Poly 223 for i := 0; i < b.N; i++ { 224 p.BarrettReduce() 225 } 226 } 227 228 func BenchmarkBarrettReduceGeneric(b *testing.B) { 229 var p Poly 230 for i := 0; i < b.N; i++ { 231 p.barrettReduceGeneric() 232 } 233 } 234 235 func TestBarrettReduceAgainstGeneric(t *testing.T) { 236 for k := 0; k < 1000; k++ { 237 var p1, p2, a Poly 238 a.RandAbsLe9Q() 239 p1 = a 240 p2 = a 241 p1.BarrettReduce() 242 p2.barrettReduceGeneric() 243 if p1 != p2 { 244 t.Fatalf("BarrettReduce(%v) = \n%v \n!= %v", a, p1, p2) 245 } 246 } 247 } 248 249 func BenchmarkNormalize(b *testing.B) { 250 var p Poly 251 for i := 0; i < b.N; i++ { 252 p.Normalize() 253 } 254 } 255 256 func BenchmarkNormalizeGeneric(b *testing.B) { 257 var p Poly 258 for i := 0; i < b.N; i++ { 259 p.barrettReduceGeneric() 260 } 261 } 262 263 func TestNormalizeAgainstGeneric(t *testing.T) { 264 for k := 0; k < 1000; k++ { 265 var p1, p2, a Poly 266 a.RandAbsLe9Q() 267 p1 = a 268 p2 = a 269 p1.Normalize() 270 p2.normalizeGeneric() 271 if p1 != p2 { 272 t.Fatalf("Normalize(%v) = \n%v \n!= %v", a, p1, p2) 273 } 274 } 275 } 276 277 func (p *Poly) OldCompressTo(m []byte, d int) { 278 switch d { 279 case 4: 280 var t [8]uint16 281 idx := 0 282 for i := 0; i < N/8; i++ { 283 for j := 0; j < 8; j++ { 284 t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/ 285 uint32(Q)) & ((1 << 4) - 1) 286 } 287 m[idx] = byte(t[0]) | byte(t[1]<<4) 288 m[idx+1] = byte(t[2]) | byte(t[3]<<4) 289 m[idx+2] = byte(t[4]) | byte(t[5]<<4) 290 m[idx+3] = byte(t[6]) | byte(t[7]<<4) 291 idx += 4 292 } 293 294 case 5: 295 var t [8]uint16 296 idx := 0 297 for i := 0; i < N/8; i++ { 298 for j := 0; j < 8; j++ { 299 t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/ 300 uint32(Q)) & ((1 << 5) - 1) 301 } 302 m[idx] = byte(t[0]) | byte(t[1]<<5) 303 m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7) 304 m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4) 305 m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6) 306 m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3) 307 idx += 5 308 } 309 310 case 10: 311 var t [4]uint16 312 idx := 0 313 for i := 0; i < N/4; i++ { 314 for j := 0; j < 4; j++ { 315 t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/ 316 uint32(Q)) & ((1 << 10) - 1) 317 } 318 m[idx] = byte(t[0]) 319 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2) 320 m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4) 321 m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6) 322 m[idx+4] = byte(t[3] >> 2) 323 idx += 5 324 } 325 case 11: 326 var t [8]uint16 327 idx := 0 328 for i := 0; i < N/8; i++ { 329 for j := 0; j < 8; j++ { 330 t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/ 331 uint32(Q)) & ((1 << 11) - 1) 332 } 333 m[idx] = byte(t[0]) 334 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3) 335 m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6) 336 m[idx+3] = byte(t[2] >> 2) 337 m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1) 338 m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4) 339 m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7) 340 m[idx+7] = byte(t[5] >> 1) 341 m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2) 342 m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5) 343 m[idx+10] = byte(t[7] >> 3) 344 idx += 11 345 } 346 default: 347 panic("unsupported d") 348 } 349 } 350 351 func TestCompressFullInputFirstCoeff(t *testing.T) { 352 for _, d := range []int{4, 5, 10, 11} { 353 t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) { 354 var p, q Poly 355 bound := (Q + (1 << uint(d))) >> uint(d+1) 356 buf := make([]byte, (N*d-1)/8+1) 357 buf2 := make([]byte, len(buf)) 358 for i := int16(0); i < Q; i++ { 359 p[0] = i 360 p.CompressTo(buf, d) 361 p.OldCompressTo(buf2, d) 362 if !bytes.Equal(buf, buf2) { 363 t.Fatalf("%d", i) 364 } 365 q.Decompress(buf, d) 366 diff := sModQ(p[0] - q[0]) 367 if diff < 0 { 368 diff = -diff 369 } 370 if diff > bound { 371 t.Logf("%v\n", buf) 372 t.Fatalf("|%d - %d mod^± q| = %d > %d", 373 p[0], q[0], diff, bound) 374 } 375 } 376 }) 377 } 378 }