github.com/cloudflare/circl@v1.5.0/pke/kyber/internal/common/poly.go (about)

     1  package common
     2  
     3  // An element of our base ring R which are polynomials over ℤ_q
     4  // modulo the equation Xᴺ = -1, where q=3329 and N=256.
     5  //
     6  // This type is also used to store NTT-transformed polynomials,
     7  // see Poly.NTT().
     8  //
     9  // Coefficients aren't always reduced.  See Normalize().
    10  type Poly [N]int16
    11  
    12  // Sets p to a + b.  Does not normalize coefficients.
    13  func (p *Poly) addGeneric(a, b *Poly) {
    14  	for i := 0; i < N; i++ {
    15  		p[i] = a[i] + b[i]
    16  	}
    17  }
    18  
    19  // Sets p to a - b.  Does not normalize coefficients.
    20  func (p *Poly) subGeneric(a, b *Poly) {
    21  	for i := 0; i < N; i++ {
    22  		p[i] = a[i] - b[i]
    23  	}
    24  }
    25  
    26  // Almost normalizes coefficients.
    27  //
    28  // Ensures each coefficient is in {0, …, q}.
    29  func (p *Poly) barrettReduceGeneric() {
    30  	for i := 0; i < N; i++ {
    31  		p[i] = barrettReduce(p[i])
    32  	}
    33  }
    34  
    35  // Normalizes coefficients.
    36  //
    37  // Ensures each coefficient is in {0, …, q-1}.
    38  func (p *Poly) normalizeGeneric() {
    39  	for i := 0; i < N; i++ {
    40  		p[i] = csubq(barrettReduce(p[i]))
    41  	}
    42  }
    43  
    44  // Multiplies p in-place by the Montgomery factor 2¹⁶.
    45  //
    46  // Coefficients of p can be arbitrary.  Resulting coefficients are bounded
    47  // in absolute value by q.
    48  func (p *Poly) ToMont() {
    49  	for i := 0; i < N; i++ {
    50  		p[i] = toMont(p[i])
    51  	}
    52  }
    53  
    54  // Sets p to the "pointwise" multiplication of a and b.
    55  //
    56  // That is: InvNTT(p) = InvNTT(a) * InvNTT(b).  Assumes a and b are in
    57  // Montgomery form.  Products between coefficients of a and b must be strictly
    58  // bounded in absolute value by 2¹⁵q.  p will be in Montgomery form and
    59  // bounded in absolute value by 2q.
    60  //
    61  // Requires a and b to be in "tangled" order, see Tangle().  p will be in
    62  // tangled order as well.
    63  func (p *Poly) mulHatGeneric(a, b *Poly) {
    64  	// Recall from the discussion in NTT(), that a transformed polynomial is
    65  	// an element of ℤ_q[x]/(x²-ζ) x … x  ℤ_q[x]/(x²+ζ¹²⁷);
    66  	// that is: 128 degree-one polynomials instead of simply 256 elements
    67  	// from ℤ_q as in the regular NTT.  So instead of pointwise multiplication,
    68  	// we multiply the 128 pairs of degree-one polynomials modulo the
    69  	// right equation:
    70  	//
    71  	//  (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x,
    72  	//
    73  	// where ζ' is the appropriate power of ζ.
    74  
    75  	k := 64
    76  	for i := 0; i < N; i += 4 {
    77  		zeta := int32(Zetas[k])
    78  		k++
    79  
    80  		p0 := montReduce(int32(a[i+1]) * int32(b[i+1]))
    81  		p0 = montReduce(int32(p0) * zeta)
    82  		p0 += montReduce(int32(a[i]) * int32(b[i]))
    83  
    84  		p1 := montReduce(int32(a[i]) * int32(b[i+1]))
    85  		p1 += montReduce(int32(a[i+1]) * int32(b[i]))
    86  
    87  		p[i] = p0
    88  		p[i+1] = p1
    89  
    90  		p2 := montReduce(int32(a[i+3]) * int32(b[i+3]))
    91  		p2 = -montReduce(int32(p2) * zeta)
    92  		p2 += montReduce(int32(a[i+2]) * int32(b[i+2]))
    93  
    94  		p3 := montReduce(int32(a[i+2]) * int32(b[i+3]))
    95  		p3 += montReduce(int32(a[i+3]) * int32(b[i+2]))
    96  
    97  		p[i+2] = p2
    98  		p[i+3] = p3
    99  	}
   100  }
   101  
   102  // Packs p into buf.  buf should be of length PolySize.
   103  //
   104  // Assumes p is normalized (and not just Barrett reduced) and "tangled",
   105  // see Tangle().
   106  func (p *Poly) Pack(buf []byte) {
   107  	q := *p
   108  	q.Detangle()
   109  	for i := 0; i < 128; i++ {
   110  		t0 := q[2*i]
   111  		t1 := q[2*i+1]
   112  		buf[3*i] = byte(t0)
   113  		buf[3*i+1] = byte(t0>>8) | byte(t1<<4)
   114  		buf[3*i+2] = byte(t1 >> 4)
   115  	}
   116  }
   117  
   118  // Unpacks p from buf.
   119  //
   120  // buf should be of length PolySize.  p will be "tangled", see Detangle().
   121  //
   122  // p will not be normalized; instead 0 ≤ p[i] < 4096.
   123  func (p *Poly) Unpack(buf []byte) {
   124  	for i := 0; i < 128; i++ {
   125  		p[2*i] = int16(buf[3*i]) | ((int16(buf[3*i+1]) << 8) & 0xfff)
   126  		p[2*i+1] = int16(buf[3*i+1]>>4) | (int16(buf[3*i+2]) << 4)
   127  	}
   128  	p.Tangle()
   129  }
   130  
   131  // Set p to Decompress_q(m, 1).
   132  //
   133  // p will be normalized.  m has to be of PlaintextSize.
   134  func (p *Poly) DecompressMessage(m []byte) {
   135  	// Decompress_q(x, 1) = ⌈xq/2⌋ = ⌊xq/2+½⌋ = (xq+1) >> 1 and so
   136  	// Decompress_q(0, 1) = 0 and Decompress_q(1, 1) = (q+1)/2.
   137  	for i := 0; i < 32; i++ {
   138  		for j := 0; j < 8; j++ {
   139  			bit := (m[i] >> uint(j)) & 1
   140  
   141  			// Set coefficient to either 0 or (q+1)/2 depending on the bit.
   142  			p[8*i+j] = -int16(bit) & ((Q + 1) / 2)
   143  		}
   144  	}
   145  }
   146  
   147  // Writes Compress_q(p, 1) to m.
   148  //
   149  // Assumes p is normalized.  m has to be of length at least PlaintextSize.
   150  func (p *Poly) CompressMessageTo(m []byte) {
   151  	// Compress_q(x, 1) is 1 on {833, …, 2496} and zero elsewhere.
   152  	for i := 0; i < 32; i++ {
   153  		m[i] = 0
   154  		for j := 0; j < 8; j++ {
   155  			x := 1664 - p[8*i+j]
   156  			// With the previous substitution, we want to return 1 if
   157  			// and only if x is in {831, …, -832}.
   158  			x = (x >> 15) ^ x
   159  			// Note (x >> 15)ˣ if x≥0 and -x-1 otherwise. Thus now we want
   160  			// to return 1 iff x ≤ 831, ie. x - 832 < 0.
   161  			x -= 832
   162  			m[i] |= ((byte(x >> 15)) & 1) << uint(j)
   163  		}
   164  	}
   165  }
   166  
   167  // Set p to Decompress_q(m, 1).
   168  //
   169  // Assumes d is in {4, 5, 10, 11}.  p will be normalized.
   170  func (p *Poly) Decompress(m []byte, d int) {
   171  	// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
   172  	//                    = ⌊(q/2ᵈ)x+½⌋
   173  	//                    = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋
   174  	//                    = (qx + (1<<(d-1))) >> d
   175  	switch d {
   176  	case 4:
   177  		for i := 0; i < N/2; i++ {
   178  			p[2*i] = int16(((1 << 3) +
   179  				uint32(m[i]&15)*uint32(Q)) >> 4)
   180  			p[2*i+1] = int16(((1 << 3) +
   181  				uint32(m[i]>>4)*uint32(Q)) >> 4)
   182  		}
   183  	case 5:
   184  		var t [8]uint16
   185  		idx := 0
   186  		for i := 0; i < N/8; i++ {
   187  			t[0] = uint16(m[idx])
   188  			t[1] = (uint16(m[idx]) >> 5) | (uint16(m[idx+1] << 3))
   189  			t[2] = uint16(m[idx+1]) >> 2
   190  			t[3] = (uint16(m[idx+1]) >> 7) | (uint16(m[idx+2] << 1))
   191  			t[4] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3] << 4))
   192  			t[5] = uint16(m[idx+3]) >> 1
   193  			t[6] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4] << 2))
   194  			t[7] = uint16(m[idx+4]) >> 3
   195  
   196  			for j := 0; j < 8; j++ {
   197  				p[8*i+j] = int16(((1 << 4) +
   198  					uint32(t[j]&((1<<5)-1))*uint32(Q)) >> 5)
   199  			}
   200  
   201  			idx += 5
   202  		}
   203  
   204  	case 10:
   205  		var t [4]uint16
   206  		idx := 0
   207  		for i := 0; i < N/4; i++ {
   208  			t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
   209  			t[1] = (uint16(m[idx+1]) >> 2) | (uint16(m[idx+2]) << 6)
   210  			t[2] = (uint16(m[idx+2]) >> 4) | (uint16(m[idx+3]) << 4)
   211  			t[3] = (uint16(m[idx+3]) >> 6) | (uint16(m[idx+4]) << 2)
   212  
   213  			for j := 0; j < 4; j++ {
   214  				p[4*i+j] = int16(((1 << 9) +
   215  					uint32(t[j]&((1<<10)-1))*uint32(Q)) >> 10)
   216  			}
   217  
   218  			idx += 5
   219  		}
   220  	case 11:
   221  		var t [8]uint16
   222  		idx := 0
   223  		for i := 0; i < N/8; i++ {
   224  			t[0] = uint16(m[idx]) | (uint16(m[idx+1]) << 8)
   225  			t[1] = (uint16(m[idx+1]) >> 3) | (uint16(m[idx+2]) << 5)
   226  			t[2] = (uint16(m[idx+2]) >> 6) | (uint16(m[idx+3]) << 2) | (uint16(m[idx+4]) << 10)
   227  			t[3] = (uint16(m[idx+4]) >> 1) | (uint16(m[idx+5]) << 7)
   228  			t[4] = (uint16(m[idx+5]) >> 4) | (uint16(m[idx+6]) << 4)
   229  			t[5] = (uint16(m[idx+6]) >> 7) | (uint16(m[idx+7]) << 1) | (uint16(m[idx+8]) << 9)
   230  			t[6] = (uint16(m[idx+8]) >> 2) | (uint16(m[idx+9]) << 6)
   231  			t[7] = (uint16(m[idx+9]) >> 5) | (uint16(m[idx+10]) << 3)
   232  
   233  			for j := 0; j < 8; j++ {
   234  				p[8*i+j] = int16(((1 << 10) +
   235  					uint32(t[j]&((1<<11)-1))*uint32(Q)) >> 11)
   236  			}
   237  
   238  			idx += 11
   239  		}
   240  	default:
   241  		panic("unsupported d")
   242  	}
   243  }
   244  
   245  // Writes Compress_q(p, d) to m.
   246  //
   247  // Assumes p is normalized and d is in {4, 5, 10, 11}.
   248  func (p *Poly) CompressTo(m []byte, d int) {
   249  	// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
   250  	//                  = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
   251  	//					= ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
   252  	//					= DIV((x << d) + q/2, q) & ((1<<d) - 1)
   253  	//
   254  	// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
   255  	// For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
   256  	// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
   257  	// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
   258  	// correctly for all inputs, but it's close enough that the end result
   259  	// of the compression is correct. The advantage is that we do not need
   260  	// to use a 64-bit intermediate value.
   261  	switch d {
   262  	case 4:
   263  		var t [8]uint16
   264  		idx := 0
   265  		for i := 0; i < N/8; i++ {
   266  			for j := 0; j < 8; j++ {
   267  				t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
   268  					20) & ((1 << 4) - 1)
   269  			}
   270  			m[idx] = byte(t[0]) | byte(t[1]<<4)
   271  			m[idx+1] = byte(t[2]) | byte(t[3]<<4)
   272  			m[idx+2] = byte(t[4]) | byte(t[5]<<4)
   273  			m[idx+3] = byte(t[6]) | byte(t[7]<<4)
   274  			idx += 4
   275  		}
   276  
   277  	case 5:
   278  		var t [8]uint16
   279  		idx := 0
   280  		for i := 0; i < N/8; i++ {
   281  			for j := 0; j < 8; j++ {
   282  				t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
   283  					20) & ((1 << 5) - 1)
   284  			}
   285  			m[idx] = byte(t[0]) | byte(t[1]<<5)
   286  			m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
   287  			m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
   288  			m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
   289  			m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
   290  			idx += 5
   291  		}
   292  
   293  	case 10:
   294  		var t [4]uint16
   295  		idx := 0
   296  		for i := 0; i < N/4; i++ {
   297  			for j := 0; j < 4; j++ {
   298  				t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
   299  					20642679)>>36) & ((1 << 10) - 1)
   300  			}
   301  			m[idx] = byte(t[0])
   302  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
   303  			m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
   304  			m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
   305  			m[idx+4] = byte(t[3] >> 2)
   306  			idx += 5
   307  		}
   308  	case 11:
   309  		var t [8]uint16
   310  		idx := 0
   311  		for i := 0; i < N/8; i++ {
   312  			for j := 0; j < 8; j++ {
   313  				t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
   314  					20642679)>>36) & ((1 << 11) - 1)
   315  			}
   316  			m[idx] = byte(t[0])
   317  			m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
   318  			m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
   319  			m[idx+3] = byte(t[2] >> 2)
   320  			m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
   321  			m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
   322  			m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
   323  			m[idx+7] = byte(t[5] >> 1)
   324  			m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
   325  			m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
   326  			m[idx+10] = byte(t[7] >> 3)
   327  			idx += 11
   328  		}
   329  	default:
   330  		panic("unsupported d")
   331  	}
   332  }