github.com/emmansun/gmsm@v0.29.1/internal/sm2ec/generate.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build ignore 6 7 package main 8 9 // Running this generator requires addchain v0.4.0, which can be installed with 10 // 11 // go install github.com/mmcloughlin/addchain/cmd/addchain@v0.4.0 12 // 13 14 import ( 15 "bytes" 16 "crypto/elliptic" 17 "fmt" 18 "go/format" 19 "io" 20 "log" 21 "math/big" 22 "os" 23 "os/exec" 24 "strings" 25 "text/template" 26 27 _sm2ec "github.com/emmansun/gmsm/sm2/sm2ec" 28 ) 29 30 var curves = []struct { 31 P string 32 Element string 33 Params *elliptic.CurveParams 34 BuildTags string 35 }{ 36 { 37 P: "SM2P256", 38 Element: "fiat.SM2P256Element", 39 Params: _sm2ec.P256().Params(), 40 BuildTags: "purego || !(amd64 || arm64)", 41 }, 42 } 43 44 func main() { 45 t := template.Must(template.New("tmplNISTEC").Parse(tmplNISTEC)) 46 47 tmplAddchainFile, err := os.CreateTemp("", "addchain-template") 48 if err != nil { 49 log.Fatal(err) 50 } 51 defer os.Remove(tmplAddchainFile.Name()) 52 if _, err := io.WriteString(tmplAddchainFile, tmplAddchain); err != nil { 53 log.Fatal(err) 54 } 55 if err := tmplAddchainFile.Close(); err != nil { 56 log.Fatal(err) 57 } 58 59 for _, c := range curves { 60 p := strings.ToLower(c.P) 61 elementLen := (c.Params.BitSize + 7) / 8 62 B := fmt.Sprintf("%#v", c.Params.B.FillBytes(make([]byte, elementLen))) 63 Gx := fmt.Sprintf("%#v", c.Params.Gx.FillBytes(make([]byte, elementLen))) 64 Gy := fmt.Sprintf("%#v", c.Params.Gy.FillBytes(make([]byte, elementLen))) 65 66 log.Printf("Generating %s.go...", p) 67 f, err := os.Create(p + ".go") 68 if err != nil { 69 log.Fatal(err) 70 } 71 defer f.Close() 72 buf := &bytes.Buffer{} 73 if err := t.Execute(buf, map[string]any{ 74 "P": c.P, "p": p, "B": B, "Gx": Gx, "Gy": Gy, 75 "Element": c.Element, "ElementLen": elementLen, 76 "BuildTags": c.BuildTags, 77 }); err != nil { 78 log.Fatal(err) 79 } 80 out, err := format.Source(buf.Bytes()) 81 if err != nil { 82 log.Fatal(err) 83 } 84 if _, err := f.Write(out); err != nil { 85 log.Fatal(err) 86 } 87 88 // If p = 3 mod 4, implement modular square root by exponentiation. 89 mod4 := new(big.Int).Mod(c.Params.P, big.NewInt(4)) 90 if mod4.Cmp(big.NewInt(3)) != 0 { 91 continue 92 } 93 94 exp := new(big.Int).Add(c.Params.P, big.NewInt(1)) 95 exp.Div(exp, big.NewInt(4)) 96 97 tmp, err := os.CreateTemp("", "addchain-"+p) 98 if err != nil { 99 log.Fatal(err) 100 } 101 defer os.Remove(tmp.Name()) 102 cmd := exec.Command("addchain", "search", fmt.Sprintf("%d", exp)) 103 cmd.Stderr = os.Stderr 104 cmd.Stdout = tmp 105 if err := cmd.Run(); err != nil { 106 log.Fatal(err) 107 } 108 if err := tmp.Close(); err != nil { 109 log.Fatal(err) 110 } 111 cmd = exec.Command("addchain", "gen", "-tmpl", tmplAddchainFile.Name(), tmp.Name()) 112 cmd.Stderr = os.Stderr 113 out, err = cmd.Output() 114 if err != nil { 115 log.Fatal(err) 116 } 117 out = bytes.Replace(out, []byte("Element"), []byte(c.Element), -1) 118 out = bytes.Replace(out, []byte("sqrtCandidate"), []byte(p+"SqrtCandidate"), -1) 119 out, err = format.Source(out) 120 if err != nil { 121 log.Fatal(err) 122 } 123 if _, err := f.Write(out); err != nil { 124 log.Fatal(err) 125 } 126 } 127 } 128 129 const tmplNISTEC = `// Copyright 2022 The Go Authors. All rights reserved. 130 // Use of this source code is governed by a BSD-style 131 // license that can be found in the LICENSE file. 132 // Code generated by generate.go. DO NOT EDIT. 133 134 {{ if .BuildTags }} 135 //go:build {{ .BuildTags }} 136 {{ end }} 137 138 package sm2ec 139 140 import ( 141 "github.com/emmansun/gmsm/internal/sm2ec/fiat" 142 "crypto/subtle" 143 "errors" 144 "sync" 145 ) 146 147 // {{.p}}ElementLength is the length of an element of the base or scalar field, 148 // which have the same bytes length for all NIST P curves. 149 const {{.p}}ElementLength = {{ .ElementLen }} 150 151 // {{.P}}Point is a {{.P}} point. The zero value is NOT valid. 152 type {{.P}}Point struct { 153 // The point is represented in projective coordinates (X:Y:Z), 154 // where x = X/Z and y = Y/Z. 155 x, y, z *{{.Element}} 156 } 157 158 // New{{.P}}Point returns a new {{.P}}Point representing the point at infinity point. 159 func New{{.P}}Point() *{{.P}}Point { 160 return &{{.P}}Point{ 161 x: new({{.Element}}), 162 y: new({{.Element}}).One(), 163 z: new({{.Element}}), 164 } 165 } 166 167 // SetGenerator sets p to the canonical generator and returns p. 168 func (p *{{.P}}Point) SetGenerator() *{{.P}}Point { 169 p.x.SetBytes({{.Gx}}) 170 p.y.SetBytes({{.Gy}}) 171 p.z.One() 172 return p 173 } 174 175 // Set sets p = q and returns p. 176 func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point { 177 p.x.Set(q.x) 178 p.y.Set(q.y) 179 p.z.Set(q.z) 180 return p 181 } 182 183 // SetBytes sets p to the compressed, uncompressed, or infinity value encoded in 184 // b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on 185 // the curve, it returns nil and an error, and the receiver is unchanged. 186 // Otherwise, it returns p. 187 func (p *{{.P}}Point) SetBytes(b []byte) (*{{.P}}Point, error) { 188 switch { 189 // Point at infinity. 190 case len(b) == 1 && b[0] == 0: 191 return p.Set(New{{.P}}Point()), nil 192 // Uncompressed form. 193 case len(b) == 1+2*{{.p}}ElementLength && b[0] == 4: 194 x, err := new({{.Element}}).SetBytes(b[1 : 1+{{.p}}ElementLength]) 195 if err != nil { 196 return nil, err 197 } 198 y, err := new({{.Element}}).SetBytes(b[1+{{.p}}ElementLength:]) 199 if err != nil { 200 return nil, err 201 } 202 if err := {{.p}}CheckOnCurve(x, y); err != nil { 203 return nil, err 204 } 205 p.x.Set(x) 206 p.y.Set(y) 207 p.z.One() 208 return p, nil 209 // Compressed form. 210 case len(b) == 1+{{.p}}ElementLength && (b[0] == 2 || b[0] == 3): 211 x, err := new({{.Element}}).SetBytes(b[1:]) 212 if err != nil { 213 return nil, err 214 } 215 // y² = x³ - 3x + b 216 y := {{.p}}Polynomial(new({{.Element}}), x) 217 if !{{.p}}Sqrt(y, y) { 218 return nil, errors.New("invalid {{.P}} compressed point encoding") 219 } 220 // Select the positive or negative root, as indicated by the least 221 // significant bit, based on the encoding type byte. 222 otherRoot := new({{.Element}}) 223 otherRoot.Sub(otherRoot, y) 224 cond := y.Bytes()[{{.p}}ElementLength-1]&1 ^ b[0]&1 225 y.Select(otherRoot, y, int(cond)) 226 p.x.Set(x) 227 p.y.Set(y) 228 p.z.One() 229 return p, nil 230 default: 231 return nil, errors.New("invalid {{.P}} point encoding") 232 } 233 } 234 235 var _{{.p}}B *{{.Element}} 236 var _{{.p}}BOnce sync.Once 237 func {{.p}}B() *{{.Element}} { 238 _{{.p}}BOnce.Do(func() { 239 _{{.p}}B, _ = new({{.Element}}).SetBytes({{.B}}) 240 }) 241 return _{{.p}}B 242 } 243 244 // {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2. 245 func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} { 246 y2.Square(x) 247 y2.Mul(y2, x) 248 249 threeX := new({{.Element}}).Add(x, x) 250 threeX.Add(threeX, x) 251 252 y2.Sub(y2, threeX) 253 254 return y2.Add(y2, {{.p}}B()) 255 } 256 257 func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { 258 // y² = x³ - 3x + b 259 rhs := {{.p}}Polynomial(new({{.Element}}), x) 260 lhs := new({{.Element}}).Square(y) 261 if rhs.Equal(lhs) != 1 { 262 return errors.New("{{.P}} point not on curve") 263 } 264 return nil 265 } 266 267 // Bytes returns the uncompressed or infinity encoding of p, as specified in 268 // SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at 269 // infinity is shorter than all other encodings. 270 func (p *{{.P}}Point) Bytes() []byte { 271 // This function is outlined to make the allocations inline in the caller 272 // rather than happen on the heap. 273 var out [1+2*{{.p}}ElementLength]byte 274 return p.bytes(&out) 275 } 276 277 func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte { 278 if p.z.IsZero() == 1 { 279 return append(out[:0], 0) 280 } 281 zinv := new({{.Element}}).Invert(p.z) 282 x := new({{.Element}}).Mul(p.x, zinv) 283 y := new({{.Element}}).Mul(p.y, zinv) 284 buf := append(out[:0], 4) 285 buf = append(buf, x.Bytes()...) 286 buf = append(buf, y.Bytes()...) 287 return buf 288 } 289 290 // BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1, 291 // Version 2.0, Section 2.3.5, or an error if p is the point at infinity. 292 func (p *{{.P}}Point) BytesX() ([]byte, error) { 293 // This function is outlined to make the allocations inline in the caller 294 // rather than happen on the heap. 295 var out [{{.p}}ElementLength]byte 296 return p.bytesX(&out) 297 } 298 299 func (p *{{.P}}Point) bytesX(out *[{{.p}}ElementLength]byte) ([]byte, error) { 300 if p.z.IsZero() == 1 { 301 return nil, errors.New("{{.P}} point is the point at infinity") 302 } 303 zinv := new({{.Element}}).Invert(p.z) 304 x := new({{.Element}}).Mul(p.x, zinv) 305 return append(out[:0], x.Bytes()...), nil 306 } 307 308 // BytesCompressed returns the compressed or infinity encoding of p, as 309 // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the 310 // point at infinity is shorter than all other encodings. 311 func (p *{{.P}}Point) BytesCompressed() []byte { 312 // This function is outlined to make the allocations inline in the caller 313 // rather than happen on the heap. 314 var out [1 + {{.p}}ElementLength]byte 315 return p.bytesCompressed(&out) 316 } 317 318 func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte { 319 if p.z.IsZero() == 1 { 320 return append(out[:0], 0) 321 } 322 zinv := new({{.Element}}).Invert(p.z) 323 x := new({{.Element}}).Mul(p.x, zinv) 324 y := new({{.Element}}).Mul(p.y, zinv) 325 // Encode the sign of the y coordinate (indicated by the least significant 326 // bit) as the encoding type (2 or 3). 327 buf := append(out[:0], 2) 328 buf[0] |= y.Bytes()[{{.p}}ElementLength-1] & 1 329 buf = append(buf, x.Bytes()...) 330 return buf 331 } 332 333 // Add sets q = p1 + p2, and returns q. The points may overlap. 334 func (q *{{.P}}Point) Add(p1, p2 *{{.P}}Point) *{{.P}}Point { 335 // Complete addition formula for a = -3 from "Complete addition formulas for 336 // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. 337 t0 := new({{.Element}}).Mul(p1.x, p2.x) // t0 := X1 * X2 338 t1 := new({{.Element}}).Mul(p1.y, p2.y) // t1 := Y1 * Y2 339 t2 := new({{.Element}}).Mul(p1.z, p2.z) // t2 := Z1 * Z2 340 t3 := new({{.Element}}).Add(p1.x, p1.y) // t3 := X1 + Y1 341 t4 := new({{.Element}}).Add(p2.x, p2.y) // t4 := X2 + Y2 342 t3.Mul(t3, t4) // t3 := t3 * t4 343 t4.Add(t0, t1) // t4 := t0 + t1 344 t3.Sub(t3, t4) // t3 := t3 - t4 345 t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 346 x3 := new({{.Element}}).Add(p2.y, p2.z) // X3 := Y2 + Z2 347 t4.Mul(t4, x3) // t4 := t4 * X3 348 x3.Add(t1, t2) // X3 := t1 + t2 349 t4.Sub(t4, x3) // t4 := t4 - X3 350 x3.Add(p1.x, p1.z) // X3 := X1 + Z1 351 y3 := new({{.Element}}).Add(p2.x, p2.z) // Y3 := X2 + Z2 352 x3.Mul(x3, y3) // X3 := X3 * Y3 353 y3.Add(t0, t2) // Y3 := t0 + t2 354 y3.Sub(x3, y3) // Y3 := X3 - Y3 355 z3 := new({{.Element}}).Mul({{.p}}B(), t2) // Z3 := b * t2 356 x3.Sub(y3, z3) // X3 := Y3 - Z3 357 z3.Add(x3, x3) // Z3 := X3 + X3 358 x3.Add(x3, z3) // X3 := X3 + Z3 359 z3.Sub(t1, x3) // Z3 := t1 - X3 360 x3.Add(t1, x3) // X3 := t1 + X3 361 y3.Mul({{.p}}B(), y3) // Y3 := b * Y3 362 t1.Add(t2, t2) // t1 := t2 + t2 363 t2.Add(t1, t2) // t2 := t1 + t2 364 y3.Sub(y3, t2) // Y3 := Y3 - t2 365 y3.Sub(y3, t0) // Y3 := Y3 - t0 366 t1.Add(y3, y3) // t1 := Y3 + Y3 367 y3.Add(t1, y3) // Y3 := t1 + Y3 368 t1.Add(t0, t0) // t1 := t0 + t0 369 t0.Add(t1, t0) // t0 := t1 + t0 370 t0.Sub(t0, t2) // t0 := t0 - t2 371 t1.Mul(t4, y3) // t1 := t4 * Y3 372 t2.Mul(t0, y3) // t2 := t0 * Y3 373 y3.Mul(x3, z3) // Y3 := X3 * Z3 374 y3.Add(y3, t2) // Y3 := Y3 + t2 375 x3.Mul(t3, x3) // X3 := t3 * X3 376 x3.Sub(x3, t1) // X3 := X3 - t1 377 z3.Mul(t4, z3) // Z3 := t4 * Z3 378 t1.Mul(t3, t0) // t1 := t3 * t0 379 z3.Add(z3, t1) // Z3 := Z3 + t1 380 381 q.x.Set(x3) 382 q.y.Set(y3) 383 q.z.Set(z3) 384 return q 385 } 386 387 // Double sets q = p + p, and returns q. The points may overlap. 388 func (q *{{.P}}Point) Double(p *{{.P}}Point) *{{.P}}Point { 389 // Complete addition formula for a = -3 from "Complete addition formulas for 390 // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. 391 t0 := new({{.Element}}).Square(p.x) // t0 := X ^ 2 392 t1 := new({{.Element}}).Square(p.y) // t1 := Y ^ 2 393 t2 := new({{.Element}}).Square(p.z) // t2 := Z ^ 2 394 t3 := new({{.Element}}).Mul(p.x, p.y) // t3 := X * Y 395 t3.Add(t3, t3) // t3 := t3 + t3 396 z3 := new({{.Element}}).Mul(p.x, p.z) // Z3 := X * Z 397 z3.Add(z3, z3) // Z3 := Z3 + Z3 398 y3 := new({{.Element}}).Mul({{.p}}B(), t2) // Y3 := b * t2 399 y3.Sub(y3, z3) // Y3 := Y3 - Z3 400 x3 := new({{.Element}}).Add(y3, y3) // X3 := Y3 + Y3 401 y3.Add(x3, y3) // Y3 := X3 + Y3 402 x3.Sub(t1, y3) // X3 := t1 - Y3 403 y3.Add(t1, y3) // Y3 := t1 + Y3 404 y3.Mul(x3, y3) // Y3 := X3 * Y3 405 x3.Mul(x3, t3) // X3 := X3 * t3 406 t3.Add(t2, t2) // t3 := t2 + t2 407 t2.Add(t2, t3) // t2 := t2 + t3 408 z3.Mul({{.p}}B(), z3) // Z3 := b * Z3 409 z3.Sub(z3, t2) // Z3 := Z3 - t2 410 z3.Sub(z3, t0) // Z3 := Z3 - t0 411 t3.Add(z3, z3) // t3 := Z3 + Z3 412 z3.Add(z3, t3) // Z3 := Z3 + t3 413 t3.Add(t0, t0) // t3 := t0 + t0 414 t0.Add(t3, t0) // t0 := t3 + t0 415 t0.Sub(t0, t2) // t0 := t0 - t2 416 t0.Mul(t0, z3) // t0 := t0 * Z3 417 y3.Add(y3, t0) // Y3 := Y3 + t0 418 t0.Mul(p.y, p.z) // t0 := Y * Z 419 t0.Add(t0, t0) // t0 := t0 + t0 420 z3.Mul(t0, z3) // Z3 := t0 * Z3 421 x3.Sub(x3, z3) // X3 := X3 - Z3 422 z3.Mul(t0, t1) // Z3 := t0 * t1 423 z3.Add(z3, z3) // Z3 := Z3 + Z3 424 z3.Add(z3, z3) // Z3 := Z3 + Z3 425 426 q.x.Set(x3) 427 q.y.Set(y3) 428 q.z.Set(z3) 429 return q 430 } 431 432 // Select sets q to p1 if cond == 1, and to p2 if cond == 0. 433 func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point { 434 q.x.Select(p1.x, p2.x, cond) 435 q.y.Select(p1.y, p2.y, cond) 436 q.z.Select(p1.z, p2.z, cond) 437 return q 438 } 439 440 // A {{.p}}Table holds the first 15 multiples of a point at offset -1, so [1]P 441 // is at table[0], [15]P is at table[14], and [0]P is implicitly the identity 442 // point. 443 type {{.p}}Table [15]*{{.P}}Point 444 445 // Select selects the n-th multiple of the table base point into p. It works in 446 // constant time by iterating over every entry of the table. n must be in [0, 15]. 447 func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) { 448 if n >= 16 { 449 panic("sm2ec: internal error: {{.p}}Table called with out-of-bounds value") 450 } 451 p.Set(New{{.P}}Point()) 452 for i := uint8(1); i < 16; i++ { 453 cond := subtle.ConstantTimeByteEq(i, n) 454 p.Select(table[i-1], p, cond) 455 } 456 } 457 458 // ScalarMult sets p = scalar * q, and returns p. 459 func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, error) { 460 // Compute a {{.p}}Table for the base point q. The explicit New{{.P}}Point 461 // calls get inlined, letting the allocations live on the stack. 462 var table = {{.p}}Table{New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), 463 New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), 464 New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), 465 New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point(), New{{.P}}Point()} 466 table[0].Set(q) 467 for i := 1; i < 15; i += 2 { 468 table[i].Double(table[i/2]) 469 table[i+1].Add(table[i], q) 470 } 471 // Instead of doing the classic double-and-add chain, we do it with a 472 // four-bit window: we double four times, and then add [0-15]P. 473 t := New{{.P}}Point() 474 p.Set(New{{.P}}Point()) 475 for i, byte := range scalar { 476 // No need to double on the first iteration, as p is the identity at 477 // this point, and [N]∞ = ∞. 478 if i != 0 { 479 p.Double(p) 480 p.Double(p) 481 p.Double(p) 482 p.Double(p) 483 } 484 windowValue := byte >> 4 485 table.Select(t, windowValue) 486 p.Add(p, t) 487 p.Double(p) 488 p.Double(p) 489 p.Double(p) 490 p.Double(p) 491 windowValue = byte & 0b1111 492 table.Select(t, windowValue) 493 p.Add(p, t) 494 } 495 return p, nil 496 } 497 498 var {{.p}}GeneratorTable *[{{.p}}ElementLength * 2]{{.p}}Table 499 var {{.p}}GeneratorTableOnce sync.Once 500 501 // generatorTable returns a sequence of {{.p}}Tables. The first table contains 502 // multiples of G. Each successive table is the previous table doubled four 503 // times. 504 func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table { 505 {{.p}}GeneratorTableOnce.Do(func() { 506 {{.p}}GeneratorTable = new([{{.p}}ElementLength * 2]{{.p}}Table) 507 base := New{{.P}}Point().SetGenerator() 508 for i := 0; i < {{.p}}ElementLength*2; i++ { 509 {{.p}}GeneratorTable[i][0] = New{{.P}}Point().Set(base) 510 for j := 1; j < 15; j++ { 511 {{.p}}GeneratorTable[i][j] = New{{.P}}Point().Add({{.p}}GeneratorTable[i][j-1], base) 512 } 513 base.Double(base) 514 base.Double(base) 515 base.Double(base) 516 base.Double(base) 517 } 518 }) 519 return {{.p}}GeneratorTable 520 } 521 522 // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and 523 // returns p. 524 func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) { 525 if len(scalar) != {{.p}}ElementLength { 526 return nil, errors.New("invalid scalar length") 527 } 528 tables := p.generatorTable() 529 // This is also a scalar multiplication with a four-bit window like in 530 // ScalarMult, but in this case the doublings are precomputed. The value 531 // [windowValue]G added at iteration k would normally get doubled 532 // (totIterations-k)×4 times, but with a larger precomputation we can 533 // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the 534 // doublings between iterations. 535 t := New{{.P}}Point() 536 p.Set(New{{.P}}Point()) 537 tableIndex := len(tables) - 1 538 for _, byte := range scalar { 539 windowValue := byte >> 4 540 tables[tableIndex].Select(t, windowValue) 541 p.Add(p, t) 542 tableIndex-- 543 544 windowValue = byte & 0b1111 545 tables[tableIndex].Select(t, windowValue) 546 p.Add(p, t) 547 tableIndex-- 548 } 549 550 return p, nil 551 } 552 553 // {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns 554 // false and e is unchanged. e and x can overlap. 555 func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) { 556 candidate := new({{ .Element }}) 557 {{.p}}SqrtCandidate(candidate, x) 558 square := new({{ .Element }}).Square(candidate) 559 if square.Equal(x) != 1 { 560 return false 561 } 562 e.Set(candidate) 563 return true 564 } 565 ` 566 567 const tmplAddchain = ` 568 // sqrtCandidate sets z to a square root candidate for x. z and x must not overlap. 569 func sqrtCandidate(z, x *Element) { 570 // Since p = 3 mod 4, exponentiation by (p + 1) / 4 yields a square root candidate. 571 // 572 // The sequence of {{ .Ops.Adds }} multiplications and {{ .Ops.Doubles }} squarings is derived from the 573 // following addition chain generated with {{ .Meta.Module }} {{ .Meta.ReleaseTag }}. 574 // 575 {{- range lines (format .Script) }} 576 // {{ . }} 577 {{- end }} 578 // 579 {{- range .Program.Temporaries }} 580 var {{ . }} = new(Element) 581 {{- end }} 582 {{ range $i := .Program.Instructions -}} 583 {{- with add $i.Op }} 584 {{ $i.Output }}.Mul({{ .X }}, {{ .Y }}) 585 {{- end -}} 586 {{- with double $i.Op }} 587 {{ $i.Output }}.Square({{ .X }}) 588 {{- end -}} 589 {{- with shift $i.Op -}} 590 {{- $first := 0 -}} 591 {{- if ne $i.Output.Identifier .X.Identifier }} 592 {{ $i.Output }}.Square({{ .X }}) 593 {{- $first = 1 -}} 594 {{- end }} 595 for s := {{ $first }}; s < {{ .S }}; s++ { 596 {{ $i.Output }}.Square({{ $i.Output }}) 597 } 598 {{- end -}} 599 {{- end }} 600 } 601 `