github.com/ethereum/go-ethereum@v1.16.1/crypto/secp256k1/libsecp256k1/doc/safegcd_implementation.md (about)

     1  # The safegcd implementation in libsecp256k1 explained
     2  
     3  This document explains the modular inverse and Jacobi symbol implementations in the `src/modinv*.h` files.
     4  It is based on the paper
     5  ["Fast constant-time gcd computation and modular inversion"](https://gcd.cr.yp.to/papers.html#safegcd)
     6  by Daniel J. Bernstein and Bo-Yin Yang. The references below are for the Date: 2019.04.13 version.
     7  
     8  The actual implementation is in C of course, but for demonstration purposes Python3 is used here.
     9  Most implementation aspects and optimizations are explained, except those that depend on the specific
    10  number representation used in the C code.
    11  
    12  ## 1. Computing the Greatest Common Divisor (GCD) using divsteps
    13  
    14  The algorithm from the paper (section 11), at a very high level, is this:
    15  
    16  ```python
    17  def gcd(f, g):
    18      """Compute the GCD of an odd integer f and another integer g."""
    19      assert f & 1  # require f to be odd
    20      delta = 1     # additional state variable
    21      while g != 0:
    22          assert f & 1  # f will be odd in every iteration
    23          if delta > 0 and g & 1:
    24              delta, f, g = 1 - delta, g, (g - f) // 2
    25          elif g & 1:
    26              delta, f, g = 1 + delta, f, (g + f) // 2
    27          else:
    28              delta, f, g = 1 + delta, f, (g    ) // 2
    29      return abs(f)
    30  ```
    31  
    32  It computes the greatest common divisor of an odd integer *f* and any integer *g*. Its inner loop
    33  keeps rewriting the variables *f* and *g* alongside a state variable *δ* that starts at *1*, until
    34  *g=0* is reached. At that point, *|f|* gives the GCD. Each of the transitions in the loop is called a
    35  "division step" (referred to as divstep in what follows).
    36  
    37  For example, *gcd(21, 14)* would be computed as:
    38  - Start with *δ=1 f=21 g=14*
    39  - Take the third branch: *δ=2 f=21 g=7*
    40  - Take the first branch: *δ=-1 f=7 g=-7*
    41  - Take the second branch: *δ=0 f=7 g=0*
    42  - The answer *|f| = 7*.
    43  
    44  Why it works:
    45  - Divsteps can be decomposed into two steps (see paragraph 8.2 in the paper):
    46    - (a) If *g* is odd, replace *(f,g)* with *(g,g-f)* or (f,g+f), resulting in an even *g*.
    47    - (b) Replace *(f,g)* with *(f,g/2)* (where *g* is guaranteed to be even).
    48  - Neither of those two operations change the GCD:
    49    - For (a), assume *gcd(f,g)=c*, then it must be the case that *f=a c* and *g=b c* for some integers *a*
    50      and *b*. As *(g,g-f)=(b c,(b-a)c)* and *(f,f+g)=(a c,(a+b)c)*, the result clearly still has
    51      common factor *c*. Reasoning in the other direction shows that no common factor can be added by
    52      doing so either.
    53    - For (b), we know that *f* is odd, so *gcd(f,g)* clearly has no factor *2*, and we can remove
    54      it from *g*.
    55  - The algorithm will eventually converge to *g=0*. This is proven in the paper (see theorem G.3).
    56  - It follows that eventually we find a final value *f'* for which *gcd(f,g) = gcd(f',0)*. As the
    57    gcd of *f'* and *0* is *|f'|* by definition, that is our answer.
    58  
    59  Compared to more [traditional GCD algorithms](https://en.wikipedia.org/wiki/Euclidean_algorithm), this one has the property of only ever looking at
    60  the low-order bits of the variables to decide the next steps, and being easy to make
    61  constant-time (in more low-level languages than Python). The *δ* parameter is necessary to
    62  guide the algorithm towards shrinking the numbers' magnitudes without explicitly needing to look
    63  at high order bits.
    64  
    65  Properties that will become important later:
    66  - Performing more divsteps than needed is not a problem, as *f* does not change anymore after *g=0*.
    67  - Only even numbers are divided by *2*. This means that when reasoning about it algebraically we
    68    do not need to worry about rounding.
    69  - At every point during the algorithm's execution the next *N* steps only depend on the bottom *N*
    70    bits of *f* and *g*, and on *δ*.
    71  
    72  
    73  ## 2. From GCDs to modular inverses
    74  
    75  We want an algorithm to compute the inverse *a* of *x* modulo *M*, i.e. the number a such that *a x=1
    76  mod M*. This inverse only exists if the GCD of *x* and *M* is *1*, but that is always the case if *M* is
    77  prime and *0 < x < M*. In what follows, assume that the modular inverse exists.
    78  It turns out this inverse can be computed as a side effect of computing the GCD by keeping track
    79  of how the internal variables can be written as linear combinations of the inputs at every step
    80  (see the [extended Euclidean algorithm](https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm)).
    81  Since the GCD is *1*, such an algorithm will compute numbers *a* and *b* such that a&thinsp;x + b&thinsp;M = 1*.
    82  Taking that expression *mod M* gives *a&thinsp;x mod M = 1*, and we see that *a* is the modular inverse of *x
    83  mod M*.
    84  
    85  A similar approach can be used to calculate modular inverses using the divsteps-based GCD
    86  algorithm shown above, if the modulus *M* is odd. To do so, compute *gcd(f=M,g=x)*, while keeping
    87  track of extra variables *d* and *e*, for which at every step *d = f/x (mod M)* and *e = g/x (mod M)*.
    88  *f/x* here means the number which multiplied with *x* gives *f mod M*. As *f* and *g* are initialized to *M*
    89  and *x* respectively, *d* and *e* just start off being *0* (*M/x mod M = 0/x mod M = 0*) and *1* (*x/x mod M
    90  = 1*).
    91  
    92  ```python
    93  def div2(M, x):
    94      """Helper routine to compute x/2 mod M (where M is odd)."""
    95      assert M & 1
    96      if x & 1: # If x is odd, make it even by adding M.
    97          x += M
    98      # x must be even now, so a clean division by 2 is possible.
    99      return x // 2
   100  
   101  def modinv(M, x):
   102      """Compute the inverse of x mod M (given that it exists, and M is odd)."""
   103      assert M & 1
   104      delta, f, g, d, e = 1, M, x, 0, 1
   105      while g != 0:
   106          # Note that while division by two for f and g is only ever done on even inputs, this is
   107          # not true for d and e, so we need the div2 helper function.
   108          if delta > 0 and g & 1:
   109              delta, f, g, d, e = 1 - delta, g, (g - f) // 2, e, div2(M, e - d)
   110          elif g & 1:
   111              delta, f, g, d, e = 1 + delta, f, (g + f) // 2, d, div2(M, e + d)
   112          else:
   113              delta, f, g, d, e = 1 + delta, f, (g    ) // 2, d, div2(M, e    )
   114          # Verify that the invariants d=f/x mod M, e=g/x mod M are maintained.
   115          assert f % M == (d * x) % M
   116          assert g % M == (e * x) % M
   117      assert f == 1 or f == -1  # |f| is the GCD, it must be 1
   118      # Because of invariant d = f/x (mod M), 1/x = d/f (mod M). As |f|=1, d/f = d*f.
   119      return (d * f) % M
   120  ```
   121  
   122  Also note that this approach to track *d* and *e* throughout the computation to determine the inverse
   123  is different from the paper. There (see paragraph 12.1 in the paper) a transition matrix for the
   124  entire computation is determined (see section 3 below) and the inverse is computed from that.
   125  The approach here avoids the need for 2x2 matrix multiplications of various sizes, and appears to
   126  be faster at the level of optimization we're able to do in C.
   127  
   128  
   129  ## 3. Batching multiple divsteps
   130  
   131  Every divstep can be expressed as a matrix multiplication, applying a transition matrix *(1/2 t)*
   132  to both vectors *[f, g]* and *[d, e]* (see paragraph 8.1 in the paper):
   133  
   134  ```
   135    t = [ u,  v ]
   136        [ q,  r ]
   137  
   138    [ out_f ] = (1/2 * t) * [ in_f ]
   139    [ out_g ] =             [ in_g ]
   140  
   141    [ out_d ] = (1/2 * t) * [ in_d ]  (mod M)
   142    [ out_e ]               [ in_e ]
   143  ```
   144  
   145  where *(u, v, q, r)* is *(0, 2, -1, 1)*, *(2, 0, 1, 1)*, or *(2, 0, 0, 1)*, depending on which branch is
   146  taken. As above, the resulting *f* and *g* are always integers.
   147  
   148  Performing multiple divsteps corresponds to a multiplication with the product of all the
   149  individual divsteps' transition matrices. As each transition matrix consists of integers
   150  divided by *2*, the product of these matrices will consist of integers divided by *2<sup>N</sup>* (see also
   151  theorem 9.2 in the paper). These divisions are expensive when updating *d* and *e*, so we delay
   152  them: we compute the integer coefficients of the combined transition matrix scaled by *2<sup>N</sup>*, and
   153  do one division by *2<sup>N</sup>* as a final step:
   154  
   155  ```python
   156  def divsteps_n_matrix(delta, f, g):
   157      """Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
   158      u, v, q, r = 1, 0, 0, 1 # start with identity matrix
   159      for _ in range(N):
   160          if delta > 0 and g & 1:
   161              delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v
   162          elif g & 1:
   163              delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v
   164          else:
   165              delta, f, g, u, v, q, r = 1 + delta, f, (g    ) // 2, 2*u, 2*v, q  , r
   166      return delta, (u, v, q, r)
   167  ```
   168  
   169  As the branches in the divsteps are completely determined by the bottom *N* bits of *f* and *g*, this
   170  function to compute the transition matrix only needs to see those bottom bits. Furthermore all
   171  intermediate results and outputs fit in *(N+1)*-bit numbers (unsigned for *f* and *g*; signed for *u*, *v*,
   172  *q*, and *r*) (see also paragraph 8.3 in the paper). This means that an implementation using 64-bit
   173  integers could set *N=62* and compute the full transition matrix for 62 steps at once without any
   174  big integer arithmetic at all. This is the reason why this algorithm is efficient: it only needs
   175  to update the full-size *f*, *g*, *d*, and *e* numbers once every *N* steps.
   176  
   177  We still need functions to compute:
   178  
   179  ```
   180    [ out_f ] = (1/2^N * [ u,  v ]) * [ in_f ]
   181    [ out_g ]   (        [ q,  r ])   [ in_g ]
   182  
   183    [ out_d ] = (1/2^N * [ u,  v ]) * [ in_d ]  (mod M)
   184    [ out_e ]   (        [ q,  r ])   [ in_e ]
   185  ```
   186  
   187  Because the divsteps transformation only ever divides even numbers by two, the result of *t&thinsp;[f,g]* is always even. When *t* is a composition of *N* divsteps, it follows that the resulting *f*
   188  and *g* will be multiple of *2<sup>N</sup>*, and division by *2<sup>N</sup>* is simply shifting them down:
   189  
   190  ```python
   191  def update_fg(f, g, t):
   192      """Multiply matrix t/2^N with [f, g]."""
   193      u, v, q, r = t
   194      cf, cg = u*f + v*g, q*f + r*g
   195      # (t / 2^N) should cleanly apply to [f,g] so the result of t*[f,g] should have N zero
   196      # bottom bits.
   197      assert cf % 2**N == 0
   198      assert cg % 2**N == 0
   199      return cf >> N, cg >> N
   200  ```
   201  
   202  The same is not true for *d* and *e*, and we need an equivalent of the `div2` function for division by *2<sup>N</sup> mod M*.
   203  This is easy if we have precomputed *1/M mod 2<sup>N</sup>* (which always exists for odd *M*):
   204  
   205  ```python
   206  def div2n(M, Mi, x):
   207      """Compute x/2^N mod M, given Mi = 1/M mod 2^N."""
   208      assert (M * Mi) % 2**N == 1
   209      # Find a factor m such that m*M has the same bottom N bits as x. We want:
   210      #     (m * M) mod 2^N = x mod 2^N
   211      # <=> m mod 2^N = (x / M) mod 2^N
   212      # <=> m mod 2^N = (x * Mi) mod 2^N
   213      m = (Mi * x) % 2**N
   214      # Subtract that multiple from x, cancelling its bottom N bits.
   215      x -= m * M
   216      # Now a clean division by 2^N is possible.
   217      assert x % 2**N == 0
   218      return (x >> N) % M
   219  
   220  def update_de(d, e, t, M, Mi):
   221      """Multiply matrix t/2^N with [d, e], modulo M."""
   222      u, v, q, r = t
   223      cd, ce = u*d + v*e, q*d + r*e
   224      return div2n(M, Mi, cd), div2n(M, Mi, ce)
   225  ```
   226  
   227  With all of those, we can write a version of `modinv` that performs *N* divsteps at once:
   228  
   229  ```python3
   230  def modinv(M, Mi, x):
   231      """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
   232      assert M & 1
   233      delta, f, g, d, e = 1, M, x, 0, 1
   234      while g != 0:
   235          # Compute the delta and transition matrix t for the next N divsteps (this only needs
   236          # (N+1)-bit signed integer arithmetic).
   237          delta, t = divsteps_n_matrix(delta, f % 2**N, g % 2**N)
   238          # Apply the transition matrix t to [f, g]:
   239          f, g = update_fg(f, g, t)
   240          # Apply the transition matrix t to [d, e]:
   241          d, e = update_de(d, e, t, M, Mi)
   242      return (d * f) % M
   243  ```
   244  
   245  This means that in practice we'll always perform a multiple of *N* divsteps. This is not a problem
   246  because once *g=0*, further divsteps do not affect *f*, *g*, *d*, or *e* anymore (only *&delta;* keeps
   247  increasing). For variable time code such excess iterations will be mostly optimized away in later
   248  sections.
   249  
   250  
   251  ## 4. Avoiding modulus operations
   252  
   253  So far, there are two places where we compute a remainder of big numbers modulo *M*: at the end of
   254  `div2n` in every `update_de`, and at the very end of `modinv` after potentially negating *d* due to the
   255  sign of *f*. These are relatively expensive operations when done generically.
   256  
   257  To deal with the modulus operation in `div2n`, we simply stop requiring *d* and *e* to be in range
   258  *[0,M)* all the time. Let's start by inlining `div2n` into `update_de`, and dropping the modulus
   259  operation at the end:
   260  
   261  ```python
   262  def update_de(d, e, t, M, Mi):
   263      """Multiply matrix t/2^N with [d, e] mod M, given Mi=1/M mod 2^N."""
   264      u, v, q, r = t
   265      cd, ce = u*d + v*e, q*d + r*e
   266      # Cancel out bottom N bits of cd and ce.
   267      md = -((Mi * cd) % 2**N)
   268      me = -((Mi * ce) % 2**N)
   269      cd += md * M
   270      ce += me * M
   271      # And cleanly divide by 2**N.
   272      return cd >> N, ce >> N
   273  ```
   274  
   275  Let's look at bounds on the ranges of these numbers. It can be shown that *|u|+|v|* and *|q|+|r|*
   276  never exceed *2<sup>N</sup>* (see paragraph 8.3 in the paper), and thus a multiplication with *t* will have
   277  outputs whose absolute values are at most *2<sup>N</sup>* times the maximum absolute input value. In case the
   278  inputs *d* and *e* are in *(-M,M)*, which is certainly true for the initial values *d=0* and *e=1* assuming
   279  *M > 1*, the multiplication results in numbers in range *(-2<sup>N</sup>M,2<sup>N</sup>M)*. Subtracting less than *2<sup>N</sup>*
   280  times *M* to cancel out *N* bits brings that up to *(-2<sup>N+1</sup>M,2<sup>N</sup>M)*, and
   281  dividing by *2<sup>N</sup>* at the end takes it to *(-2M,M)*. Another application of `update_de` would take that
   282  to *(-3M,2M)*, and so forth. This progressive expansion of the variables' ranges can be
   283  counteracted by incrementing *d* and *e* by *M* whenever they're negative:
   284  
   285  ```python
   286      ...
   287      if d < 0:
   288          d += M
   289      if e < 0:
   290          e += M
   291      cd, ce = u*d + v*e, q*d + r*e
   292      # Cancel out bottom N bits of cd and ce.
   293      ...
   294  ```
   295  
   296  With inputs in *(-2M,M)*, they will first be shifted into range *(-M,M)*, which means that the
   297  output will again be in *(-2M,M)*, and this remains the case regardless of how many `update_de`
   298  invocations there are. In what follows, we will try to make this more efficient.
   299  
   300  Note that increasing *d* by *M* is equal to incrementing *cd* by *u&thinsp;M* and *ce* by *q&thinsp;M*. Similarly,
   301  increasing *e* by *M* is equal to incrementing *cd* by *v&thinsp;M* and *ce* by *r&thinsp;M*. So we could instead write:
   302  
   303  ```python
   304      ...
   305      cd, ce = u*d + v*e, q*d + r*e
   306      # Perform the equivalent of incrementing d, e by M when they're negative.
   307      if d < 0:
   308          cd += u*M
   309          ce += q*M
   310      if e < 0:
   311          cd += v*M
   312          ce += r*M
   313      # Cancel out bottom N bits of cd and ce.
   314      md = -((Mi * cd) % 2**N)
   315      me = -((Mi * ce) % 2**N)
   316      cd += md * M
   317      ce += me * M
   318      ...
   319  ```
   320  
   321  Now note that we have two steps of corrections to *cd* and *ce* that add multiples of *M*: this
   322  increment, and the decrement that cancels out bottom bits. The second one depends on the first
   323  one, but they can still be efficiently combined by only computing the bottom bits of *cd* and *ce*
   324  at first, and using that to compute the final *md*, *me* values:
   325  
   326  ```python
   327  def update_de(d, e, t, M, Mi):
   328      """Multiply matrix t/2^N with [d, e], modulo M."""
   329      u, v, q, r = t
   330      md, me = 0, 0
   331      # Compute what multiples of M to add to cd and ce.
   332      if d < 0:
   333          md += u
   334          me += q
   335      if e < 0:
   336          md += v
   337          me += r
   338      # Compute bottom N bits of t*[d,e] + M*[md,me].
   339      cd, ce = (u*d + v*e + md*M) % 2**N, (q*d + r*e + me*M) % 2**N
   340      # Correct md and me such that the bottom N bits of t*[d,e] + M*[md,me] are zero.
   341      md -= (Mi * cd) % 2**N
   342      me -= (Mi * ce) % 2**N
   343      # Do the full computation.
   344      cd, ce = u*d + v*e + md*M, q*d + r*e + me*M
   345      # And cleanly divide by 2**N.
   346      return cd >> N, ce >> N
   347  ```
   348  
   349  One last optimization: we can avoid the *md&thinsp;M* and *me&thinsp;M* multiplications in the bottom bits of *cd*
   350  and *ce* by moving them to the *md* and *me* correction:
   351  
   352  ```python
   353      ...
   354      # Compute bottom N bits of t*[d,e].
   355      cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N
   356      # Correct md and me such that the bottom N bits of t*[d,e]+M*[md,me] are zero.
   357      # Note that this is not the same as {md = (-Mi * cd) % 2**N} etc. That would also result in N
   358      # zero bottom bits, but isn't guaranteed to be a reduction of [0,2^N) compared to the
   359      # previous md and me values, and thus would violate our bounds analysis.
   360      md -= (Mi*cd + md) % 2**N
   361      me -= (Mi*ce + me) % 2**N
   362      ...
   363  ```
   364  
   365  The resulting function takes *d* and *e* in range *(-2M,M)* as inputs, and outputs values in the same
   366  range. That also means that the *d* value at the end of `modinv` will be in that range, while we want
   367  a result in *[0,M)*. To do that, we need a normalization function. It's easy to integrate the
   368  conditional negation of *d* (based on the sign of *f*) into it as well:
   369  
   370  ```python
   371  def normalize(sign, v, M):
   372      """Compute sign*v mod M, where v is in range (-2*M,M); output in [0,M)."""
   373      assert sign == 1 or sign == -1
   374      # v in (-2*M,M)
   375      if v < 0:
   376          v += M
   377      # v in (-M,M). Now multiply v with sign (which can only be 1 or -1).
   378      if sign == -1:
   379          v = -v
   380      # v in (-M,M)
   381      if v < 0:
   382          v += M
   383      # v in [0,M)
   384      return v
   385  ```
   386  
   387  And calling it in `modinv` is simply:
   388  
   389  ```python
   390     ...
   391     return normalize(f, d, M)
   392  ```
   393  
   394  
   395  ## 5. Constant-time operation
   396  
   397  The primary selling point of the algorithm is fast constant-time operation. What code flow still
   398  depends on the input data so far?
   399  
   400  - the number of iterations of the while *g &ne; 0* loop in `modinv`
   401  - the branches inside `divsteps_n_matrix`
   402  - the sign checks in `update_de`
   403  - the sign checks in `normalize`
   404  
   405  To make the while loop in `modinv` constant time it can be replaced with a constant number of
   406  iterations. The paper proves (Theorem 11.2) that *741* divsteps are sufficient for any *256*-bit
   407  inputs, and [safegcd-bounds](https://github.com/sipa/safegcd-bounds) shows that the slightly better bound *724* is
   408  sufficient even. Given that every loop iteration performs *N* divsteps, it will run a total of
   409  *&lceil;724/N&rceil;* times.
   410  
   411  To deal with the branches in `divsteps_n_matrix` we will replace them with constant-time bitwise
   412  operations (and hope the C compiler isn't smart enough to turn them back into branches; see
   413  `ctime_tests.c` for automated tests that this isn't the case). To do so, observe that a
   414  divstep can be written instead as (compare to the inner loop of `gcd` in section 1).
   415  
   416  ```python
   417      x = -f if delta > 0 else f         # set x equal to (input) -f or f
   418      if g & 1:
   419          g += x                         # set g to (input) g-f or g+f
   420          if delta > 0:
   421              delta = -delta
   422              f += g                     # set f to (input) g (note that g was set to g-f before)
   423      delta += 1
   424      g >>= 1
   425  ```
   426  
   427  To convert the above to bitwise operations, we rely on a trick to negate conditionally: per the
   428  definition of negative numbers in two's complement, (*-v == ~v + 1*) holds for every number *v*. As
   429  *-1* in two's complement is all *1* bits, bitflipping can be expressed as xor with *-1*. It follows
   430  that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on values *0* or *-1*, then
   431  *(v ^ c) - c* is *v* if *c=0* and *-v* if *c=-1*.
   432  
   433  Using this we can write:
   434  
   435  ```python
   436      x = -f if delta > 0 else f
   437  ```
   438  
   439  in constant-time form as:
   440  
   441  ```python
   442      c1 = (-delta) >> 63
   443      # Conditionally negate f based on c1:
   444      x = (f ^ c1) - c1
   445  ```
   446  
   447  To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;>0* to *-1*
   448  (if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by
   449  the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see
   450  `assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all
   451  numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*.
   452  
   453  Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again), we can write:
   454  
   455  ```python
   456      if g & 1:
   457          g += x
   458  ```
   459  
   460  as:
   461  
   462  ```python
   463      # Compute c2=0 if g is even and c2=-1 if g is odd.
   464      c2 = -(g & 1)
   465      # This masks out x if g is even, and leaves x be if g is odd.
   466      g += x & c2
   467  ```
   468  
   469  Using the conditional negation trick again we can write:
   470  
   471  ```python
   472      if g & 1:
   473          if delta > 0:
   474              delta = -delta
   475  ```
   476  
   477  as:
   478  
   479  ```python
   480      # Compute c3=-1 if g is odd and delta>0, and 0 otherwise.
   481      c3 = c1 & c2
   482      # Conditionally negate delta based on c3:
   483      delta = (delta ^ c3) - c3
   484  ```
   485  
   486  Finally:
   487  
   488  ```python
   489      if g & 1:
   490          if delta > 0:
   491              f += g
   492  ```
   493  
   494  becomes:
   495  
   496  ```python
   497      f += g & c3
   498  ```
   499  
   500  It turns out that this can be implemented more efficiently by applying the substitution
   501  *&eta;=-&delta;*. In this representation, negating *&delta;* corresponds to negating *&eta;*, and incrementing
   502  *&delta;* corresponds to decrementing *&eta;*. This allows us to remove the negation in the *c1*
   503  computation:
   504  
   505  ```python
   506      # Compute a mask c1 for eta < 0, and compute the conditional negation x of f:
   507      c1 = eta >> 63
   508      x = (f ^ c1) - c1
   509      # Compute a mask c2 for odd g, and conditionally add x to g:
   510      c2 = -(g & 1)
   511      g += x & c2
   512      # Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta,
   513      # and add g to f:
   514      c3 = c1 & c2
   515      eta = (eta ^ c3) - c3
   516      f += g & c3
   517      # Incrementing delta corresponds to decrementing eta.
   518      eta -= 1
   519      g >>= 1
   520  ```
   521  
   522  A variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
   523  *1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
   524  (which can be shown using convex hull analysis). In this case, the substitution *&zeta;=-(&delta;+1/2)*
   525  is used instead to keep the variable integral. Incrementing *&delta;* by *1* still translates to
   526  decrementing *&zeta;* by *1*, but negating *&delta;* now corresponds to going from *&zeta;* to *-(&zeta;+1)*, or
   527  *~&zeta;*. Doing that conditionally based on *c3* is simply:
   528  
   529  ```python
   530      ...
   531      c3 = c1 & c2
   532      zeta ^= c3
   533      ...
   534  ```
   535  
   536  By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
   537  also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
   538  `divsteps_n_matrix` is obtained. The full code will be in section 7.
   539  
   540  These bit fiddling tricks can also be used to make the conditional negations and additions in
   541  `update_de` and `normalize` constant-time.
   542  
   543  
   544  ## 6. Variable-time optimizations
   545  
   546  In section 5, we modified the `divsteps_n_matrix` function (and a few others) to be constant time.
   547  Constant time operations are only necessary when computing modular inverses of secret data. In
   548  other cases, it slows down calculations unnecessarily. In this section, we will construct a
   549  faster non-constant time `divsteps_n_matrix` function.
   550  
   551  To do so, first consider yet another way of writing the inner loop of divstep operations in
   552  `gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
   553  the original version with initial *&delta;=1* and *&eta;=-&delta;* here.
   554  
   555  ```python
   556  for _ in range(N):
   557      if g & 1 and eta < 0:
   558          eta, f, g = -eta, g, -f
   559      if g & 1:
   560          g += f
   561      eta -= 1
   562      g >>= 1
   563  ```
   564  
   565  Whenever *g* is even, the loop only shifts *g* down and decreases *&eta;*. When *g* ends in multiple zero
   566  bits, these iterations can be consolidated into one step. This requires counting the bottom zero
   567  bits efficiently, which is possible on most platforms; it is abstracted here as the function
   568  `count_trailing_zeros`.
   569  
   570  ```python
   571  def count_trailing_zeros(v):
   572      """
   573      When v is zero, consider all N zero bits as "trailing".
   574      For a non-zero value v, find z such that v=(d<<z) for some odd d.
   575      """
   576      if v == 0:
   577          return N
   578      else:
   579          return (v & -v).bit_length() - 1
   580  
   581  i = N # divsteps left to do
   582  while True:
   583      # Get rid of all bottom zeros at once. In the first iteration, g may be odd and the following
   584      # lines have no effect (until "if eta < 0").
   585      zeros = min(i, count_trailing_zeros(g))
   586      eta -= zeros
   587      g >>= zeros
   588      i -= zeros
   589      if i == 0:
   590          break
   591      # We know g is odd now
   592      if eta < 0:
   593          eta, f, g = -eta, g, -f
   594      g += f
   595      # g is even now, and the eta decrement and g shift will happen in the next loop.
   596  ```
   597  
   598  We can now remove multiple bottom *0* bits from *g* at once, but still need a full iteration whenever
   599  there is a bottom *1* bit. In what follows, we will get rid of multiple *1* bits simultaneously as
   600  well.
   601  
   602  Observe that as long as *&eta; &geq; 0*, the loop does not modify *f*. Instead, it cancels out bottom
   603  bits of *g* and shifts them out, and decreases *&eta;* and *i* accordingly - interrupting only when *&eta;*
   604  becomes negative, or when *i* reaches *0*. Combined, this is equivalent to adding a multiple of *f* to
   605  *g* to cancel out multiple bottom bits, and then shifting them out.
   606  
   607  It is easy to find what that multiple is: we want a number *w* such that *g+w&thinsp;f* has a few bottom
   608  zero bits. If that number of bits is *L*, we want *g+w&thinsp;f mod 2<sup>L</sup> = 0*, or *w = -g/f mod 2<sup>L</sup>*. Since *f*
   609  is odd, such a *w* exists for any *L*. *L* cannot be more than *i* steps (as we'd finish the loop before
   610  doing more) or more than *&eta;+1* steps (as we'd run `eta, f, g = -eta, g, -f` at that point), but
   611  apart from that, we're only limited by the complexity of computing *w*.
   612  
   613  This code demonstrates how to cancel up to 4 bits per step:
   614  
   615  ```python
   616  NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
   617  i = N
   618  while True:
   619      zeros = min(i, count_trailing_zeros(g))
   620      eta -= zeros
   621      g >>= zeros
   622      i -= zeros
   623      if i == 0:
   624          break
   625      # We know g is odd now
   626      if eta < 0:
   627          eta, f, g = -eta, g, -f
   628      # Compute limit on number of bits to cancel
   629      limit = min(min(eta + 1, i), 4)
   630      # Compute w = -g/f mod 2**limit, using the table value for -1/f mod 2**4. Note that f is
   631      # always odd, so its inverse modulo a power of two always exists.
   632      w = (g * NEGINV16[(f & 15) // 2]) % (2**limit)
   633      # As w = -g/f mod (2**limit), g+w*f mod 2**limit = 0 mod 2**limit.
   634      g += w * f
   635      assert g % (2**limit) == 0
   636      # The next iteration will now shift out at least limit bottom zero bits from g.
   637  ```
   638  
   639  By using a bigger table more bits can be cancelled at once. The table can also be implemented
   640  as a formula. Several formulas are known for computing modular inverses modulo powers of two;
   641  some can be found in Hacker's Delight second edition by Henry S. Warren, Jr. pages 245-247.
   642  Here we need the negated modular inverse, which is a simple transformation of those:
   643  
   644  - Instead of a 3-bit table:
   645    - *-f* or *f ^ 6*
   646  - Instead of a 4-bit table:
   647    - *1 - f(f + 1)*
   648    - *-(f + (((f + 1) & 4) << 1))*
   649  - For larger tables the following technique can be used: if *w=-1/f mod 2<sup>L</sup>*, then *w(w&thinsp;f+2)* is
   650    *-1/f mod 2<sup>2L</sup>*. This allows extending the previous formulas (or tables). In particular we
   651    have this 6-bit function (based on the 3-bit function above):
   652    - *f(f<sup>2</sup> - 2)*
   653  
   654  This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
   655  `divsteps_n_matrix`, gives a significantly faster, but non-constant time version.
   656  
   657  
   658  ## 7. Final Python version
   659  
   660  All together we need the following functions:
   661  
   662  - A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function
   663    from section 2, but with its loop replaced by a variant of the constant-time divstep from
   664    section 5, extended to handle *u*, *v*, *q*, *r*:
   665  
   666  ```python
   667  def divsteps_n_matrix(zeta, f, g):
   668      """Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
   669      u, v, q, r = 1, 0, 0, 1 # start with identity matrix
   670      for _ in range(N):
   671          c1 = zeta >> 63
   672          # Compute x, y, z as conditionally-negated versions of f, u, v.
   673          x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
   674          c2 = -(g & 1)
   675          # Conditionally add x, y, z to g, q, r.
   676          g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
   677          c1 &= c2                     # reusing c1 here for the earlier c3 variable
   678          zeta = (zeta ^ c1) - 1       # inlining the unconditional zeta decrement here
   679          # Conditionally add g, q, r to f, u, v.
   680          f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
   681          # When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
   682          # by 2^N. Instead, shift f's coefficients u and v up.
   683          g, u, v = g >> 1, u << 1, v << 1
   684      return zeta, (u, v, q, r)
   685  ```
   686  
   687  - The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
   688    changes to `update_de` from section 5:
   689  
   690  ```python
   691  def update_fg(f, g, t):
   692      """Multiply matrix t/2^N with [f, g]."""
   693      u, v, q, r = t
   694      cf, cg = u*f + v*g, q*f + r*g
   695      return cf >> N, cg >> N
   696  
   697  def update_de(d, e, t, M, Mi):
   698      """Multiply matrix t/2^N with [d, e], modulo M."""
   699      u, v, q, r = t
   700      d_sign, e_sign = d >> 257, e >> 257
   701      md, me = (u & d_sign) + (v & e_sign), (q & d_sign) + (r & e_sign)
   702      cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N
   703      md -= (Mi*cd + md) % 2**N
   704      me -= (Mi*ce + me) % 2**N
   705      cd, ce = u*d + v*e + M*md, q*d + r*e + M*me
   706      return cd >> N, ce >> N
   707  ```
   708  
   709  - The `normalize` function from section 4, made constant time as well:
   710  
   711  ```python
   712  def normalize(sign, v, M):
   713      """Compute sign*v mod M, where v in (-2*M,M); output in [0,M)."""
   714      v_sign = v >> 257
   715      # Conditionally add M to v.
   716      v += M & v_sign
   717      c = (sign - 1) >> 1
   718      # Conditionally negate v.
   719      v = (v ^ c) - c
   720      v_sign = v >> 257
   721      # Conditionally add M to v again.
   722      v += M & v_sign
   723      return v
   724  ```
   725  
   726  - And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
   727    iteration count from section 5:
   728  
   729  ```python
   730  def modinv(M, Mi, x):
   731      """Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
   732      zeta, f, g, d, e = -1, M, x, 0, 1
   733      for _ in range((590 + N - 1) // N):
   734          zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
   735          f, g = update_fg(f, g, t)
   736          d, e = update_de(d, e, t, M, Mi)
   737      return normalize(f, d, M)
   738  ```
   739  
   740  - To get a variable time version, replace the `divsteps_n_matrix` function with one that uses the
   741    divsteps loop from section 5, and a `modinv` version that calls it without the fixed iteration
   742    count:
   743  
   744  ```python
   745  NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
   746  def divsteps_n_matrix_var(eta, f, g):
   747      """Compute eta and transition matrix t after N divsteps (multiplied by 2^N)."""
   748      u, v, q, r = 1, 0, 0, 1
   749      i = N
   750      while True:
   751          zeros = min(i, count_trailing_zeros(g))
   752          eta, i = eta - zeros, i - zeros
   753          g, u, v = g >> zeros, u << zeros, v << zeros
   754          if i == 0:
   755              break
   756          if eta < 0:
   757              eta, f, u, v, g, q, r = -eta, g, q, r, -f, -u, -v
   758          limit = min(min(eta + 1, i), 4)
   759          w = (g * NEGINV16[(f & 15) // 2]) % (2**limit)
   760          g, q, r = g + w*f, q + w*u, r + w*v
   761      return eta, (u, v, q, r)
   762  
   763  def modinv_var(M, Mi, x):
   764      """Compute the modular inverse of x mod M, given Mi = 1/M mod 2^N."""
   765      eta, f, g, d, e = -1, M, x, 0, 1
   766      while g != 0:
   767          eta, t = divsteps_n_matrix_var(eta, f % 2**N, g % 2**N)
   768          f, g = update_fg(f, g, t)
   769          d, e = update_de(d, e, t, M, Mi)
   770      return normalize(f, d, Mi)
   771  ```
   772  
   773  ## 8. From GCDs to Jacobi symbol
   774  
   775  We can also use a similar approach to calculate Jacobi symbol *(x | M)* by keeping track of an
   776  extra variable *j*, for which at every step *(x | M) = j (g | f)*. As we update *f* and *g*, we
   777  make corresponding updates to *j* using
   778  [properties of the Jacobi symbol](https://en.wikipedia.org/wiki/Jacobi_symbol#Properties):
   779  * *((g/2) | f)* is either *(g | f)* or *-(g | f)*, depending on the value of *f mod 8* (negating if it's *3* or *5*).
   780  * *(f | g)* is either *(g | f)* or *-(g | f)*, depending on *f mod 4* and *g mod 4* (negating if both are *3*).
   781  
   782  These updates depend only on the values of *f* and *g* modulo *4* or *8*, and can thus be applied
   783  very quickly, as long as we keep track of a few additional bits of *f* and *g*. Overall, this
   784  calculation is slightly simpler than the one for the modular inverse because we no longer need to
   785  keep track of *d* and *e*.
   786  
   787  However, one difficulty of this approach is that the Jacobi symbol *(a | n)* is only defined for
   788  positive odd integers *n*, whereas in the original safegcd algorithm, *f, g* can take negative
   789  values. We resolve this by using the following modified steps:
   790  
   791  ```python
   792          # Before
   793          if delta > 0 and g & 1:
   794              delta, f, g = 1 - delta, g, (g - f) // 2
   795  
   796          # After
   797          if delta > 0 and g & 1:
   798              delta, f, g = 1 - delta, g, (g + f) // 2
   799  ```
   800  
   801  The algorithm is still correct, since the changed divstep, called a "posdivstep" (see section 8.4
   802  and E.5 in the paper) preserves *gcd(f, g)*. However, there's no proof that the modified algorithm
   803  will converge. The justification for posdivsteps is completely empirical: in practice, it appears
   804  that the vast majority of nonzero inputs converge to *f=g=gcd(f<sub>0</sub>, g<sub>0</sub>)* in a
   805  number of steps proportional to their logarithm.
   806  
   807  Note that:
   808  - We require inputs to satisfy *gcd(x, M) = 1*, as otherwise *f=1* is not reached.
   809  - We require inputs *x &neq; 0*, because applying posdivstep with *g=0* has no effect.
   810  - We need to update the termination condition from *g=0* to *f=1*.
   811  
   812  We account for the possibility of nonconvergence by only performing a bounded number of
   813  posdivsteps, and then falling back to square-root based Jacobi calculation if a solution has not
   814  yet been found.
   815  
   816  The optimizations in sections 3-7 above are described in the context of the original divsteps, but
   817  in the C implementation we also adapt most of them (not including "avoiding modulus operations",
   818  since it's not necessary to track *d, e*, and "constant-time operation", since we never calculate
   819  Jacobi symbols for secret data) to the posdivsteps version.