github.com/consensys/gnark-crypto@v0.14.0/field/generator/internal/addchain/addchain.go (about)

     1  // Original copyright :
     2  // BSD 3-Clause License
     3  
     4  // Copyright (c) 2019, Michael McLoughlin
     5  // All rights reserved.
     6  
     7  // Redistribution and use in source and binary forms, with or without
     8  // modification, are permitted provided that the following conditions are met:
     9  
    10  // 1. Redistributions of source code must retain the above copyright notice, this
    11  //    list of conditions and the following disclaimer.
    12  
    13  // 2. Redistributions in binary form must reproduce the above copyright notice,
    14  //    this list of conditions and the following disclaimer in the documentation
    15  //    and/or other materials provided with the distribution.
    16  
    17  // 3. Neither the name of the copyright holder nor the names of its
    18  //    contributors may be used to endorse or promote products derived from
    19  //    this software without specific prior written permission.
    20  
    21  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
    22  // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
    23  // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
    24  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
    25  // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
    26  // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
    27  // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
    28  // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
    29  // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    30  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    31  
    32  // Package addchain is derived from github.com/mmcloughlin/addchain internal packages or examples
    33  package addchain
    34  
    35  import (
    36  	"bufio"
    37  	"encoding/gob"
    38  	"log"
    39  	"math/big"
    40  	"os"
    41  	"path/filepath"
    42  	"reflect"
    43  	"strings"
    44  	"sync"
    45  
    46  	"github.com/mmcloughlin/addchain"
    47  	"github.com/mmcloughlin/addchain/acc"
    48  	"github.com/mmcloughlin/addchain/acc/ast"
    49  	"github.com/mmcloughlin/addchain/acc/ir"
    50  	"github.com/mmcloughlin/addchain/acc/pass"
    51  	"github.com/mmcloughlin/addchain/acc/printer"
    52  	"github.com/mmcloughlin/addchain/alg/ensemble"
    53  	"github.com/mmcloughlin/addchain/alg/exec"
    54  	"github.com/mmcloughlin/addchain/meta"
    55  )
    56  
    57  // most of these functions are derived from github.com/mmcloughlin/addchain internal packages or examples
    58  
    59  var (
    60  	once        sync.Once
    61  	addChainDir string
    62  	mAddchains  map[string]*AddChainData // key is big.Int.Text(16)
    63  )
    64  
    65  // GetAddChain returns template data of a short addition chain for given big.Int
    66  func GetAddChain(n *big.Int) *AddChainData {
    67  
    68  	// init the cache only once.
    69  	once.Do(initCache)
    70  
    71  	key := n.Text(16)
    72  	if r, ok := mAddchains[key]; ok {
    73  		return r
    74  	}
    75  
    76  	// Default ensemble of algorithms.
    77  	algorithms := ensemble.Ensemble()
    78  
    79  	// Use parallel executor.
    80  	ex := exec.NewParallel()
    81  	results := ex.Execute(n, algorithms)
    82  
    83  	// Output best result.
    84  	best := 0
    85  	for i, r := range results {
    86  		if r.Err != nil {
    87  			log.Fatal(r.Err)
    88  		}
    89  		if len(results[i].Program) < len(results[best].Program) {
    90  			best = i
    91  		}
    92  	}
    93  	r := results[best]
    94  	data := processSearchResult(r.Program, key)
    95  
    96  	mAddchains[key] = data
    97  	// gob encode
    98  	file := filepath.Join(addChainDir, key)
    99  	log.Println("saving addchain", file)
   100  	f, err := os.Create(file)
   101  	if err != nil {
   102  		log.Fatal(err)
   103  	}
   104  	enc := gob.NewEncoder(f)
   105  
   106  	if err := enc.Encode(r.Program); err != nil {
   107  		_ = f.Close()
   108  		log.Fatal(err)
   109  	}
   110  	_ = f.Close()
   111  
   112  	return data
   113  }
   114  
   115  func processSearchResult(_p addchain.Program, n string) *AddChainData {
   116  	p, err := acc.Decompile(_p)
   117  	if err != nil {
   118  		log.Fatal(err)
   119  	}
   120  	chain, err := acc.Build(p)
   121  	if err != nil {
   122  		log.Fatal(err)
   123  	}
   124  
   125  	data, err := prepareAddChainData(chain, n)
   126  	if err != nil {
   127  		log.Fatal(err)
   128  	}
   129  	return data
   130  }
   131  
   132  // Data provided to templates.
   133  type AddChainData struct {
   134  	// Chain is the addition chain as a list of integers.
   135  	Chain addchain.Chain
   136  
   137  	// Ops is the complete sequence of addition operations required to compute
   138  	// the addition chain.
   139  	Ops addchain.Program
   140  
   141  	// Script is the condensed representation of the addition chain computation
   142  	// in the "addition chain calculator" language.
   143  	Script *ast.Chain
   144  
   145  	// Program is the intermediate representation of the addition chain
   146  	// computation. This representation is likely the most convenient for code
   147  	// generation. It contains a sequence of add, double and shift (repeated
   148  	// doubling) instructions required to compute the chain. Temporary variable
   149  	// allocation has been performed and the list of required temporaries
   150  	// populated.
   151  	Program *ir.Program
   152  
   153  	// Metadata about the addchain project and the specific release parameters.
   154  	// Please use this to include a reference or citation back to the addchain
   155  	// project in your generated output.
   156  	Meta *meta.Properties
   157  
   158  	N string // base 16 value of the value
   159  }
   160  
   161  // PrepareData builds input template data for the given addition chain script.
   162  func prepareAddChainData(s *ast.Chain, n string) (*AddChainData, error) {
   163  	// Prepare template data.
   164  	allocator := pass.Allocator{
   165  		Input:  "x",
   166  		Output: "z",
   167  		Format: "t%d",
   168  	}
   169  	// Translate to IR.
   170  	p, err := acc.Translate(s)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	// Apply processing passes: temporary variable allocation, and computing the
   176  	// full addition chain sequence and operations.
   177  	if err := pass.Exec(p, allocator, pass.Func(pass.Eval)); err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	return &AddChainData{
   182  		Chain:   p.Chain,
   183  		Ops:     p.Program,
   184  		Script:  s,
   185  		Program: p,
   186  		Meta:    meta.Meta,
   187  		N:       n,
   188  	}, nil
   189  }
   190  
   191  // Function is a function provided to templates.
   192  type Function struct {
   193  	Name        string
   194  	Description string
   195  	Func        interface{}
   196  }
   197  
   198  // Signature returns the function signature.
   199  func (f *Function) Signature() string {
   200  	return reflect.ValueOf(f.Func).Type().String()
   201  }
   202  
   203  // Functions is the list of functions provided to templates.
   204  var Functions = []*Function{
   205  	{
   206  		Name:        "add_",
   207  		Description: "If the input operation is an `ir.Add` then return it, otherwise return `nil`",
   208  		Func: func(op ir.Op) ir.Op {
   209  			if a, ok := op.(ir.Add); ok {
   210  				return a
   211  			}
   212  			return nil
   213  		},
   214  	},
   215  	{
   216  		Name:        "double_",
   217  		Description: "If the input operation is an `ir.Double` then return it, otherwise return `nil`",
   218  		Func: func(op ir.Op) ir.Op {
   219  			if d, ok := op.(ir.Double); ok {
   220  				return d
   221  			}
   222  			return nil
   223  		},
   224  	},
   225  	{
   226  		Name:        "shift_",
   227  		Description: "If the input operation is an `ir.Shift` then return it, otherwise return `nil`",
   228  		Func: func(op ir.Op) ir.Op {
   229  			if s, ok := op.(ir.Shift); ok {
   230  				return s
   231  			}
   232  			return nil
   233  		},
   234  	},
   235  	{
   236  		Name:        "inc_",
   237  		Description: "Increment an integer",
   238  		Func:        func(n int) int { return n + 1 },
   239  	},
   240  	{
   241  		Name:        "format_",
   242  		Description: "Formats an addition chain script (`*ast.Chain`) as a string",
   243  		Func:        printer.String,
   244  	},
   245  	{
   246  		Name:        "split_",
   247  		Description: "Calls `strings.Split`",
   248  		Func:        strings.Split,
   249  	},
   250  	{
   251  		Name:        "join_",
   252  		Description: "Calls `strings.Join`",
   253  		Func:        strings.Join,
   254  	},
   255  	{
   256  		Name:        "lines_",
   257  		Description: "Split input string into lines",
   258  		Func: func(s string) []string {
   259  			var lines []string
   260  			scanner := bufio.NewScanner(strings.NewReader(s))
   261  			for scanner.Scan() {
   262  				lines = append(lines, scanner.Text())
   263  			}
   264  			return lines
   265  		},
   266  	},
   267  	{
   268  		Name:        "ptr_",
   269  		Description: "adds & if it's a value",
   270  		Func: func(s *ir.Operand) string {
   271  			if s.String() == "x" {
   272  				return "&"
   273  			}
   274  			return ""
   275  		},
   276  	},
   277  	{
   278  		Name: "last_",
   279  		Func: func(x int, a interface{}) bool {
   280  			return x == reflect.ValueOf(a).Len()-1
   281  		},
   282  	},
   283  }
   284  
   285  // to speed up code generation, we cache addchain search results on disk
   286  func initCache() {
   287  	mAddchains = make(map[string]*AddChainData)
   288  
   289  	// read existing files in addchain directory
   290  	path, err := os.Getwd()
   291  	if err != nil {
   292  		log.Fatal(err)
   293  	}
   294  	addChainDir = filepath.Join(path, "addchain")
   295  	_ = os.Mkdir(addChainDir, 0700)
   296  	files, err := os.ReadDir(addChainDir)
   297  	if err != nil {
   298  		log.Fatal(err)
   299  	}
   300  
   301  	// preload pre-computed add chains
   302  	for _, entry := range files {
   303  		if entry.IsDir() {
   304  			continue
   305  		}
   306  		f, err := os.Open(filepath.Join(addChainDir, entry.Name()))
   307  		if err != nil {
   308  			log.Fatal(err)
   309  		}
   310  
   311  		// decode the addchain.Program
   312  		dec := gob.NewDecoder(f)
   313  		var p addchain.Program
   314  		err = dec.Decode(&p)
   315  		_ = f.Close()
   316  		if err != nil {
   317  			log.Fatal(err)
   318  		}
   319  		data := processSearchResult(p, filepath.Base(f.Name()))
   320  		log.Println("read", filepath.Base(f.Name()))
   321  
   322  		// save the data
   323  		mAddchains[filepath.Base(f.Name())] = data
   324  
   325  	}
   326  
   327  }