github.com/koko1123/flow-go-1@v0.29.6/fvm/environment/generate-wrappers/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/format"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  )
    11  
    12  const header = `// AUTO-GENERATED BY %s.  DO NOT MODIFY.
    13  
    14  package environment
    15  
    16  import (
    17      "github.com/koko1123/flow-go-1/fvm/errors"
    18      "github.com/koko1123/flow-go-1/fvm/state"
    19  	"github.com/koko1123/flow-go-1/module/trace"
    20  )
    21  
    22  func parseRestricted(
    23      txnState *state.TransactionState,
    24      spanName trace.SpanName,
    25  ) error {
    26      if txnState.IsParseRestricted() {
    27          return errors.NewParseRestrictedModeInvalidAccessFailure(spanName)
    28      }
    29  
    30      return nil
    31  }
    32  
    33  // Utility functions used for checking unexpected operation access while
    34  // cadence is parsing programs.
    35  //
    36  // The generic functions are of the form
    37  //      parseRestrict<x>Arg<y>Ret(txnState, spanName, callback, arg1, ..., argX)
    38  // where the callback expects <x> number of arguments, and <y> number of
    39  // return values (not counting error). If the callback expects no argument,
    40  // <x>Arg is omitted, and similarly for return value.`
    41  
    42  func generateWrapper(numArgs int, numRets int, content *FileContent) {
    43  	l := content.Line
    44  	push := content.PushIndent
    45  	pop := content.PopIndent
    46  
    47  	argsFuncSuffix := ""
    48  	if numArgs > 0 {
    49  		argsFuncSuffix = fmt.Sprintf("%dArg", numArgs)
    50  	}
    51  
    52  	argTypes := []string{}
    53  	argNames := []string{}
    54  	for i := 0; i < numArgs; i++ {
    55  		argTypes = append(argTypes, fmt.Sprintf("Arg%dT", i))
    56  		argNames = append(argNames, fmt.Sprintf("arg%d", i))
    57  	}
    58  
    59  	retsFuncSuffix := ""
    60  	if numRets > 0 {
    61  		retsFuncSuffix = fmt.Sprintf("%dRet", numRets)
    62  	}
    63  
    64  	retTypes := []string{}
    65  	retNames := []string{}
    66  	for i := 0; i < numRets; i++ {
    67  		retTypes = append(retTypes, fmt.Sprintf("Ret%dT", i))
    68  		retNames = append(retNames, fmt.Sprintf("value%d", i))
    69  	}
    70  
    71  	//
    72  	// Generate function signature
    73  	//
    74  
    75  	l("")
    76  	l("func parseRestrict%s%s[", argsFuncSuffix, retsFuncSuffix)
    77  	push()
    78  
    79  	for _, typeName := range append(argTypes, retTypes...) {
    80  		l("%s any,", typeName)
    81  	}
    82  
    83  	pop()
    84  	l("](")
    85  	push()
    86  
    87  	l("txnState *state.TransactionState,")
    88  	l("spanName trace.SpanName,")
    89  
    90  	callbackRet := "error"
    91  	if numRets > 0 {
    92  		callbackRet = "(" + strings.Join(append(retTypes, "error"), ", ") + ")"
    93  	}
    94  
    95  	l("callback func(%s) %s,", strings.Join(argTypes, ", "), callbackRet)
    96  
    97  	for i, argType := range argTypes {
    98  		l("%s %s,", argNames[i], argType)
    99  	}
   100  
   101  	pop()
   102  	if numRets == 0 {
   103  		l(") error {")
   104  	} else {
   105  		l(") (")
   106  		push()
   107  
   108  		for _, retType := range retTypes {
   109  			l("%s,", retType)
   110  		}
   111  		l("error,")
   112  
   113  		pop()
   114  		l(") {")
   115  	}
   116  	push()
   117  
   118  	//
   119  	// Generate parse restrict check
   120  	//
   121  
   122  	l("err := parseRestricted(txnState, spanName)")
   123  	l("if err != nil {")
   124  	push()
   125  
   126  	for i, retType := range retTypes {
   127  		l("var %s %s", retNames[i], retType)
   128  	}
   129  
   130  	l("return %s", strings.Join(append(retNames, "err"), ", "))
   131  
   132  	pop()
   133  	l("}")
   134  
   135  	//
   136  	// Generate callback invocation
   137  	//
   138  
   139  	l("")
   140  	l("return callback(%s)", strings.Join(argNames, ", "))
   141  
   142  	pop()
   143  	l("}")
   144  }
   145  
   146  func main() {
   147  	if len(os.Args) != 2 {
   148  		fmt.Printf("USAGE: %s <output file>\n", filepath.Base(os.Args[0]))
   149  		os.Exit(1)
   150  	}
   151  
   152  	cmd := append([]string{filepath.Base(os.Args[0])}, os.Args[1:]...)
   153  
   154  	content := NewFileContent()
   155  	content.Section(header, strings.Join(cmd, " "))
   156  
   157  	for numArgs := 1; numArgs < 4; numArgs++ {
   158  		generateWrapper(numArgs, 0, content)
   159  	}
   160  
   161  	for _, numArgs := range []int{0, 1, 2, 3, 4, 6} {
   162  		generateWrapper(numArgs, 1, content)
   163  	}
   164  
   165  	generateWrapper(1, 2, content)
   166  
   167  	buffer := &bytes.Buffer{}
   168  	_, err := content.WriteTo(buffer)
   169  	if err != nil {
   170  		panic(err) // This should never happen
   171  	}
   172  
   173  	source, formatErr := format.Source(buffer.Bytes())
   174  
   175  	// NOTE: formatting error can occur if the generated code has syntax
   176  	// errors.  We still want to write out the unformatted source for debugging
   177  	// purpose.
   178  	if formatErr != nil {
   179  		source = buffer.Bytes()
   180  	}
   181  
   182  	writeErr := os.WriteFile(os.Args[1], source, 0644)
   183  	if writeErr != nil {
   184  		panic(writeErr)
   185  	}
   186  
   187  	if formatErr != nil {
   188  		panic(formatErr)
   189  	}
   190  }