github.com/ethereum/go-ethereum@v1.16.1/crypto/secp256k1/libsecp256k1/sage/group_prover.sage (about)

     1  # This code supports verifying group implementations which have branches
     2  # or conditional statements (like cmovs), by allowing each execution path
     3  # to independently set assumptions on input or intermediary variables.
     4  #
     5  # The general approach is:
     6  # * A constraint is a tuple of two sets of symbolic expressions:
     7  #   the first of which are required to evaluate to zero, the second of which
     8  #   are required to evaluate to nonzero.
     9  #   - A constraint is said to be conflicting if any of its nonzero expressions
    10  #     is in the ideal with basis the zero expressions (in other words: when the
    11  #     zero expressions imply that one of the nonzero expressions are zero).
    12  # * There is a list of laws that describe the intended behaviour, including
    13  #   laws for addition and doubling. Each law is called with the symbolic point
    14  #   coordinates as arguments, and returns:
    15  #   - A constraint describing the assumptions under which it is applicable,
    16  #     called "assumeLaw"
    17  #   - A constraint describing the requirements of the law, called "require"
    18  # * Implementations are transliterated into functions that operate as well on
    19  #   algebraic input points, and are called once per combination of branches
    20  #   executed. Each execution returns:
    21  #   - A constraint describing the assumptions this implementation requires
    22  #     (such as Z1=1), called "assumeFormula"
    23  #   - A constraint describing the assumptions this specific branch requires,
    24  #     but which is by construction guaranteed to cover the entire space by
    25  #     merging the results from all branches, called "assumeBranch"
    26  #   - The result of the computation
    27  # * All combinations of laws with implementation branches are tried, and:
    28  #   - If the combination of assumeLaw, assumeFormula, and assumeBranch results
    29  #     in a conflict, it means this law does not apply to this branch, and it is
    30  #     skipped.
    31  #   - For others, we try to prove the require constraints hold, assuming the
    32  #     information in assumeLaw + assumeFormula + assumeBranch, and if this does
    33  #     not succeed, we fail.
    34  #     + To prove an expression is zero, we check whether it belongs to the
    35  #       ideal with the assumed zero expressions as basis. This test is exact.
    36  #     + To prove an expression is nonzero, we check whether each of its
    37  #       factors is contained in the set of nonzero assumptions' factors.
    38  #       This test is not exact, so various combinations of original and
    39  #       reduced expressions' factors are tried.
    40  #   - If we succeed, we print out the assumptions from assumeFormula that
    41  #     weren't implied by assumeLaw already. Those from assumeBranch are skipped,
    42  #     as we assume that all constraints in it are complementary with each other.
    43  #
    44  # Based on the sage verification scripts used in the Explicit-Formulas Database
    45  # by Tanja Lange and others, see https://hyperelliptic.org/EFD
    46  
    47  class fastfrac:
    48    """Fractions over rings."""
    49  
    50    def __init__(self,R,top,bot=1):
    51      """Construct a fractional, given a ring, a numerator, and denominator."""
    52      self.R = R
    53      if parent(top) == ZZ or parent(top) == R:
    54        self.top = R(top)
    55        self.bot = R(bot)
    56      elif top.__class__ == fastfrac:
    57        self.top = top.top
    58        self.bot = top.bot * bot
    59      else:
    60        self.top = R(numerator(top))
    61        self.bot = R(denominator(top)) * bot
    62  
    63    def iszero(self,I):
    64      """Return whether this fraction is zero given an ideal."""
    65      return self.top in I and self.bot not in I
    66  
    67    def reduce(self,assumeZero):
    68      zero = self.R.ideal(list(map(numerator, assumeZero)))
    69      return fastfrac(self.R, zero.reduce(self.top)) / fastfrac(self.R, zero.reduce(self.bot))
    70  
    71    def __add__(self,other):
    72      """Add two fractions."""
    73      if parent(other) == ZZ:
    74        return fastfrac(self.R,self.top + self.bot * other,self.bot)
    75      if other.__class__ == fastfrac:
    76        return fastfrac(self.R,self.top * other.bot + self.bot * other.top,self.bot * other.bot)
    77      return NotImplemented
    78  
    79    def __sub__(self,other):
    80      """Subtract two fractions."""
    81      if parent(other) == ZZ:
    82        return fastfrac(self.R,self.top - self.bot * other,self.bot)
    83      if other.__class__ == fastfrac:
    84        return fastfrac(self.R,self.top * other.bot - self.bot * other.top,self.bot * other.bot)
    85      return NotImplemented
    86  
    87    def __neg__(self):
    88      """Return the negation of a fraction."""
    89      return fastfrac(self.R,-self.top,self.bot)
    90  
    91    def __mul__(self,other):
    92      """Multiply two fractions."""
    93      if parent(other) == ZZ:
    94        return fastfrac(self.R,self.top * other,self.bot)
    95      if other.__class__ == fastfrac:
    96        return fastfrac(self.R,self.top * other.top,self.bot * other.bot)
    97      return NotImplemented
    98  
    99    def __rmul__(self,other):
   100      """Multiply something else with a fraction."""
   101      return self.__mul__(other)
   102  
   103    def __truediv__(self,other):
   104      """Divide two fractions."""
   105      if parent(other) == ZZ:
   106        return fastfrac(self.R,self.top,self.bot * other)
   107      if other.__class__ == fastfrac:
   108        return fastfrac(self.R,self.top * other.bot,self.bot * other.top)
   109      return NotImplemented
   110  
   111    # Compatibility wrapper for Sage versions based on Python 2
   112    def __div__(self,other):
   113       """Divide two fractions."""
   114       return self.__truediv__(other)
   115  
   116    def __pow__(self,other):
   117      """Compute a power of a fraction."""
   118      if parent(other) == ZZ:
   119        if other < 0:
   120          # Negative powers require flipping top and bottom
   121          return fastfrac(self.R,self.bot ^ (-other),self.top ^ (-other))
   122        else:
   123          return fastfrac(self.R,self.top ^ other,self.bot ^ other)
   124      return NotImplemented
   125  
   126    def __str__(self):
   127      return "fastfrac((" + str(self.top) + ") / (" + str(self.bot) + "))"
   128    def __repr__(self):
   129      return "%s" % self
   130  
   131    def numerator(self):
   132      return self.top
   133  
   134  class constraints:
   135    """A set of constraints, consisting of zero and nonzero expressions.
   136  
   137    Constraints can either be used to express knowledge or a requirement.
   138  
   139    Both the fields zero and nonzero are maps from expressions to description
   140    strings. The expressions that are the keys in zero are required to be zero,
   141    and the expressions that are the keys in nonzero are required to be nonzero.
   142  
   143    Note that (a != 0) and (b != 0) is the same as (a*b != 0), so all keys in
   144    nonzero could be multiplied into a single key. This is often much less
   145    efficient to work with though, so we keep them separate inside the
   146    constraints. This allows higher-level code to do fast checks on the individual
   147    nonzero elements, or combine them if needed for stronger checks.
   148  
   149    We can't multiply the different zero elements, as it would suffice for one of
   150    the factors to be zero, instead of all of them. Instead, the zero elements are
   151    typically combined into an ideal first.
   152    """
   153  
   154    def __init__(self, **kwargs):
   155      if 'zero' in kwargs:
   156        self.zero = dict(kwargs['zero'])
   157      else:
   158        self.zero = dict()
   159      if 'nonzero' in kwargs:
   160        self.nonzero = dict(kwargs['nonzero'])
   161      else:
   162        self.nonzero = dict()
   163  
   164    def negate(self):
   165      return constraints(zero=self.nonzero, nonzero=self.zero)
   166  
   167    def map(self, fun):
   168      return constraints(zero={fun(k): v for k, v in self.zero.items()}, nonzero={fun(k): v for k, v in self.nonzero.items()})
   169  
   170    def __add__(self, other):
   171      zero = self.zero.copy()
   172      zero.update(other.zero)
   173      nonzero = self.nonzero.copy()
   174      nonzero.update(other.nonzero)
   175      return constraints(zero=zero, nonzero=nonzero)
   176  
   177    def __str__(self):
   178      return "constraints(zero=%s,nonzero=%s)" % (self.zero, self.nonzero)
   179  
   180    def __repr__(self):
   181      return "%s" % self
   182  
   183  def normalize_factor(p):
   184    """Normalizes the sign of primitive polynomials (as returned by factor())
   185  
   186    This function ensures that the polynomial has a positive leading coefficient.
   187  
   188    This is necessary because recent sage versions (starting with v9.3 or v9.4,
   189    we don't know) are inconsistent about the placement of the minus sign in
   190    polynomial factorizations:
   191    ```
   192    sage: R.<ax,bx,ay,by,Az,Bz,Ai,Bi> = PolynomialRing(QQ,8,order='invlex')
   193    sage: R((-2 * (bx - ax)) ^ 1).factor()
   194    (-2) * (bx - ax)
   195    sage: R((-2 * (bx - ax)) ^ 2).factor()
   196    (4) * (-bx + ax)^2
   197    sage: R((-2 * (bx - ax)) ^ 3).factor()
   198    (8) * (-bx + ax)^3
   199    ```
   200    """
   201    # Assert p is not 0 and that its non-zero coefficients are coprime.
   202    # (We could just work with the primitive part p/p.content() but we want to be
   203    # aware if factor() does not return a primitive part in future sage versions.)
   204    assert p.content() == 1
   205    # Ensure that the first non-zero coefficient is positive.
   206    return p if p.lc() > 0 else -p
   207  
   208  def conflicts(R, con):
   209    """Check whether any of the passed non-zero assumptions is implied by the zero assumptions"""
   210    zero = R.ideal(list(map(numerator, con.zero)))
   211    if 1 in zero:
   212      return True
   213    # First a cheap check whether any of the individual nonzero terms conflict on
   214    # their own.
   215    for nonzero in con.nonzero:
   216      if nonzero.iszero(zero):
   217        return True
   218    # It can be the case that entries in the nonzero set do not individually
   219    # conflict with the zero set, but their combination does. For example, knowing
   220    # that either x or y is zero is equivalent to having x*y in the zero set.
   221    # Having x or y individually in the nonzero set is not a conflict, but both
   222    # simultaneously is, so that is the right thing to check for.
   223    if reduce(lambda a,b: a * b, con.nonzero, fastfrac(R, 1)).iszero(zero):
   224      return True
   225    return False
   226  
   227  
   228  def get_nonzero_set(R, assume):
   229    """Calculate a simple set of nonzero expressions"""
   230    zero = R.ideal(list(map(numerator, assume.zero)))
   231    nonzero = set()
   232    for nz in map(numerator, assume.nonzero):
   233      for (f,n) in nz.factor():
   234        nonzero.add(normalize_factor(f))
   235      rnz = zero.reduce(nz)
   236      for (f,n) in rnz.factor():
   237        nonzero.add(normalize_factor(f))
   238    return nonzero
   239  
   240  
   241  def prove_nonzero(R, exprs, assume):
   242    """Check whether an expression is provably nonzero, given assumptions"""
   243    zero = R.ideal(list(map(numerator, assume.zero)))
   244    nonzero = get_nonzero_set(R, assume)
   245    expl = set()
   246    ok = True
   247    for expr in exprs:
   248      if numerator(expr) in zero:
   249        return (False, [exprs[expr]])
   250    allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1)
   251    for (f, n) in allexprs.factor():
   252      if normalize_factor(f) not in nonzero:
   253        ok = False
   254    if ok:
   255      return (True, None)
   256    ok = True
   257    for (f, n) in zero.reduce(allexprs).factor():
   258      if normalize_factor(f) not in nonzero:
   259        ok = False
   260    if ok:
   261      return (True, None)
   262    ok = True
   263    for expr in exprs:
   264      for (f,n) in numerator(expr).factor():
   265        if normalize_factor(f) not in nonzero:
   266          ok = False
   267    if ok:
   268      return (True, None)
   269    ok = True
   270    for expr in exprs:
   271      for (f,n) in zero.reduce(numerator(expr)).factor():
   272        if normalize_factor(f) not in nonzero:
   273          expl.add(exprs[expr])
   274    if expl:
   275      return (False, list(expl))
   276    else:
   277      return (True, None)
   278  
   279  
   280  def prove_zero(R, exprs, assume):
   281    """Check whether all of the passed expressions are provably zero, given assumptions"""
   282    r, e = prove_nonzero(R, dict(map(lambda x: (fastfrac(R, x.bot, 1), exprs[x]), exprs)), assume)
   283    if not r:
   284      return (False, list(map(lambda x: "Possibly zero denominator: %s" % x, e)))
   285    zero = R.ideal(list(map(numerator, assume.zero)))
   286    nonzero = prod(x for x in assume.nonzero)
   287    expl = []
   288    for expr in exprs:
   289      if not expr.iszero(zero):
   290        expl.append(exprs[expr])
   291    if not expl:
   292      return (True, None)
   293    return (False, expl)
   294  
   295  
   296  def describe_extra(R, assume, assumeExtra):
   297    """Describe what assumptions are added, given existing assumptions"""
   298    zerox = assume.zero.copy()
   299    zerox.update(assumeExtra.zero)
   300    zero = R.ideal(list(map(numerator, assume.zero)))
   301    zeroextra = R.ideal(list(map(numerator, zerox)))
   302    nonzero = get_nonzero_set(R, assume)
   303    ret = set()
   304    # Iterate over the extra zero expressions
   305    for base in assumeExtra.zero:
   306      if base not in zero:
   307        add = []
   308        for (f, n) in numerator(base).factor():
   309          if normalize_factor(f) not in nonzero:
   310            add += ["%s" % normalize_factor(f)]
   311        if add:
   312          ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base])
   313    # Iterate over the extra nonzero expressions
   314    for nz in assumeExtra.nonzero:
   315      nzr = zeroextra.reduce(numerator(nz))
   316      if nzr not in zeroextra:
   317        for (f,n) in nzr.factor():
   318          if normalize_factor(zeroextra.reduce(f)) not in nonzero:
   319            ret.add("%s != 0" % normalize_factor(zeroextra.reduce(f)))
   320    return ", ".join(x for x in ret)
   321  
   322  
   323  def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require):
   324    """Check a set of zero and nonzero requirements, given a set of zero and nonzero assumptions"""
   325    assume = assumeLaw + assumeAssert + assumeBranch
   326  
   327    if conflicts(R, assume):
   328      # This formula does not apply
   329      return (True, None)
   330  
   331    describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert)
   332    if describe != "":
   333      describe = " (assuming " + describe + ")"
   334  
   335    ok, msg = prove_zero(R, require.zero, assume)
   336    if not ok:
   337      return (False, "FAIL, %s fails%s" % (str(msg), describe))
   338  
   339    res, expl = prove_nonzero(R, require.nonzero, assume)
   340    if not res:
   341      return (False, "FAIL, %s fails%s" % (str(expl), describe))
   342  
   343    return (True, "OK%s" % describe)
   344  
   345  
   346  def concrete_verify(c):
   347    for k in c.zero:
   348      if k != 0:
   349        return (False, c.zero[k])
   350    for k in c.nonzero:
   351      if k == 0:
   352        return (False, c.nonzero[k])
   353    return (True, None)