github.com/mavryk-network/mvgo@v1.19.9/cmd/tzgen/root.go (about) 1 // Copyright (c) 2023 Blockwatch Data Inc. 2 // Authors 3 // - jean.schmitt@ubisoft.com 4 // - abdul@blockwatch.cc 5 6 package main 7 8 import ( 9 "bytes" 10 "flag" 11 "fmt" 12 "io" 13 "net/http" 14 "net/url" 15 "os" 16 17 "github.com/iancoleman/strcase" 18 "github.com/mavryk-network/mvgo/internal/generate" 19 "github.com/mavryk-network/mvgo/internal/parse" 20 "github.com/pkg/errors" 21 "gopkg.in/yaml.v3" 22 ) 23 24 var ( 25 errExit = errors.New("exit") 26 27 endpointFlag string 28 addressFlag string 29 srcFlag string 30 nameFlag string 31 pkgFlag string 32 outFlag string 33 fixupFileFlag string 34 ) 35 36 func init() { 37 flag.StringVar(&endpointFlag, "endpoint", "https://rpc.tzstats.com", "rpc endpoint") 38 flag.StringVar(&addressFlag, "address", "", "address of the contract. required if -src is not set") 39 flag.StringVar(&srcFlag, "src", "", "json file containing the contracts's script") 40 flag.StringVar(&nameFlag, "name", "", "name of the contract") 41 flag.StringVar(&pkgFlag, "pkg", "", "package name of the output go code") 42 flag.StringVar(&outFlag, "out", "", "output file. Prints to Stdout if not set") 43 flag.StringVar(&fixupFileFlag, "fixup", "", "yaml file to fix generated go code for automatically generated functions/variable names") 44 } 45 46 func parseFlags() error { 47 if len(os.Args) >= 2 { 48 switch os.Args[1] { 49 case "version": 50 printVersion() 51 return errExit 52 case "help": 53 fmt.Printf("Usage: %s [flags]\n", appName) 54 fmt.Println("\nFlags") 55 flag.PrintDefaults() 56 } 57 } 58 flag.Parse() 59 return nil 60 } 61 62 func runCommand() error { 63 if pkgFlag == "" { 64 return errors.New("-pkg is required, to get package name") 65 } 66 if nameFlag == "" { 67 return errors.New("-name is required to set name of contract") 68 } 69 src, err := getSrc() 70 if err != nil { 71 return errors.Wrap(err, "failed to get contract script") 72 } 73 generated, err := generateBindings(src) 74 if err != nil { 75 return errors.Wrap(err, "failed to generate bindings") 76 } 77 err = writeResult(generated) 78 if err != nil { 79 return errors.Wrap(err, "failed to write generated code to file") 80 } 81 return nil 82 } 83 84 func generateBindings(script []byte) ([]byte, error) { 85 var err error 86 data := generate.Data{ 87 Address: addressFlag, 88 Package: pkgFlag, 89 } 90 data.Contract, data.Structs, err = parse.Parse(script, nameFlag) 91 if err != nil { 92 return nil, err 93 } 94 if fixupFileFlag != "" { 95 fixupFile, err := os.ReadFile(fixupFileFlag) 96 if err != nil { 97 return nil, err 98 } 99 100 var fixupCfg parse.FixupConfig 101 err = yaml.NewDecoder(bytes.NewReader(fixupFile)).Decode(&fixupCfg) 102 if err != nil { 103 return nil, err 104 } 105 106 data.Structs = parse.Fixup(fixupCfg, data.Structs, strcase.ToCamel) 107 } 108 109 return generate.Render(&data) 110 } 111 112 func getSrc() ([]byte, error) { 113 if srcFlag != "" { 114 return os.ReadFile(srcFlag) 115 } 116 117 // Get source from RPC 118 // At this point, addressFlag is required 119 if addressFlag == "" { 120 return nil, errors.New("-address is required when getting script from rpc") 121 } 122 123 u, err := url.JoinPath(endpointFlag, "chains/main/blocks/head/context/contracts", addressFlag, "script") 124 if err != nil { 125 return nil, err 126 } 127 res, err := http.Get(u) 128 if err != nil { 129 return nil, err 130 } 131 defer res.Body.Close() 132 133 if res.StatusCode != http.StatusOK { 134 return nil, errors.Errorf("failed to get contract script at url %s: %v", u, res.Status) 135 } 136 return io.ReadAll(res.Body) 137 } 138 139 func writeResult(out []byte) error { 140 if outFlag == "" { 141 _, err := os.Stdout.Write(out) 142 if err != nil { 143 return err 144 } 145 return nil 146 } 147 return os.WriteFile(outFlag, out, 0o644) 148 }