github.com/cloudflare/circl@v1.5.0/sign/dilithium/gen.go (about)

     1  //go:build ignore
     2  // +build ignore
     3  
     4  // Autogenerates wrappers from templates to prevent too much duplicated code
     5  // between the code for different modes.
     6  package main
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  	"go/format"
    12  	"os"
    13  	"path"
    14  	"strings"
    15  	"text/template"
    16  
    17  	"github.com/cloudflare/circl/sign/internal/dilithium/params"
    18  )
    19  
    20  type Mode struct {
    21  	Name          string
    22  	K             int
    23  	L             int
    24  	Eta           int
    25  	DoubleEtaBits int
    26  	Omega         int
    27  	Tau           int
    28  	Gamma1Bits    int
    29  	Gamma2        int
    30  	TRSize        int
    31  	CTildeSize    int
    32  }
    33  
    34  func (m Mode) Pkg() string {
    35  	return strings.ToLower(m.Mode())
    36  }
    37  
    38  func (m Mode) PkgPath() string {
    39  	if m.NIST() {
    40  		return path.Join("..", "mldsa", m.Pkg())
    41  	}
    42  
    43  	return m.Pkg()
    44  }
    45  
    46  func (m Mode) Impl() string {
    47  	return "impl" + m.Mode()
    48  }
    49  
    50  func (m Mode) Mode() string {
    51  	if m.NIST() {
    52  		return strings.ReplaceAll(m.Name, "-", "")
    53  	}
    54  
    55  	return strings.ReplaceAll(m.Name, "Dilithium", "Mode")
    56  }
    57  
    58  func (m Mode) NIST() bool {
    59  	return strings.HasPrefix(m.Name, "ML-DSA-")
    60  }
    61  
    62  var (
    63  	Modes = []Mode{
    64  		{
    65  			Name:          "Dilithium2",
    66  			K:             4,
    67  			L:             4,
    68  			Eta:           2,
    69  			DoubleEtaBits: 3,
    70  			Omega:         80,
    71  			Tau:           39,
    72  			Gamma1Bits:    17,
    73  			Gamma2:        (params.Q - 1) / 88,
    74  			TRSize:        32,
    75  			CTildeSize:    32,
    76  		},
    77  		{
    78  			Name:          "Dilithium3",
    79  			K:             6,
    80  			L:             5,
    81  			Eta:           4,
    82  			DoubleEtaBits: 4,
    83  			Omega:         55,
    84  			Tau:           49,
    85  			Gamma1Bits:    19,
    86  			Gamma2:        (params.Q - 1) / 32,
    87  			TRSize:        32,
    88  			CTildeSize:    32,
    89  		},
    90  		{
    91  			Name:          "Dilithium5",
    92  			K:             8,
    93  			L:             7,
    94  			Eta:           2,
    95  			DoubleEtaBits: 3,
    96  			Omega:         75,
    97  			Tau:           60,
    98  			Gamma1Bits:    19,
    99  			Gamma2:        (params.Q - 1) / 32,
   100  			TRSize:        32,
   101  			CTildeSize:    32,
   102  		},
   103  		{
   104  			Name:          "ML-DSA-44",
   105  			K:             4,
   106  			L:             4,
   107  			Eta:           2,
   108  			DoubleEtaBits: 3,
   109  			Omega:         80,
   110  			Tau:           39,
   111  			Gamma1Bits:    17,
   112  			Gamma2:        (params.Q - 1) / 88,
   113  			TRSize:        64,
   114  			CTildeSize:    32,
   115  		},
   116  		{
   117  			Name:          "ML-DSA-65",
   118  			K:             6,
   119  			L:             5,
   120  			Eta:           4,
   121  			DoubleEtaBits: 4,
   122  			Omega:         55,
   123  			Tau:           49,
   124  			Gamma1Bits:    19,
   125  			Gamma2:        (params.Q - 1) / 32,
   126  			TRSize:        64,
   127  			CTildeSize:    48,
   128  		},
   129  		{
   130  			Name:          "ML-DSA-87",
   131  			K:             8,
   132  			L:             7,
   133  			Eta:           2,
   134  			DoubleEtaBits: 3,
   135  			Omega:         75,
   136  			Tau:           60,
   137  			Gamma1Bits:    19,
   138  			Gamma2:        (params.Q - 1) / 32,
   139  			TRSize:        64,
   140  			CTildeSize:    64,
   141  		},
   142  	}
   143  	TemplateWarning = "// Code generated from"
   144  )
   145  
   146  func main() {
   147  	generateModePackageFiles()
   148  	generateACVPTest()
   149  	generateParamsFiles()
   150  	generateSourceFiles()
   151  }
   152  
   153  // Generates modeX/internal/params.go from templates/params.templ.go
   154  func generateParamsFiles() {
   155  	tl, err := template.ParseFiles("templates/params.templ.go")
   156  	if err != nil {
   157  		panic(err)
   158  	}
   159  
   160  	for _, mode := range Modes {
   161  		buf := new(bytes.Buffer)
   162  		err := tl.Execute(buf, mode)
   163  		if err != nil {
   164  			panic(err)
   165  		}
   166  
   167  		// Formating output code
   168  		code, err := format.Source(buf.Bytes())
   169  		if err != nil {
   170  			panic("error formating code")
   171  		}
   172  
   173  		res := string(code)
   174  		offset := strings.Index(res, TemplateWarning)
   175  		if offset == -1 {
   176  			panic("Missing template warning in params.templ.go")
   177  		}
   178  		err = os.WriteFile(mode.PkgPath()+"/internal/params.go",
   179  			[]byte(res[offset:]), 0o644)
   180  		if err != nil {
   181  			panic(err)
   182  		}
   183  	}
   184  }
   185  
   186  // Generates modeX/dilithium.go from templates/pkg.templ.go
   187  func generateModePackageFiles() {
   188  	tl, err := template.ParseFiles("templates/pkg.templ.go")
   189  	if err != nil {
   190  		panic(err)
   191  	}
   192  
   193  	for _, mode := range Modes {
   194  		buf := new(bytes.Buffer)
   195  		err := tl.Execute(buf, mode)
   196  		if err != nil {
   197  			panic(err)
   198  		}
   199  
   200  		res, err := format.Source(buf.Bytes())
   201  		if err != nil {
   202  			panic("error formating code")
   203  		}
   204  
   205  		offset := strings.Index(string(res), TemplateWarning)
   206  		if offset == -1 {
   207  			panic("Missing template warning in pkg.templ.go")
   208  		}
   209  		err = os.WriteFile(mode.PkgPath()+"/dilithium.go", res[offset:], 0o644)
   210  		if err != nil {
   211  			panic(err)
   212  		}
   213  	}
   214  }
   215  
   216  // Generates modeX/dilithium.go from templates/pkg.templ.go
   217  func generateACVPTest() {
   218  	tl, err := template.ParseFiles("templates/acvp.templ.go")
   219  	if err != nil {
   220  		panic(err)
   221  	}
   222  
   223  	for _, mode := range Modes {
   224  		if !strings.HasPrefix(mode.Name, "ML-DSA") {
   225  			continue
   226  		}
   227  
   228  		buf := new(bytes.Buffer)
   229  		err := tl.Execute(buf, mode)
   230  		if err != nil {
   231  			panic(err)
   232  		}
   233  
   234  		res, err := format.Source(buf.Bytes())
   235  		if err != nil {
   236  			panic("error formating code")
   237  		}
   238  
   239  		offset := strings.Index(string(res), TemplateWarning)
   240  		if offset == -1 {
   241  			panic("Missing template warning in pkg.templ.go")
   242  		}
   243  		err = os.WriteFile(mode.PkgPath()+"/acvp_test.go", res[offset:], 0o644)
   244  		if err != nil {
   245  			panic(err)
   246  		}
   247  	}
   248  }
   249  
   250  // Copies mode3 source files to other modes
   251  func generateSourceFiles() {
   252  	files := make(map[string][]byte)
   253  
   254  	// Ignore mode specific files.
   255  	ignored := func(x string) bool {
   256  		return x == "params.go" || x == "params_test.go" ||
   257  			strings.HasSuffix(x, ".swp")
   258  	}
   259  
   260  	fs, err := os.ReadDir("mode3/internal")
   261  	if err != nil {
   262  		panic(err)
   263  	}
   264  
   265  	// Read files
   266  	for _, f := range fs {
   267  		name := f.Name()
   268  		if ignored(name) {
   269  			continue
   270  		}
   271  		files[name], err = os.ReadFile(path.Join("mode3/internal", name))
   272  		if err != nil {
   273  			panic(err)
   274  		}
   275  	}
   276  
   277  	// Go over modes
   278  	for _, mode := range Modes {
   279  		if mode.Name == "Dilithium3" {
   280  			continue
   281  		}
   282  
   283  		fs, err = os.ReadDir(path.Join(mode.PkgPath(), "internal"))
   284  		for _, f := range fs {
   285  			name := f.Name()
   286  			fn := path.Join(mode.PkgPath(), "internal", name)
   287  			if ignored(name) {
   288  				continue
   289  			}
   290  			_, ok := files[name]
   291  			if !ok {
   292  				fmt.Printf("Removing superfluous file: %s\n", fn)
   293  				err = os.Remove(fn)
   294  				if err != nil {
   295  					panic(err)
   296  				}
   297  			}
   298  			if f.IsDir() {
   299  				panic(fmt.Sprintf("%s: is a directory", fn))
   300  			}
   301  			if f.Type()&os.ModeSymlink != 0 {
   302  				fmt.Printf("Removing symlink: %s\n", fn)
   303  				err = os.Remove(fn)
   304  				if err != nil {
   305  					panic(err)
   306  				}
   307  			}
   308  		}
   309  		for name, expected := range files {
   310  			fn := path.Join(mode.PkgPath(), "internal", name)
   311  			expected = []byte(fmt.Sprintf(
   312  				"%s mode3/internal/%s by gen.go\n\n%s",
   313  				TemplateWarning,
   314  				name,
   315  				string(expected),
   316  			))
   317  			got, err := os.ReadFile(fn)
   318  			if err == nil {
   319  				if bytes.Equal(got, expected) {
   320  					continue
   321  				}
   322  			}
   323  			fmt.Printf("Updating %s\n", fn)
   324  			err = os.WriteFile(fn, expected, 0o644)
   325  			if err != nil {
   326  				panic(err)
   327  			}
   328  		}
   329  	}
   330  }