github.com/emmansun/gmsm@v0.29.1/internal/sm2ec/fiat/generate.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  
     7  package main
     8  
     9  import (
    10  	"bytes"
    11  	"go/format"
    12  	"io"
    13  	"log"
    14  	"os"
    15  	"os/exec"
    16  	"text/template"
    17  )
    18  
    19  var curves = []struct {
    20  	Element  string
    21  	Prime    string
    22  	Prefix   string
    23  	FiatType string
    24  	BytesLen int
    25  }{
    26  	{
    27  		Element:  "SM2P256Element",
    28  		Prime:    "2^256 - 2^224 - 2^96 + 2^64 - 1",
    29  		Prefix:   "sm2p256",
    30  		FiatType: "[4]uint64",
    31  		BytesLen: 32,
    32  	},
    33  	{
    34  		Element:  "SM2P256OrderElement",
    35  		Prime:    "2^256 - 2^224 - 188730267045675049073202170516080344797",
    36  		Prefix:   "sm2p256scalar",
    37  		FiatType: "[4]uint64",
    38  		BytesLen: 32,
    39  	},	
    40  }
    41  
    42  func main() {
    43  	t := template.Must(template.New("montgomery").Parse(tmplWrapper))
    44  
    45  	tmplAddchainFile, err := os.CreateTemp("", "addchain-template")
    46  	if err != nil {
    47  		log.Fatal(err)
    48  	}
    49  	defer os.Remove(tmplAddchainFile.Name())
    50  	if _, err := io.WriteString(tmplAddchainFile, tmplAddchain); err != nil {
    51  		log.Fatal(err)
    52  	}
    53  	if err := tmplAddchainFile.Close(); err != nil {
    54  		log.Fatal(err)
    55  	}
    56  
    57  	for _, c := range curves {
    58  		log.Printf("Generating %s.go...", c.Prefix)
    59  		f, err := os.Create(c.Prefix + ".go")
    60  		if err != nil {
    61  			log.Fatal(err)
    62  		}
    63  		if err := t.Execute(f, c); err != nil {
    64  			log.Fatal(err)
    65  		}
    66  		if err := f.Close(); err != nil {
    67  			log.Fatal(err)
    68  		}
    69  
    70  		log.Printf("Generating %s_fiat64.go...", c.Prefix)
    71  		cmd := exec.Command("docker", "run", "--rm", "--entrypoint", "word_by_word_montgomery",
    72  			"fiat-crypto:v0.0.9", "--lang", "Go", "--no-wide-int", "--cmovznz-by-mul",
    73  			"--relax-primitive-carry-to-bitwidth", "32,64", "--internal-static",
    74  			"--public-function-case", "camelCase", "--public-type-case", "camelCase",
    75  			"--private-function-case", "camelCase", "--private-type-case", "camelCase",
    76  			"--doc-text-before-function-name", "", "--doc-newline-before-package-declaration",
    77  			"--doc-prepend-header", "Code generated by Fiat Cryptography. DO NOT EDIT.",
    78  			"--package-name", "fiat", "--no-prefix-fiat", c.Prefix, "64", c.Prime,
    79  			"mul", "square", "add", "sub", "one", "from_montgomery", "to_montgomery",
    80  			"selectznz", "to_bytes", "from_bytes", "nonzero", "opp", "msat", "divstep", "divstep_precomp")
    81  		cmd.Stderr = os.Stderr
    82  		out, err := cmd.Output()
    83  		if err != nil {
    84  			log.Fatal(err)
    85  		}
    86  		out, err = format.Source(out)
    87  		if err != nil {
    88  			log.Fatal(err)
    89  		}
    90  		if err := os.WriteFile(c.Prefix+"_fiat64.go", out, 0644); err != nil {
    91  			log.Fatal(err)
    92  		}
    93  
    94  		log.Printf("Generating %s_invert.go...", c.Prefix)
    95  		f, err = os.CreateTemp("", "addchain-"+c.Prefix)
    96  		if err != nil {
    97  			log.Fatal(err)
    98  		}
    99  		defer os.Remove(f.Name())
   100  		cmd = exec.Command("addchain", "search", c.Prime+" - 2")
   101  		cmd.Stderr = os.Stderr
   102  		cmd.Stdout = f
   103  		if err := cmd.Run(); err != nil {
   104  			log.Fatal(err)
   105  		}
   106  		if err := f.Close(); err != nil {
   107  			log.Fatal(err)
   108  		}
   109  		cmd = exec.Command("addchain", "gen", "-tmpl", tmplAddchainFile.Name(), f.Name())
   110  		cmd.Stderr = os.Stderr
   111  		out, err = cmd.Output()
   112  		if err != nil {
   113  			log.Fatal(err)
   114  		}
   115  		out = bytes.Replace(out, []byte("Element"), []byte(c.Element), -1)
   116  		out, err = format.Source(out)
   117  		if err != nil {
   118  			log.Fatal(err)
   119  		}
   120  		if err := os.WriteFile(c.Prefix+"_invert.go", out, 0644); err != nil {
   121  			log.Fatal(err)
   122  		}
   123  	}
   124  }
   125  
   126  const tmplWrapper = `// Copyright 2021 The Go Authors. All rights reserved.
   127  // Use of this source code is governed by a BSD-style
   128  // license that can be found in the LICENSE file.
   129  
   130  // Code generated by generate.go. DO NOT EDIT.
   131  
   132  package fiat
   133  
   134  import (
   135  	"crypto/subtle"
   136  	"errors"
   137  )
   138  
   139  // {{ .Element }} is an integer modulo {{ .Prime }}.
   140  //
   141  // The zero value is a valid zero element.
   142  type {{ .Element }} struct {
   143  	// Values are represented internally always in the Montgomery domain, and
   144  	// converted in Bytes and SetBytes.
   145  	x {{ .Prefix }}MontgomeryDomainFieldElement
   146  }
   147  
   148  const {{ .Prefix }}ElementLen = {{ .BytesLen }}
   149  
   150  type {{ .Prefix }}UntypedFieldElement = {{ .FiatType }}
   151  
   152  // One sets e = 1, and returns e.
   153  func (e *{{ .Element }}) One() *{{ .Element }} {
   154  	{{ .Prefix }}SetOne(&e.x)
   155  	return e
   156  }
   157  
   158  // Equal returns 1 if e == t, and zero otherwise.
   159  func (e *{{ .Element }}) Equal(t *{{ .Element }}) int {
   160  	eBytes := e.Bytes()
   161  	tBytes := t.Bytes()
   162  	return subtle.ConstantTimeCompare(eBytes, tBytes)
   163  }
   164  
   165  // IsZero returns 1 if e == 0, and zero otherwise.
   166  func (e *{{ .Element }}) IsZero() int {
   167  	zero := make([]byte, {{ .Prefix }}ElementLen)
   168  	eBytes := e.Bytes()
   169  	return subtle.ConstantTimeCompare(eBytes, zero)
   170  }
   171  
   172  // Set sets e = t, and returns e.
   173  func (e *{{ .Element }}) Set(t *{{ .Element }}) *{{ .Element }} {
   174  	e.x = t.x
   175  	return e
   176  }
   177  
   178  // Bytes returns the {{ .BytesLen }}-byte big-endian encoding of e.
   179  func (e *{{ .Element }}) Bytes() []byte {
   180  	// This function is outlined to make the allocations inline in the caller
   181  	// rather than happen on the heap.
   182  	var out [{{ .Prefix }}ElementLen]byte
   183  	return e.bytes(&out)
   184  }
   185  
   186  func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte {
   187  	var tmp {{ .Prefix }}NonMontgomeryDomainFieldElement
   188  	{{ .Prefix }}FromMontgomery(&tmp, &e.x)
   189  	{{ .Prefix }}ToBytes(out, (*{{ .Prefix }}UntypedFieldElement)(&tmp))
   190  	{{ .Prefix }}InvertEndianness(out[:])
   191  	return out[:]
   192  }
   193  
   194  // SetBytes sets e = v, where v is a big-endian {{ .BytesLen }}-byte encoding, and returns e.
   195  // If v is not {{ .BytesLen }} bytes or it encodes a value higher than {{ .Prime }},
   196  // SetBytes returns nil and an error, and e is unchanged.
   197  func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
   198  	if len(v) != {{ .Prefix }}ElementLen {
   199  		return nil, errors.New("invalid {{ .Element }} encoding")
   200  	}
   201  	// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
   202  	// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
   203  	var minusOneEncoding = new({{ .Element }}).Sub(
   204  		new({{ .Element }}), new({{ .Element }}).One()).Bytes()
   205  	for i := range v {
   206  		if v[i] < minusOneEncoding[i] {
   207  			break
   208  		}
   209  		if v[i] > minusOneEncoding[i] {
   210  			return nil, errors.New("invalid {{ .Element }} encoding")
   211  		}
   212  	}
   213  	var in [{{ .Prefix }}ElementLen]byte
   214  	copy(in[:], v)
   215  	{{ .Prefix }}InvertEndianness(in[:])
   216  	var tmp {{ .Prefix }}NonMontgomeryDomainFieldElement
   217  	{{ .Prefix }}FromBytes((*{{ .Prefix }}UntypedFieldElement)(&tmp), &in)
   218  	{{ .Prefix }}ToMontgomery(&e.x, &tmp)
   219  	return e, nil
   220  }
   221  
   222  // Add sets e = t1 + t2, and returns e.
   223  func (e *{{ .Element }}) Add(t1, t2 *{{ .Element }}) *{{ .Element }} {
   224  	{{ .Prefix }}Add(&e.x, &t1.x, &t2.x)
   225  	return e
   226  }
   227  
   228  // Sub sets e = t1 - t2, and returns e.
   229  func (e *{{ .Element }}) Sub(t1, t2 *{{ .Element }}) *{{ .Element }} {
   230  	{{ .Prefix }}Sub(&e.x, &t1.x, &t2.x)
   231  	return e
   232  }
   233  
   234  // Mul sets e = t1 * t2, and returns e.
   235  func (e *{{ .Element }}) Mul(t1, t2 *{{ .Element }}) *{{ .Element }} {
   236  	{{ .Prefix }}Mul(&e.x, &t1.x, &t2.x)
   237  	return e
   238  }
   239  
   240  // Square sets e = t * t, and returns e.
   241  func (e *{{ .Element }}) Square(t *{{ .Element }}) *{{ .Element }} {
   242  	{{ .Prefix }}Square(&e.x, &t.x)
   243  	return e
   244  }
   245  
   246  // Select sets v to a if cond == 1, and to b if cond == 0.
   247  func (v *{{ .Element }}) Select(a, b *{{ .Element }}, cond int) *{{ .Element }} {
   248  	{{ .Prefix }}Selectznz((*{{ .Prefix }}UntypedFieldElement)(&v.x), {{ .Prefix }}Uint1(cond),
   249  		(*{{ .Prefix }}UntypedFieldElement)(&b.x), (*{{ .Prefix }}UntypedFieldElement)(&a.x))
   250  	return v
   251  }
   252  
   253  func {{ .Prefix }}InvertEndianness(v []byte) {
   254  	for i := 0; i < len(v)/2; i++ {
   255  		v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]
   256  	}
   257  }
   258  `
   259  
   260  const tmplAddchain = `// Copyright 2021 The Go Authors. All rights reserved.
   261  // Use of this source code is governed by a BSD-style
   262  // license that can be found in the LICENSE file.
   263  // Code generated by {{ .Meta.Name }}. DO NOT EDIT.
   264  package fiat
   265  // Invert sets e = 1/x, and returns e.
   266  //
   267  // If x == 0, Invert returns e = 0.
   268  func (e *Element) Invert(x *Element) *Element {
   269  	// Inversion is implemented as exponentiation with exponent p − 2.
   270  	// The sequence of {{ .Ops.Adds }} multiplications and {{ .Ops.Doubles }} squarings is derived from the
   271  	// following addition chain generated with {{ .Meta.Module }} {{ .Meta.ReleaseTag }}.
   272  	//
   273  	{{- range lines (format .Script) }}
   274  	//	{{ . }}
   275  	{{- end }}
   276  	//
   277  	var z = new(Element).Set(e)
   278  	{{- range .Program.Temporaries }}
   279  	var {{ . }} = new(Element)
   280  	{{- end }}
   281  	{{ range $i := .Program.Instructions -}}
   282  	{{- with add $i.Op }}
   283  	{{ $i.Output }}.Mul({{ .X }}, {{ .Y }})
   284  	{{- end -}}
   285  	{{- with double $i.Op }}
   286  	{{ $i.Output }}.Square({{ .X }})
   287  	{{- end -}}
   288  	{{- with shift $i.Op -}}
   289  	{{- $first := 0 -}}
   290  	{{- if ne $i.Output.Identifier .X.Identifier }}
   291  	{{ $i.Output }}.Square({{ .X }})
   292  	{{- $first = 1 -}}
   293  	{{- end }}
   294  	for s := {{ $first }}; s < {{ .S }}; s++ {
   295  		{{ $i.Output }}.Square({{ $i.Output }})
   296  	}
   297  	{{- end -}}
   298  	{{- end }}
   299  	return e.Set(z)
   300  }
   301  `