github.com/cloudflare/circl@v1.5.0/dh/x448/curve_test.go (about)

     1  package x448
     2  
     3  import (
     4  	"crypto/rand"
     5  	"math/big"
     6  	"testing"
     7  
     8  	"github.com/cloudflare/circl/internal/conv"
     9  	"github.com/cloudflare/circl/internal/test"
    10  	fp "github.com/cloudflare/circl/math/fp448"
    11  )
    12  
    13  func getModulus() *big.Int {
    14  	p := big.NewInt(1)
    15  	p.Lsh(p, 224)
    16  	p.Sub(p, new(big.Int).SetInt64(1))
    17  	p.Lsh(p, 224)
    18  	p.Sub(p, new(big.Int).SetInt64(1))
    19  	return p
    20  }
    21  
    22  func doubleBig(x1, z1, p *big.Int) {
    23  	// Montgomery point doubling in projective (X:Z) coordinates.
    24  	A24 := big.NewInt(39082)
    25  	A, B, C := big.NewInt(0), big.NewInt(0), big.NewInt(0)
    26  	A.Add(x1, z1).Mod(A, p)
    27  	B.Sub(x1, z1).Mod(B, p)
    28  	A.Mul(A, A)
    29  	B.Mul(B, B)
    30  	C.Sub(A, B)
    31  	x1.Mul(A, B).Mod(x1, p)
    32  	z1.Mul(C, A24).Add(z1, B).Mul(z1, C).Mod(z1, p)
    33  }
    34  
    35  // Equation 7 at https://eprint.iacr.org/2017/264
    36  func diffAddBig(work [5]*big.Int, p *big.Int, b uint) {
    37  	mu, x1, z1, x2, z2 := work[0], work[1], work[2], work[3], work[4]
    38  	A, B := big.NewInt(0), big.NewInt(0)
    39  	if b != 0 {
    40  		t := new(big.Int)
    41  		t.Set(x1)
    42  		x1.Set(x2)
    43  		x2.Set(t)
    44  		t.Set(z1)
    45  		z1.Set(z2)
    46  		z2.Set(t)
    47  	}
    48  	A.Add(x1, z1)
    49  	B.Sub(x1, z1)
    50  	B.Mul(B, mu).Mod(B, p)
    51  	x1.Add(A, B).Mod(x1, p)
    52  	z1.Sub(A, B).Mod(z1, p)
    53  	x1.Mul(x1, x1).Mul(x1, z2).Mod(x1, p)
    54  	z1.Mul(z1, z1).Mul(z1, x2).Mod(z1, p)
    55  	mu.Mod(mu, p)
    56  	x2.Mod(x2, p)
    57  	z2.Mod(z2, p)
    58  }
    59  
    60  func ladderStepBig(work [5]*big.Int, p *big.Int, b uint) {
    61  	A24 := big.NewInt(39082)
    62  	x1 := work[0]
    63  	x2, z2 := work[1], work[2]
    64  	x3, z3 := work[3], work[4]
    65  	A, B, C, D := big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)
    66  	DA, CB, E := big.NewInt(0), big.NewInt(0), big.NewInt(0)
    67  	A.Add(x2, z2).Mod(A, p)
    68  	B.Sub(x2, z2).Mod(B, p)
    69  	C.Add(x3, z3).Mod(C, p)
    70  	D.Sub(x3, z3).Mod(D, p)
    71  	DA.Mul(D, A).Mod(DA, p)
    72  	CB.Mul(C, B).Mod(CB, p)
    73  	if b != 0 {
    74  		t := new(big.Int)
    75  		t.Set(A)
    76  		A.Set(C)
    77  		C.Set(t)
    78  		t.Set(B)
    79  		B.Set(D)
    80  		D.Set(t)
    81  	}
    82  	AA := A.Mul(A, A).Mod(A, p)
    83  	BB := B.Mul(B, B).Mod(B, p)
    84  	E.Sub(AA, BB)
    85  	x1.Mod(x1, p)
    86  	x2.Mul(AA, BB).Mod(x2, p)
    87  	z2.Mul(E, A24).Add(z2, BB).Mul(z2, E).Mod(z2, p)
    88  	x3.Add(DA, CB)
    89  	z3.Sub(DA, CB)
    90  	x3.Mul(x3, x3).Mod(x3, p)
    91  	z3.Mul(z3, z3).Mul(z3, x1).Mod(z3, p)
    92  }
    93  
    94  type (
    95  	tDouble     func(x, z *fp.Elt)
    96  	tDiffAdd    func(w *[5]fp.Elt, b uint)
    97  	tLadderStep func(w *[5]fp.Elt, b uint)
    98  	tMulA24     func(z, x *fp.Elt)
    99  )
   100  
   101  func TestGeneric(t *testing.T) {
   102  	testDouble(t, doubleGeneric)
   103  	testDiffAdd(t, diffAddGeneric)
   104  	testLadderStep(t, ladderStepGeneric)
   105  	testMulA24(t, mulA24Generic)
   106  }
   107  
   108  func TestNative(t *testing.T) {
   109  	testDouble(t, double)
   110  	testDiffAdd(t, diffAdd)
   111  	testLadderStep(t, ladderStep)
   112  	testMulA24(t, mulA24)
   113  }
   114  
   115  func testMulA24(t *testing.T, f tMulA24) {
   116  	const numTests = 1 << 9
   117  	var x, z fp.Elt
   118  	p := getModulus()
   119  	A24 := big.NewInt(39082)
   120  	for i := 0; i < numTests; i++ {
   121  		_, _ = rand.Read(x[:])
   122  		bigX := conv.BytesLe2BigInt(x[:])
   123  		f(&z, &x)
   124  		got := conv.BytesLe2BigInt(z[:])
   125  		got.Mod(got, p)
   126  
   127  		want := bigX.Mul(bigX, A24).Mod(bigX, p)
   128  
   129  		if got.Cmp(want) != 0 {
   130  			test.ReportError(t, got, want, x)
   131  		}
   132  	}
   133  }
   134  
   135  func testDouble(t *testing.T, f tDouble) {
   136  	const numTests = 1 << 9
   137  	var x, z fp.Elt
   138  	p := getModulus()
   139  	for i := 0; i < numTests; i++ {
   140  		_, _ = rand.Read(x[:])
   141  		_, _ = rand.Read(z[:])
   142  
   143  		bigX := conv.BytesLe2BigInt(x[:])
   144  		bigZ := conv.BytesLe2BigInt(z[:])
   145  		f(&x, &z)
   146  		got0 := conv.BytesLe2BigInt(x[:])
   147  		got1 := conv.BytesLe2BigInt(z[:])
   148  		got0.Mod(got0, p)
   149  		got1.Mod(got1, p)
   150  
   151  		doubleBig(bigX, bigZ, p)
   152  		want0 := bigX
   153  		want1 := bigZ
   154  
   155  		if got0.Cmp(want0) != 0 {
   156  			test.ReportError(t, got0, want0, x, z)
   157  		}
   158  		if got1.Cmp(want1) != 0 {
   159  			test.ReportError(t, got1, want1, x, z)
   160  		}
   161  	}
   162  }
   163  
   164  func testDiffAdd(t *testing.T, f tDiffAdd) {
   165  	const numTests = 1 << 9
   166  	var w [5]fp.Elt
   167  	bigWork := [5]*big.Int{}
   168  	p := getModulus()
   169  	for i := 0; i < numTests; i++ {
   170  		for j := range w {
   171  			_, _ = rand.Read(w[j][:])
   172  			bigWork[j] = conv.BytesLe2BigInt(w[j][:])
   173  		}
   174  		b := uint(w[0][0] & 1)
   175  
   176  		f(&w, b)
   177  
   178  		diffAddBig(bigWork, p, b)
   179  
   180  		for j := range bigWork {
   181  			got := conv.BytesLe2BigInt(w[j][:])
   182  			got.Mod(got, p)
   183  			want := bigWork[j]
   184  			if got.Cmp(want) != 0 {
   185  				test.ReportError(t, got, want, w, b)
   186  			}
   187  		}
   188  	}
   189  }
   190  
   191  func testLadderStep(t *testing.T, f tLadderStep) {
   192  	const numTests = 1 << 9
   193  	var w [5]fp.Elt
   194  	bigWork := [5]*big.Int{}
   195  	p := getModulus()
   196  	for i := 0; i < numTests; i++ {
   197  		for j := range w {
   198  			_, _ = rand.Read(w[j][:])
   199  			bigWork[j] = conv.BytesLe2BigInt(w[j][:])
   200  		}
   201  		b := uint(w[0][0] & 1)
   202  
   203  		f(&w, b)
   204  
   205  		ladderStepBig(bigWork, p, b)
   206  
   207  		for j := range bigWork {
   208  			got := conv.BytesLe2BigInt(w[j][:])
   209  			got.Mod(got, p)
   210  			want := bigWork[j]
   211  			if got.Cmp(want) != 0 {
   212  				test.ReportError(t, got, want, w, b)
   213  			}
   214  		}
   215  	}
   216  }