github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/cmd.go (about) 1 package main 2 3 import ( 4 "bufio" 5 "bytes" 6 "flag" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "os" 11 "strings" 12 13 "github.com/arnodel/golua/ast" 14 "github.com/arnodel/golua/lib" 15 "github.com/arnodel/golua/lib/base" 16 "github.com/arnodel/golua/lib/debuglib" 17 "github.com/arnodel/golua/lib/iolib" 18 rt "github.com/arnodel/golua/runtime" 19 ) 20 21 type luaCmd struct { 22 disFlag bool 23 astFlag bool 24 unbufferedFlag bool 25 cpuLimit uint64 26 memLimit uint64 27 flags string 28 exec execFlags 29 30 complianceFlags rt.ComplianceFlags 31 } 32 33 func (c *luaCmd) setFlags() { 34 flag.BoolVar(&c.disFlag, "dis", false, "Disassemble source instead of running it") 35 flag.BoolVar(&c.astFlag, "ast", false, "Print AST instead of running code") 36 flag.BoolVar(&c.unbufferedFlag, "u", false, "Force unbuffered output") 37 flag.Var(&c.exec, "e", "statement to execute") 38 39 if rt.QuotasAvailable { 40 flag.Uint64Var(&c.cpuLimit, "cpulimit", 0, "CPU limit") 41 flag.Uint64Var(&c.memLimit, "memlimit", 0, "memory limit") 42 flag.StringVar(&c.flags, "flags", "", "compliance flags turned on") 43 } 44 } 45 46 func (c *luaCmd) run() (retcode int) { 47 var ( 48 chunkName string 49 chunk []byte 50 err error 51 args []string 52 readStdin bool 53 repl bool 54 ) 55 56 buffered := !isaTTY(os.Stdin) || flag.NArg() > 0 57 if c.unbufferedFlag { 58 buffered = false 59 } 60 iolib.BufferedStdFiles = buffered 61 62 if c.flags != "" { 63 for _, name := range strings.Split(c.flags, ",") { 64 var ok bool 65 c.complianceFlags, ok = c.complianceFlags.AddFlagWithName(name) 66 if !ok { 67 return fatal("Unknown flag: %s", name) 68 } 69 } 70 } 71 72 // Get a Lua runtime 73 r := rt.New(nil) 74 c.pushContext(r) 75 76 cleanup := lib.LoadAll(r) 77 defer cleanup() 78 79 // Run finalizers before we exit 80 defer r.Close(nil) 81 82 if len(c.exec) == 0 && flag.NArg() == 0 { 83 chunkName = "<stdin>" 84 readStdin = true 85 repl = isaTTY(os.Stdin) 86 } 87 if flag.NArg() > 0 { 88 chunkName = flag.Arg(0) 89 chunk, err = ioutil.ReadFile(chunkName) 90 if err != nil { 91 return fatal("Error reading '%s': %s", chunkName, err) 92 } 93 args = flag.Args()[1:] 94 } 95 96 var argVals []rt.Value 97 if len(args) > 0 { 98 argTable := rt.NewTable() 99 argVals = make([]rt.Value, len(args)) 100 for i, arg := range args { 101 argVal := rt.StringValue(arg) 102 r.SetTable(argTable, rt.IntValue(int64(i+1)), argVal) 103 argVals[i] = argVal 104 } 105 r.SetTable(r.GlobalEnv(), rt.StringValue("arg"), rt.TableValue(argTable)) 106 } 107 108 for _, src := range c.exec { 109 unit, _, err := r.CompileLuaChunk("<exec>", []byte(src)) 110 if err != nil { 111 return fatal("Error parsing %q: %s", src, err) 112 } 113 clos := r.LoadLuaUnit(unit, rt.TableValue(r.GlobalEnv())) 114 cerr := rt.Call(r.MainThread(), rt.FunctionValue(clos), argVals, rt.NewTerminationWith(nil, 0, false)) 115 if cerr != nil { 116 return fatal("!!! %s", cerr.Error()) 117 } 118 } 119 120 if readStdin { 121 if repl { 122 return c.repl(r) 123 } 124 chunk, err = ioutil.ReadAll(os.Stdin) 125 if err != nil { 126 return fatal("Error reading <stdin>: %s", err) 127 } 128 } 129 130 if c.astFlag { 131 stat, _, err := r.ParseLuaChunk(chunkName, chunk) 132 if err != nil { 133 return fatal("Error parsing %s: %s", chunkName, err) 134 } 135 w := ast.NewIndentWriter(os.Stdout) 136 stat.HWrite(w) 137 return 0 138 } 139 140 if c.disFlag { 141 unit, _, err := r.CompileLuaChunk(chunkName, chunk) 142 if err != nil { 143 return fatal("Error parsing %s: %s", chunkName, err) 144 } 145 unit.Disassemble(os.Stdout) 146 return 0 147 } 148 149 defer func() { 150 if rec := recover(); rec != nil { 151 quotaExceeded, ok := rec.(rt.ContextTerminationError) 152 if !ok { 153 panic(r) 154 } 155 fmt.Fprintf(os.Stderr, "%s\n", quotaExceeded) 156 retcode = 2 157 } 158 }() 159 160 clos, err := r.LoadFromSourceOrCode(chunkName, chunk, "bt", rt.TableValue(r.GlobalEnv()), true) 161 if err != nil { 162 return fatal("Error loading %s: %s", chunkName, err) 163 } 164 cerr := rt.Call(r.MainThread(), rt.FunctionValue(clos), argVals, rt.NewTerminationWith(nil, 0, false)) 165 if cerr != nil { 166 return fatal("!!! %s", cerr.Error()) 167 } 168 return 0 169 } 170 171 func fatal(tpl string, args ...interface{}) int { 172 fmt.Fprintf(os.Stderr, tpl+"\n", args...) 173 return 1 174 } 175 176 func isaTTY(f *os.File) bool { 177 fi, _ := f.Stat() 178 return fi.Mode()&os.ModeCharDevice != 0 179 } 180 181 func (c *luaCmd) repl(r *rt.Runtime) int { 182 reader := bufio.NewReader(os.Stdin) 183 w := new(bytes.Buffer) 184 for { 185 if len(w.Bytes()) == 0 { 186 fmt.Print("> ") 187 } else { 188 fmt.Print("| ") 189 } 190 line, err := reader.ReadBytes('\n') 191 line = bytes.TrimLeft(line, ">|") 192 if err == io.EOF { 193 w.WriteTo(os.Stdout) 194 fmt.Print(string(line)) 195 return 0 196 } 197 _, err = w.Write(line) 198 if err != nil { 199 return fatal("error: %s", err) 200 } 201 more, err := c.runChunk(r, w.Bytes()) 202 if !more { 203 w = new(bytes.Buffer) 204 if err != nil { 205 fmt.Printf("!!! %s\n", err) 206 if _, ok := err.(rt.ContextTerminationError); ok { 207 fmt.Print("Reset limits and continue? [yN] ") 208 line, err := reader.ReadString('\n') 209 if err == io.EOF || strings.TrimSpace(line) != "y" { 210 return 0 211 } 212 r.PopContext() 213 c.pushContext(r) 214 } 215 } 216 } 217 } 218 } 219 220 func (c *luaCmd) runChunk(r *rt.Runtime, source []byte) (more bool, err error) { 221 defer func() { 222 if rec := recover(); rec != nil { 223 quotaExceeded, ok := rec.(rt.ContextTerminationError) 224 if !ok { 225 panic(r) 226 } 227 err = quotaExceeded 228 more = false 229 } 230 }() 231 clos, err := r.CompileAndLoadLuaChunkOrExp("<stdin>", source, rt.TableValue(r.GlobalEnv())) 232 if err != nil { 233 return rt.ErrorIsUnexpectedEOF(err), err 234 } 235 t := r.MainThread() 236 term := rt.NewTerminationWith(nil, 0, true) 237 cerr := rt.Call(t, rt.FunctionValue(clos), nil, term) 238 if cerr == nil { 239 if len(term.Etc()) > 0 { 240 cerr = base.Print(t, term.Etc()) 241 if cerr != nil { 242 return false, cerr 243 } 244 } 245 return false, nil 246 } 247 return false, cerr 248 } 249 250 func (c *luaCmd) pushContext(r *rt.Runtime) { 251 r.PushContext(rt.RuntimeContextDef{ 252 HardLimits: rt.RuntimeResources{ 253 Cpu: c.cpuLimit, 254 Memory: c.memLimit, 255 }, 256 RequiredFlags: c.complianceFlags, 257 MessageHandler: debuglib.Traceback, 258 }) 259 } 260 261 type execFlags []string 262 263 func (e *execFlags) String() string { 264 return strings.Join(*e, "; ") 265 } 266 267 func (e *execFlags) Set(value string) error { 268 *e = append(*e, value) 269 return nil 270 }