github.com/cloudflare/circl@v1.5.0/pke/kyber/gen.go (about) 1 //go:build ignore 2 // +build ignore 3 4 // Autogenerates wrappers from templates to prevent too much duplicated code 5 // between the code for different modes. 6 package main 7 8 import ( 9 "bytes" 10 "fmt" 11 "go/format" 12 "os" 13 "path" 14 "strings" 15 "text/template" 16 ) 17 18 type Instance struct { 19 Name string 20 K int 21 Eta1 int 22 CiphertextSize int 23 DU int 24 DV int 25 } 26 27 func (m Instance) Pkg() string { 28 return strings.ToLower(m.Name) 29 } 30 31 func (m Instance) Impl() string { 32 return "impl" + m.Name 33 } 34 35 var ( 36 Instances = []Instance{ 37 { 38 Name: "Kyber512", 39 Eta1: 3, 40 K: 2, 41 CiphertextSize: 768, 42 DU: 10, 43 DV: 4, 44 }, 45 { 46 Name: "Kyber768", 47 Eta1: 2, 48 K: 3, 49 CiphertextSize: 1088, 50 DU: 10, 51 DV: 4, 52 }, 53 { 54 Name: "Kyber1024", 55 Eta1: 2, 56 K: 4, 57 CiphertextSize: 1568, 58 DU: 11, 59 DV: 5, 60 }, 61 } 62 TemplateWarning = "// Code generated from" 63 ) 64 65 func main() { 66 generatePackageFiles() 67 generateParamsFiles() 68 generateSourceFiles() 69 } 70 71 // Generates instance/internal/params.go from templates/params.templ.go 72 func generateParamsFiles() { 73 tl, err := template.ParseFiles("templates/params.templ.go") 74 if err != nil { 75 panic(err) 76 } 77 78 for _, mode := range Instances { 79 buf := new(bytes.Buffer) 80 err := tl.Execute(buf, mode) 81 if err != nil { 82 panic(err) 83 } 84 85 // Formating output code 86 code, err := format.Source(buf.Bytes()) 87 if err != nil { 88 panic("error formating code") 89 } 90 91 res := string(code) 92 offset := strings.Index(res, TemplateWarning) 93 if offset == -1 { 94 panic("Missing template warning in params.templ.go") 95 } 96 err = os.WriteFile(mode.Pkg()+"/internal/params.go", 97 []byte(res[offset:]), 0o644) 98 if err != nil { 99 panic(err) 100 } 101 } 102 } 103 104 // Generates instance/kyber.go from templates/pkg.templ.go 105 func generatePackageFiles() { 106 tl, err := template.ParseFiles("templates/pkg.templ.go") 107 if err != nil { 108 panic(err) 109 } 110 111 for _, mode := range Instances { 112 buf := new(bytes.Buffer) 113 err := tl.Execute(buf, mode) 114 if err != nil { 115 panic(err) 116 } 117 118 res := string(buf.Bytes()) 119 offset := strings.Index(res, TemplateWarning) 120 if offset == -1 { 121 panic("Missing template warning in pkg.templ.go") 122 } 123 err = os.WriteFile(mode.Pkg()+"/kyber.go", []byte(res[offset:]), 0o644) 124 if err != nil { 125 panic(err) 126 } 127 } 128 } 129 130 // Copies kyber512 source files to other modes 131 func generateSourceFiles() { 132 files := make(map[string][]byte) 133 134 // Ignore mode specific files. 135 ignored := func(x string) bool { 136 return x == "params.go" || x == "params_test.go" 137 } 138 139 fs, err := os.ReadDir("kyber512/internal") 140 if err != nil { 141 panic(err) 142 } 143 144 // Read files 145 for _, f := range fs { 146 name := f.Name() 147 if ignored(name) { 148 continue 149 } 150 files[name], err = os.ReadFile(path.Join("kyber512/internal", name)) 151 if err != nil { 152 panic(err) 153 } 154 } 155 156 // Go over modes 157 for _, mode := range Instances { 158 if mode.Name == "Kyber512" { 159 continue 160 } 161 162 fs, err = os.ReadDir(path.Join(mode.Pkg(), "internal")) 163 for _, f := range fs { 164 name := f.Name() 165 fn := path.Join(mode.Pkg(), "internal", name) 166 if ignored(name) { 167 continue 168 } 169 _, ok := files[name] 170 if !ok { 171 fmt.Printf("Removing superfluous file: %s\n", fn) 172 err = os.Remove(fn) 173 if err != nil { 174 panic(err) 175 } 176 } 177 if f.IsDir() { 178 panic(fmt.Sprintf("%s: is a directory", fn)) 179 } 180 if f.Type()&os.ModeSymlink != 0 { 181 fmt.Printf("Removing symlink: %s\n", fn) 182 err = os.Remove(fn) 183 if err != nil { 184 panic(err) 185 } 186 } 187 } 188 for name, expected := range files { 189 fn := path.Join(mode.Pkg(), "internal", name) 190 expected = []byte(fmt.Sprintf( 191 "%s kyber512/internal/%s by gen.go\n\n%s", 192 TemplateWarning, 193 name, 194 string(expected), 195 )) 196 got, err := os.ReadFile(fn) 197 if err == nil { 198 if bytes.Equal(got, expected) { 199 continue 200 } 201 } 202 fmt.Printf("Updating %s\n", fn) 203 err = os.WriteFile(fn, expected, 0o644) 204 if err != nil { 205 panic(err) 206 } 207 } 208 } 209 }