github.com/klaytn/klaytn@v1.12.1/crypto/secp256k1/libsecp256k1/sage/group_prover.sage (about)

     1  # This code supports verifying group implementations which have branches
     2  # or conditional statements (like cmovs), by allowing each execution path
     3  # to independently set assumptions on input or intermediary variables.
     4  #
     5  # The general approach is:
     6  # * A constraint is a tuple of two sets of of symbolic expressions:
     7  #   the first of which are required to evaluate to zero, the second of which
     8  #   are required to evaluate to nonzero.
     9  #   - A constraint is said to be conflicting if any of its nonzero expressions
    10  #     is in the ideal with basis the zero expressions (in other words: when the
    11  #     zero expressions imply that one of the nonzero expressions are zero).
    12  # * There is a list of laws that describe the intended behaviour, including
    13  #   laws for addition and doubling. Each law is called with the symbolic point
    14  #   coordinates as arguments, and returns:
    15  #   - A constraint describing the assumptions under which it is applicable,
    16  #     called "assumeLaw"
    17  #   - A constraint describing the requirements of the law, called "require"
    18  # * Implementations are transliterated into functions that operate as well on
    19  #   algebraic input points, and are called once per combination of branches
    20  #   exectured. Each execution returns:
    21  #   - A constraint describing the assumptions this implementation requires
    22  #     (such as Z1=1), called "assumeFormula"
    23  #   - A constraint describing the assumptions this specific branch requires,
    24  #     but which is by construction guaranteed to cover the entire space by
    25  #     merging the results from all branches, called "assumeBranch"
    26  #   - The result of the computation
    27  # * All combinations of laws with implementation branches are tried, and:
    28  #   - If the combination of assumeLaw, assumeFormula, and assumeBranch results
    29  #     in a conflict, it means this law does not apply to this branch, and it is
    30  #     skipped.
    31  #   - For others, we try to prove the require constraints hold, assuming the
    32  #     information in assumeLaw + assumeFormula + assumeBranch, and if this does
    33  #     not succeed, we fail.
    34  #     + To prove an expression is zero, we check whether it belongs to the
    35  #       ideal with the assumed zero expressions as basis. This test is exact.
    36  #     + To prove an expression is nonzero, we check whether each of its
    37  #       factors is contained in the set of nonzero assumptions' factors.
    38  #       This test is not exact, so various combinations of original and
    39  #       reduced expressions' factors are tried.
    40  #   - If we succeed, we print out the assumptions from assumeFormula that
    41  #     weren't implied by assumeLaw already. Those from assumeBranch are skipped,
    42  #     as we assume that all constraints in it are complementary with each other.
    43  #
    44  # Based on the sage verification scripts used in the Explicit-Formulas Database
    45  # by Tanja Lange and others, see http://hyperelliptic.org/EFD
    46  
    47  class fastfrac:
    48    """Fractions over rings."""
    49  
    50    def __init__(self,R,top,bot=1):
    51      """Construct a fractional, given a ring, a numerator, and denominator."""
    52      self.R = R
    53      if parent(top) == ZZ or parent(top) == R:
    54        self.top = R(top)
    55        self.bot = R(bot)
    56      elif top.__class__ == fastfrac:
    57        self.top = top.top
    58        self.bot = top.bot * bot
    59      else:
    60        self.top = R(numerator(top))
    61        self.bot = R(denominator(top)) * bot
    62  
    63    def iszero(self,I):
    64      """Return whether this fraction is zero given an ideal."""
    65      return self.top in I and self.bot not in I
    66  
    67    def reduce(self,assumeZero):
    68      zero = self.R.ideal(map(numerator, assumeZero))
    69      return fastfrac(self.R, zero.reduce(self.top)) / fastfrac(self.R, zero.reduce(self.bot))
    70  
    71    def __add__(self,other):
    72      """Add two fractions."""
    73      if parent(other) == ZZ:
    74        return fastfrac(self.R,self.top + self.bot * other,self.bot)
    75      if other.__class__ == fastfrac:
    76        return fastfrac(self.R,self.top * other.bot + self.bot * other.top,self.bot * other.bot)
    77      return NotImplemented
    78  
    79    def __sub__(self,other):
    80      """Subtract two fractions."""
    81      if parent(other) == ZZ:
    82        return fastfrac(self.R,self.top - self.bot * other,self.bot)
    83      if other.__class__ == fastfrac:
    84        return fastfrac(self.R,self.top * other.bot - self.bot * other.top,self.bot * other.bot)
    85      return NotImplemented
    86  
    87    def __neg__(self):
    88      """Return the negation of a fraction."""
    89      return fastfrac(self.R,-self.top,self.bot)
    90  
    91    def __mul__(self,other):
    92      """Multiply two fractions."""
    93      if parent(other) == ZZ:
    94        return fastfrac(self.R,self.top * other,self.bot)
    95      if other.__class__ == fastfrac:
    96        return fastfrac(self.R,self.top * other.top,self.bot * other.bot)
    97      return NotImplemented
    98  
    99    def __rmul__(self,other):
   100      """Multiply something else with a fraction."""
   101      return self.__mul__(other)
   102  
   103    def __div__(self,other):
   104      """Divide two fractions."""
   105      if parent(other) == ZZ:
   106        return fastfrac(self.R,self.top,self.bot * other)
   107      if other.__class__ == fastfrac:
   108        return fastfrac(self.R,self.top * other.bot,self.bot * other.top)
   109      return NotImplemented
   110  
   111    def __pow__(self,other):
   112      """Compute a power of a fraction."""
   113      if parent(other) == ZZ:
   114        if other < 0:
   115          # Negative powers require flipping top and bottom
   116          return fastfrac(self.R,self.bot ^ (-other),self.top ^ (-other))
   117        else:
   118          return fastfrac(self.R,self.top ^ other,self.bot ^ other)
   119      return NotImplemented
   120  
   121    def __str__(self):
   122      return "fastfrac((" + str(self.top) + ") / (" + str(self.bot) + "))"
   123    def __repr__(self):
   124      return "%s" % self
   125  
   126    def numerator(self):
   127      return self.top
   128  
   129  class constraints:
   130    """A set of constraints, consisting of zero and nonzero expressions.
   131  
   132    Constraints can either be used to express knowledge or a requirement.
   133  
   134    Both the fields zero and nonzero are maps from expressions to description
   135    strings. The expressions that are the keys in zero are required to be zero,
   136    and the expressions that are the keys in nonzero are required to be nonzero.
   137  
   138    Note that (a != 0) and (b != 0) is the same as (a*b != 0), so all keys in
   139    nonzero could be multiplied into a single key. This is often much less
   140    efficient to work with though, so we keep them separate inside the
   141    constraints. This allows higher-level code to do fast checks on the individual
   142    nonzero elements, or combine them if needed for stronger checks.
   143  
   144    We can't multiply the different zero elements, as it would suffice for one of
   145    the factors to be zero, instead of all of them. Instead, the zero elements are
   146    typically combined into an ideal first.
   147    """
   148  
   149    def __init__(self, **kwargs):
   150      if 'zero' in kwargs:
   151        self.zero = dict(kwargs['zero'])
   152      else:
   153        self.zero = dict()
   154      if 'nonzero' in kwargs:
   155        self.nonzero = dict(kwargs['nonzero'])
   156      else:
   157        self.nonzero = dict()
   158  
   159    def negate(self):
   160      return constraints(zero=self.nonzero, nonzero=self.zero)
   161  
   162    def __add__(self, other):
   163      zero = self.zero.copy()
   164      zero.update(other.zero)
   165      nonzero = self.nonzero.copy()
   166      nonzero.update(other.nonzero)
   167      return constraints(zero=zero, nonzero=nonzero)
   168  
   169    def __str__(self):
   170      return "constraints(zero=%s,nonzero=%s)" % (self.zero, self.nonzero)
   171  
   172    def __repr__(self):
   173      return "%s" % self
   174  
   175  
   176  def conflicts(R, con):
   177    """Check whether any of the passed non-zero assumptions is implied by the zero assumptions"""
   178    zero = R.ideal(map(numerator, con.zero))
   179    if 1 in zero:
   180      return True
   181    # First a cheap check whether any of the individual nonzero terms conflict on
   182    # their own.
   183    for nonzero in con.nonzero:
   184      if nonzero.iszero(zero):
   185        return True
   186    # It can be the case that entries in the nonzero set do not individually
   187    # conflict with the zero set, but their combination does. For example, knowing
   188    # that either x or y is zero is equivalent to having x*y in the zero set.
   189    # Having x or y individually in the nonzero set is not a conflict, but both
   190    # simultaneously is, so that is the right thing to check for.
   191    if reduce(lambda a,b: a * b, con.nonzero, fastfrac(R, 1)).iszero(zero):
   192      return True
   193    return False
   194  
   195  
   196  def get_nonzero_set(R, assume):
   197    """Calculate a simple set of nonzero expressions"""
   198    zero = R.ideal(map(numerator, assume.zero))
   199    nonzero = set()
   200    for nz in map(numerator, assume.nonzero):
   201      for (f,n) in nz.factor():
   202        nonzero.add(f)
   203      rnz = zero.reduce(nz)
   204      for (f,n) in rnz.factor():
   205        nonzero.add(f)
   206    return nonzero
   207  
   208  
   209  def prove_nonzero(R, exprs, assume):
   210    """Check whether an expression is provably nonzero, given assumptions"""
   211    zero = R.ideal(map(numerator, assume.zero))
   212    nonzero = get_nonzero_set(R, assume)
   213    expl = set()
   214    ok = True
   215    for expr in exprs:
   216      if numerator(expr) in zero:
   217        return (False, [exprs[expr]])
   218    allexprs = reduce(lambda a,b: numerator(a)*numerator(b), exprs, 1)
   219    for (f, n) in allexprs.factor():
   220      if f not in nonzero:
   221        ok = False
   222    if ok:
   223      return (True, None)
   224    ok = True
   225    for (f, n) in zero.reduce(numerator(allexprs)).factor():
   226      if f not in nonzero:
   227        ok = False
   228    if ok:
   229      return (True, None)
   230    ok = True
   231    for expr in exprs:
   232      for (f,n) in numerator(expr).factor():
   233        if f not in nonzero:
   234          ok = False
   235    if ok:
   236      return (True, None)
   237    ok = True
   238    for expr in exprs:
   239      for (f,n) in zero.reduce(numerator(expr)).factor():
   240        if f not in nonzero:
   241          expl.add(exprs[expr])
   242    if expl:
   243      return (False, list(expl))
   244    else:
   245      return (True, None)
   246  
   247  
   248  def prove_zero(R, exprs, assume):
   249    """Check whether all of the passed expressions are provably zero, given assumptions"""
   250    r, e = prove_nonzero(R, dict(map(lambda x: (fastfrac(R, x.bot, 1), exprs[x]), exprs)), assume)
   251    if not r:
   252      return (False, map(lambda x: "Possibly zero denominator: %s" % x, e))
   253    zero = R.ideal(map(numerator, assume.zero))
   254    nonzero = prod(x for x in assume.nonzero)
   255    expl = []
   256    for expr in exprs:
   257      if not expr.iszero(zero):
   258        expl.append(exprs[expr])
   259    if not expl:
   260      return (True, None)
   261    return (False, expl)
   262  
   263  
   264  def describe_extra(R, assume, assumeExtra):
   265    """Describe what assumptions are added, given existing assumptions"""
   266    zerox = assume.zero.copy()
   267    zerox.update(assumeExtra.zero)
   268    zero = R.ideal(map(numerator, assume.zero))
   269    zeroextra = R.ideal(map(numerator, zerox))
   270    nonzero = get_nonzero_set(R, assume)
   271    ret = set()
   272    # Iterate over the extra zero expressions
   273    for base in assumeExtra.zero:
   274      if base not in zero:
   275        add = []
   276        for (f, n) in numerator(base).factor():
   277          if f not in nonzero:
   278            add += ["%s" % f]
   279        if add:
   280          ret.add((" * ".join(add)) + " = 0 [%s]" % assumeExtra.zero[base])
   281    # Iterate over the extra nonzero expressions
   282    for nz in assumeExtra.nonzero:
   283      nzr = zeroextra.reduce(numerator(nz))
   284      if nzr not in zeroextra:
   285        for (f,n) in nzr.factor():
   286          if zeroextra.reduce(f) not in nonzero:
   287            ret.add("%s != 0" % zeroextra.reduce(f))
   288    return ", ".join(x for x in ret)
   289  
   290  
   291  def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require):
   292    """Check a set of zero and nonzero requirements, given a set of zero and nonzero assumptions"""
   293    assume = assumeLaw + assumeAssert + assumeBranch
   294  
   295    if conflicts(R, assume):
   296      # This formula does not apply
   297      return None
   298  
   299    describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert)
   300  
   301    ok, msg = prove_zero(R, require.zero, assume)
   302    if not ok:
   303      return "FAIL, %s fails (assuming %s)" % (str(msg), describe)
   304  
   305    res, expl = prove_nonzero(R, require.nonzero, assume)
   306    if not res:
   307      return "FAIL, %s fails (assuming %s)" % (str(expl), describe)
   308  
   309    if describe != "":
   310      return "OK (assuming %s)" % describe
   311    else:
   312      return "OK"
   313  
   314  
   315  def concrete_verify(c):
   316    for k in c.zero:
   317      if k != 0:
   318        return (False, c.zero[k])
   319    for k in c.nonzero:
   320      if k == 0:
   321        return (False, c.nonzero[k])
   322    return (True, None)