github.com/cloudflare/circl@v1.5.0/sign/dilithium/gen.go (about) 1 //go:build ignore 2 // +build ignore 3 4 // Autogenerates wrappers from templates to prevent too much duplicated code 5 // between the code for different modes. 6 package main 7 8 import ( 9 "bytes" 10 "fmt" 11 "go/format" 12 "os" 13 "path" 14 "strings" 15 "text/template" 16 17 "github.com/cloudflare/circl/sign/internal/dilithium/params" 18 ) 19 20 type Mode struct { 21 Name string 22 K int 23 L int 24 Eta int 25 DoubleEtaBits int 26 Omega int 27 Tau int 28 Gamma1Bits int 29 Gamma2 int 30 TRSize int 31 CTildeSize int 32 } 33 34 func (m Mode) Pkg() string { 35 return strings.ToLower(m.Mode()) 36 } 37 38 func (m Mode) PkgPath() string { 39 if m.NIST() { 40 return path.Join("..", "mldsa", m.Pkg()) 41 } 42 43 return m.Pkg() 44 } 45 46 func (m Mode) Impl() string { 47 return "impl" + m.Mode() 48 } 49 50 func (m Mode) Mode() string { 51 if m.NIST() { 52 return strings.ReplaceAll(m.Name, "-", "") 53 } 54 55 return strings.ReplaceAll(m.Name, "Dilithium", "Mode") 56 } 57 58 func (m Mode) NIST() bool { 59 return strings.HasPrefix(m.Name, "ML-DSA-") 60 } 61 62 var ( 63 Modes = []Mode{ 64 { 65 Name: "Dilithium2", 66 K: 4, 67 L: 4, 68 Eta: 2, 69 DoubleEtaBits: 3, 70 Omega: 80, 71 Tau: 39, 72 Gamma1Bits: 17, 73 Gamma2: (params.Q - 1) / 88, 74 TRSize: 32, 75 CTildeSize: 32, 76 }, 77 { 78 Name: "Dilithium3", 79 K: 6, 80 L: 5, 81 Eta: 4, 82 DoubleEtaBits: 4, 83 Omega: 55, 84 Tau: 49, 85 Gamma1Bits: 19, 86 Gamma2: (params.Q - 1) / 32, 87 TRSize: 32, 88 CTildeSize: 32, 89 }, 90 { 91 Name: "Dilithium5", 92 K: 8, 93 L: 7, 94 Eta: 2, 95 DoubleEtaBits: 3, 96 Omega: 75, 97 Tau: 60, 98 Gamma1Bits: 19, 99 Gamma2: (params.Q - 1) / 32, 100 TRSize: 32, 101 CTildeSize: 32, 102 }, 103 { 104 Name: "ML-DSA-44", 105 K: 4, 106 L: 4, 107 Eta: 2, 108 DoubleEtaBits: 3, 109 Omega: 80, 110 Tau: 39, 111 Gamma1Bits: 17, 112 Gamma2: (params.Q - 1) / 88, 113 TRSize: 64, 114 CTildeSize: 32, 115 }, 116 { 117 Name: "ML-DSA-65", 118 K: 6, 119 L: 5, 120 Eta: 4, 121 DoubleEtaBits: 4, 122 Omega: 55, 123 Tau: 49, 124 Gamma1Bits: 19, 125 Gamma2: (params.Q - 1) / 32, 126 TRSize: 64, 127 CTildeSize: 48, 128 }, 129 { 130 Name: "ML-DSA-87", 131 K: 8, 132 L: 7, 133 Eta: 2, 134 DoubleEtaBits: 3, 135 Omega: 75, 136 Tau: 60, 137 Gamma1Bits: 19, 138 Gamma2: (params.Q - 1) / 32, 139 TRSize: 64, 140 CTildeSize: 64, 141 }, 142 } 143 TemplateWarning = "// Code generated from" 144 ) 145 146 func main() { 147 generateModePackageFiles() 148 generateACVPTest() 149 generateParamsFiles() 150 generateSourceFiles() 151 } 152 153 // Generates modeX/internal/params.go from templates/params.templ.go 154 func generateParamsFiles() { 155 tl, err := template.ParseFiles("templates/params.templ.go") 156 if err != nil { 157 panic(err) 158 } 159 160 for _, mode := range Modes { 161 buf := new(bytes.Buffer) 162 err := tl.Execute(buf, mode) 163 if err != nil { 164 panic(err) 165 } 166 167 // Formating output code 168 code, err := format.Source(buf.Bytes()) 169 if err != nil { 170 panic("error formating code") 171 } 172 173 res := string(code) 174 offset := strings.Index(res, TemplateWarning) 175 if offset == -1 { 176 panic("Missing template warning in params.templ.go") 177 } 178 err = os.WriteFile(mode.PkgPath()+"/internal/params.go", 179 []byte(res[offset:]), 0o644) 180 if err != nil { 181 panic(err) 182 } 183 } 184 } 185 186 // Generates modeX/dilithium.go from templates/pkg.templ.go 187 func generateModePackageFiles() { 188 tl, err := template.ParseFiles("templates/pkg.templ.go") 189 if err != nil { 190 panic(err) 191 } 192 193 for _, mode := range Modes { 194 buf := new(bytes.Buffer) 195 err := tl.Execute(buf, mode) 196 if err != nil { 197 panic(err) 198 } 199 200 res, err := format.Source(buf.Bytes()) 201 if err != nil { 202 panic("error formating code") 203 } 204 205 offset := strings.Index(string(res), TemplateWarning) 206 if offset == -1 { 207 panic("Missing template warning in pkg.templ.go") 208 } 209 err = os.WriteFile(mode.PkgPath()+"/dilithium.go", res[offset:], 0o644) 210 if err != nil { 211 panic(err) 212 } 213 } 214 } 215 216 // Generates modeX/dilithium.go from templates/pkg.templ.go 217 func generateACVPTest() { 218 tl, err := template.ParseFiles("templates/acvp.templ.go") 219 if err != nil { 220 panic(err) 221 } 222 223 for _, mode := range Modes { 224 if !strings.HasPrefix(mode.Name, "ML-DSA") { 225 continue 226 } 227 228 buf := new(bytes.Buffer) 229 err := tl.Execute(buf, mode) 230 if err != nil { 231 panic(err) 232 } 233 234 res, err := format.Source(buf.Bytes()) 235 if err != nil { 236 panic("error formating code") 237 } 238 239 offset := strings.Index(string(res), TemplateWarning) 240 if offset == -1 { 241 panic("Missing template warning in pkg.templ.go") 242 } 243 err = os.WriteFile(mode.PkgPath()+"/acvp_test.go", res[offset:], 0o644) 244 if err != nil { 245 panic(err) 246 } 247 } 248 } 249 250 // Copies mode3 source files to other modes 251 func generateSourceFiles() { 252 files := make(map[string][]byte) 253 254 // Ignore mode specific files. 255 ignored := func(x string) bool { 256 return x == "params.go" || x == "params_test.go" || 257 strings.HasSuffix(x, ".swp") 258 } 259 260 fs, err := os.ReadDir("mode3/internal") 261 if err != nil { 262 panic(err) 263 } 264 265 // Read files 266 for _, f := range fs { 267 name := f.Name() 268 if ignored(name) { 269 continue 270 } 271 files[name], err = os.ReadFile(path.Join("mode3/internal", name)) 272 if err != nil { 273 panic(err) 274 } 275 } 276 277 // Go over modes 278 for _, mode := range Modes { 279 if mode.Name == "Dilithium3" { 280 continue 281 } 282 283 fs, err = os.ReadDir(path.Join(mode.PkgPath(), "internal")) 284 for _, f := range fs { 285 name := f.Name() 286 fn := path.Join(mode.PkgPath(), "internal", name) 287 if ignored(name) { 288 continue 289 } 290 _, ok := files[name] 291 if !ok { 292 fmt.Printf("Removing superfluous file: %s\n", fn) 293 err = os.Remove(fn) 294 if err != nil { 295 panic(err) 296 } 297 } 298 if f.IsDir() { 299 panic(fmt.Sprintf("%s: is a directory", fn)) 300 } 301 if f.Type()&os.ModeSymlink != 0 { 302 fmt.Printf("Removing symlink: %s\n", fn) 303 err = os.Remove(fn) 304 if err != nil { 305 panic(err) 306 } 307 } 308 } 309 for name, expected := range files { 310 fn := path.Join(mode.PkgPath(), "internal", name) 311 expected = []byte(fmt.Sprintf( 312 "%s mode3/internal/%s by gen.go\n\n%s", 313 TemplateWarning, 314 name, 315 string(expected), 316 )) 317 got, err := os.ReadFile(fn) 318 if err == nil { 319 if bytes.Equal(got, expected) { 320 continue 321 } 322 } 323 fmt.Printf("Updating %s\n", fn) 324 err = os.WriteFile(fn, expected, 0o644) 325 if err != nil { 326 panic(err) 327 } 328 } 329 } 330 }