github.com/cloudflare/circl@v1.5.0/pke/kyber/gen.go (about)

     1  //go:build ignore
     2  // +build ignore
     3  
     4  // Autogenerates wrappers from templates to prevent too much duplicated code
     5  // between the code for different modes.
     6  package main
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"go/format"
    12  	"os"
    13  	"path"
    14  	"strings"
    15  	"text/template"
    16  )
    17  
    18  type Instance struct {
    19  	Name           string
    20  	K              int
    21  	Eta1           int
    22  	CiphertextSize int
    23  	DU             int
    24  	DV             int
    25  }
    26  
    27  func (m Instance) Pkg() string {
    28  	return strings.ToLower(m.Name)
    29  }
    30  
    31  func (m Instance) Impl() string {
    32  	return "impl" + m.Name
    33  }
    34  
    35  var (
    36  	Instances = []Instance{
    37  		{
    38  			Name:           "Kyber512",
    39  			Eta1:           3,
    40  			K:              2,
    41  			CiphertextSize: 768,
    42  			DU:             10,
    43  			DV:             4,
    44  		},
    45  		{
    46  			Name:           "Kyber768",
    47  			Eta1:           2,
    48  			K:              3,
    49  			CiphertextSize: 1088,
    50  			DU:             10,
    51  			DV:             4,
    52  		},
    53  		{
    54  			Name:           "Kyber1024",
    55  			Eta1:           2,
    56  			K:              4,
    57  			CiphertextSize: 1568,
    58  			DU:             11,
    59  			DV:             5,
    60  		},
    61  	}
    62  	TemplateWarning = "// Code generated from"
    63  )
    64  
    65  func main() {
    66  	generatePackageFiles()
    67  	generateParamsFiles()
    68  	generateSourceFiles()
    69  }
    70  
    71  // Generates instance/internal/params.go from templates/params.templ.go
    72  func generateParamsFiles() {
    73  	tl, err := template.ParseFiles("templates/params.templ.go")
    74  	if err != nil {
    75  		panic(err)
    76  	}
    77  
    78  	for _, mode := range Instances {
    79  		buf := new(bytes.Buffer)
    80  		err := tl.Execute(buf, mode)
    81  		if err != nil {
    82  			panic(err)
    83  		}
    84  
    85  		// Formating output code
    86  		code, err := format.Source(buf.Bytes())
    87  		if err != nil {
    88  			panic("error formating code")
    89  		}
    90  
    91  		res := string(code)
    92  		offset := strings.Index(res, TemplateWarning)
    93  		if offset == -1 {
    94  			panic("Missing template warning in params.templ.go")
    95  		}
    96  		err = os.WriteFile(mode.Pkg()+"/internal/params.go",
    97  			[]byte(res[offset:]), 0o644)
    98  		if err != nil {
    99  			panic(err)
   100  		}
   101  	}
   102  }
   103  
   104  // Generates instance/kyber.go from templates/pkg.templ.go
   105  func generatePackageFiles() {
   106  	tl, err := template.ParseFiles("templates/pkg.templ.go")
   107  	if err != nil {
   108  		panic(err)
   109  	}
   110  
   111  	for _, mode := range Instances {
   112  		buf := new(bytes.Buffer)
   113  		err := tl.Execute(buf, mode)
   114  		if err != nil {
   115  			panic(err)
   116  		}
   117  
   118  		res := string(buf.Bytes())
   119  		offset := strings.Index(res, TemplateWarning)
   120  		if offset == -1 {
   121  			panic("Missing template warning in pkg.templ.go")
   122  		}
   123  		err = os.WriteFile(mode.Pkg()+"/kyber.go", []byte(res[offset:]), 0o644)
   124  		if err != nil {
   125  			panic(err)
   126  		}
   127  	}
   128  }
   129  
   130  // Copies kyber512 source files to other modes
   131  func generateSourceFiles() {
   132  	files := make(map[string][]byte)
   133  
   134  	// Ignore mode specific files.
   135  	ignored := func(x string) bool {
   136  		return x == "params.go" || x == "params_test.go"
   137  	}
   138  
   139  	fs, err := os.ReadDir("kyber512/internal")
   140  	if err != nil {
   141  		panic(err)
   142  	}
   143  
   144  	// Read files
   145  	for _, f := range fs {
   146  		name := f.Name()
   147  		if ignored(name) {
   148  			continue
   149  		}
   150  		files[name], err = os.ReadFile(path.Join("kyber512/internal", name))
   151  		if err != nil {
   152  			panic(err)
   153  		}
   154  	}
   155  
   156  	// Go over modes
   157  	for _, mode := range Instances {
   158  		if mode.Name == "Kyber512" {
   159  			continue
   160  		}
   161  
   162  		fs, err = os.ReadDir(path.Join(mode.Pkg(), "internal"))
   163  		for _, f := range fs {
   164  			name := f.Name()
   165  			fn := path.Join(mode.Pkg(), "internal", name)
   166  			if ignored(name) {
   167  				continue
   168  			}
   169  			_, ok := files[name]
   170  			if !ok {
   171  				fmt.Printf("Removing superfluous file: %s\n", fn)
   172  				err = os.Remove(fn)
   173  				if err != nil {
   174  					panic(err)
   175  				}
   176  			}
   177  			if f.IsDir() {
   178  				panic(fmt.Sprintf("%s: is a directory", fn))
   179  			}
   180  			if f.Type()&os.ModeSymlink != 0 {
   181  				fmt.Printf("Removing symlink: %s\n", fn)
   182  				err = os.Remove(fn)
   183  				if err != nil {
   184  					panic(err)
   185  				}
   186  			}
   187  		}
   188  		for name, expected := range files {
   189  			fn := path.Join(mode.Pkg(), "internal", name)
   190  			expected = []byte(fmt.Sprintf(
   191  				"%s kyber512/internal/%s by gen.go\n\n%s",
   192  				TemplateWarning,
   193  				name,
   194  				string(expected),
   195  			))
   196  			got, err := os.ReadFile(fn)
   197  			if err == nil {
   198  				if bytes.Equal(got, expected) {
   199  					continue
   200  				}
   201  			}
   202  			fmt.Printf("Updating %s\n", fn)
   203  			err = os.WriteFile(fn, expected, 0o644)
   204  			if err != nil {
   205  				panic(err)
   206  			}
   207  		}
   208  	}
   209  }