github.com/mavryk-network/mvgo@v1.19.9/cmd/tzgen/root.go (about)

     1  // Copyright (c) 2023 Blockwatch Data Inc.
     2  // Authors
     3  // - jean.schmitt@ubisoft.com
     4  // - abdul@blockwatch.cc
     5  
     6  package main
     7  
     8  import (
     9  	"bytes"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"net/url"
    15  	"os"
    16  
    17  	"github.com/iancoleman/strcase"
    18  	"github.com/mavryk-network/mvgo/internal/generate"
    19  	"github.com/mavryk-network/mvgo/internal/parse"
    20  	"github.com/pkg/errors"
    21  	"gopkg.in/yaml.v3"
    22  )
    23  
    24  var (
    25  	errExit = errors.New("exit")
    26  
    27  	endpointFlag  string
    28  	addressFlag   string
    29  	srcFlag       string
    30  	nameFlag      string
    31  	pkgFlag       string
    32  	outFlag       string
    33  	fixupFileFlag string
    34  )
    35  
    36  func init() {
    37  	flag.StringVar(&endpointFlag, "endpoint", "https://rpc.tzstats.com", "rpc endpoint")
    38  	flag.StringVar(&addressFlag, "address", "", "address of the contract. required if -src is not set")
    39  	flag.StringVar(&srcFlag, "src", "", "json file containing the contracts's script")
    40  	flag.StringVar(&nameFlag, "name", "", "name of the contract")
    41  	flag.StringVar(&pkgFlag, "pkg", "", "package name of the output go code")
    42  	flag.StringVar(&outFlag, "out", "", "output file. Prints to Stdout if not set")
    43  	flag.StringVar(&fixupFileFlag, "fixup", "", "yaml file to fix generated go code for automatically generated functions/variable names")
    44  }
    45  
    46  func parseFlags() error {
    47  	if len(os.Args) >= 2 {
    48  		switch os.Args[1] {
    49  		case "version":
    50  			printVersion()
    51  			return errExit
    52  		case "help":
    53  			fmt.Printf("Usage: %s [flags]\n", appName)
    54  			fmt.Println("\nFlags")
    55  			flag.PrintDefaults()
    56  		}
    57  	}
    58  	flag.Parse()
    59  	return nil
    60  }
    61  
    62  func runCommand() error {
    63  	if pkgFlag == "" {
    64  		return errors.New("-pkg is required, to get package name")
    65  	}
    66  	if nameFlag == "" {
    67  		return errors.New("-name is required to set name of contract")
    68  	}
    69  	src, err := getSrc()
    70  	if err != nil {
    71  		return errors.Wrap(err, "failed to get contract script")
    72  	}
    73  	generated, err := generateBindings(src)
    74  	if err != nil {
    75  		return errors.Wrap(err, "failed to generate bindings")
    76  	}
    77  	err = writeResult(generated)
    78  	if err != nil {
    79  		return errors.Wrap(err, "failed to write generated code to file")
    80  	}
    81  	return nil
    82  }
    83  
    84  func generateBindings(script []byte) ([]byte, error) {
    85  	var err error
    86  	data := generate.Data{
    87  		Address: addressFlag,
    88  		Package: pkgFlag,
    89  	}
    90  	data.Contract, data.Structs, err = parse.Parse(script, nameFlag)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if fixupFileFlag != "" {
    95  		fixupFile, err := os.ReadFile(fixupFileFlag)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  
   100  		var fixupCfg parse.FixupConfig
   101  		err = yaml.NewDecoder(bytes.NewReader(fixupFile)).Decode(&fixupCfg)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  
   106  		data.Structs = parse.Fixup(fixupCfg, data.Structs, strcase.ToCamel)
   107  	}
   108  
   109  	return generate.Render(&data)
   110  }
   111  
   112  func getSrc() ([]byte, error) {
   113  	if srcFlag != "" {
   114  		return os.ReadFile(srcFlag)
   115  	}
   116  
   117  	// Get source from RPC
   118  	// At this point, addressFlag is required
   119  	if addressFlag == "" {
   120  		return nil, errors.New("-address is required when getting script from rpc")
   121  	}
   122  
   123  	u, err := url.JoinPath(endpointFlag, "chains/main/blocks/head/context/contracts", addressFlag, "script")
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	res, err := http.Get(u)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	defer res.Body.Close()
   132  
   133  	if res.StatusCode != http.StatusOK {
   134  		return nil, errors.Errorf("failed to get contract script at url %s: %v", u, res.Status)
   135  	}
   136  	return io.ReadAll(res.Body)
   137  }
   138  
   139  func writeResult(out []byte) error {
   140  	if outFlag == "" {
   141  		_, err := os.Stdout.Write(out)
   142  		if err != nil {
   143  			return err
   144  		}
   145  		return nil
   146  	}
   147  	return os.WriteFile(outFlag, out, 0o644)
   148  }