github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/cmd.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"os"
    11  	"strings"
    12  
    13  	"github.com/arnodel/golua/ast"
    14  	"github.com/arnodel/golua/lib"
    15  	"github.com/arnodel/golua/lib/base"
    16  	"github.com/arnodel/golua/lib/debuglib"
    17  	"github.com/arnodel/golua/lib/iolib"
    18  	rt "github.com/arnodel/golua/runtime"
    19  )
    20  
    21  type luaCmd struct {
    22  	disFlag        bool
    23  	astFlag        bool
    24  	unbufferedFlag bool
    25  	cpuLimit       uint64
    26  	memLimit       uint64
    27  	flags          string
    28  	exec           execFlags
    29  
    30  	complianceFlags rt.ComplianceFlags
    31  }
    32  
    33  func (c *luaCmd) setFlags() {
    34  	flag.BoolVar(&c.disFlag, "dis", false, "Disassemble source instead of running it")
    35  	flag.BoolVar(&c.astFlag, "ast", false, "Print AST instead of running code")
    36  	flag.BoolVar(&c.unbufferedFlag, "u", false, "Force unbuffered output")
    37  	flag.Var(&c.exec, "e", "statement to execute")
    38  
    39  	if rt.QuotasAvailable {
    40  		flag.Uint64Var(&c.cpuLimit, "cpulimit", 0, "CPU limit")
    41  		flag.Uint64Var(&c.memLimit, "memlimit", 0, "memory limit")
    42  		flag.StringVar(&c.flags, "flags", "", "compliance flags turned on")
    43  	}
    44  }
    45  
    46  func (c *luaCmd) run() (retcode int) {
    47  	var (
    48  		chunkName string
    49  		chunk     []byte
    50  		err       error
    51  		args      []string
    52  		readStdin bool
    53  		repl      bool
    54  	)
    55  
    56  	buffered := !isaTTY(os.Stdin) || flag.NArg() > 0
    57  	if c.unbufferedFlag {
    58  		buffered = false
    59  	}
    60  	iolib.BufferedStdFiles = buffered
    61  
    62  	if c.flags != "" {
    63  		for _, name := range strings.Split(c.flags, ",") {
    64  			var ok bool
    65  			c.complianceFlags, ok = c.complianceFlags.AddFlagWithName(name)
    66  			if !ok {
    67  				return fatal("Unknown flag: %s", name)
    68  			}
    69  		}
    70  	}
    71  
    72  	// Get a Lua runtime
    73  	r := rt.New(nil)
    74  	c.pushContext(r)
    75  
    76  	cleanup := lib.LoadAll(r)
    77  	defer cleanup()
    78  
    79  	// Run finalizers before we exit
    80  	defer r.Close(nil)
    81  
    82  	if len(c.exec) == 0 && flag.NArg() == 0 {
    83  		chunkName = "<stdin>"
    84  		readStdin = true
    85  		repl = isaTTY(os.Stdin)
    86  	}
    87  	if flag.NArg() > 0 {
    88  		chunkName = flag.Arg(0)
    89  		chunk, err = ioutil.ReadFile(chunkName)
    90  		if err != nil {
    91  			return fatal("Error reading '%s': %s", chunkName, err)
    92  		}
    93  		args = flag.Args()[1:]
    94  	}
    95  
    96  	var argVals []rt.Value
    97  	if len(args) > 0 {
    98  		argTable := rt.NewTable()
    99  		argVals = make([]rt.Value, len(args))
   100  		for i, arg := range args {
   101  			argVal := rt.StringValue(arg)
   102  			r.SetTable(argTable, rt.IntValue(int64(i+1)), argVal)
   103  			argVals[i] = argVal
   104  		}
   105  		r.SetTable(r.GlobalEnv(), rt.StringValue("arg"), rt.TableValue(argTable))
   106  	}
   107  
   108  	for _, src := range c.exec {
   109  		unit, _, err := r.CompileLuaChunk("<exec>", []byte(src))
   110  		if err != nil {
   111  			return fatal("Error parsing %q: %s", src, err)
   112  		}
   113  		clos := r.LoadLuaUnit(unit, rt.TableValue(r.GlobalEnv()))
   114  		cerr := rt.Call(r.MainThread(), rt.FunctionValue(clos), argVals, rt.NewTerminationWith(nil, 0, false))
   115  		if cerr != nil {
   116  			return fatal("!!! %s", cerr.Error())
   117  		}
   118  	}
   119  
   120  	if readStdin {
   121  		if repl {
   122  			return c.repl(r)
   123  		}
   124  		chunk, err = ioutil.ReadAll(os.Stdin)
   125  		if err != nil {
   126  			return fatal("Error reading <stdin>: %s", err)
   127  		}
   128  	}
   129  
   130  	if c.astFlag {
   131  		stat, _, err := r.ParseLuaChunk(chunkName, chunk)
   132  		if err != nil {
   133  			return fatal("Error parsing %s: %s", chunkName, err)
   134  		}
   135  		w := ast.NewIndentWriter(os.Stdout)
   136  		stat.HWrite(w)
   137  		return 0
   138  	}
   139  
   140  	if c.disFlag {
   141  		unit, _, err := r.CompileLuaChunk(chunkName, chunk)
   142  		if err != nil {
   143  			return fatal("Error parsing %s: %s", chunkName, err)
   144  		}
   145  		unit.Disassemble(os.Stdout)
   146  		return 0
   147  	}
   148  
   149  	defer func() {
   150  		if rec := recover(); rec != nil {
   151  			quotaExceeded, ok := rec.(rt.ContextTerminationError)
   152  			if !ok {
   153  				panic(r)
   154  			}
   155  			fmt.Fprintf(os.Stderr, "%s\n", quotaExceeded)
   156  			retcode = 2
   157  		}
   158  	}()
   159  
   160  	clos, err := r.LoadFromSourceOrCode(chunkName, chunk, "bt", rt.TableValue(r.GlobalEnv()), true)
   161  	if err != nil {
   162  		return fatal("Error loading %s: %s", chunkName, err)
   163  	}
   164  	cerr := rt.Call(r.MainThread(), rt.FunctionValue(clos), argVals, rt.NewTerminationWith(nil, 0, false))
   165  	if cerr != nil {
   166  		return fatal("!!! %s", cerr.Error())
   167  	}
   168  	return 0
   169  }
   170  
   171  func fatal(tpl string, args ...interface{}) int {
   172  	fmt.Fprintf(os.Stderr, tpl+"\n", args...)
   173  	return 1
   174  }
   175  
   176  func isaTTY(f *os.File) bool {
   177  	fi, _ := f.Stat()
   178  	return fi.Mode()&os.ModeCharDevice != 0
   179  }
   180  
   181  func (c *luaCmd) repl(r *rt.Runtime) int {
   182  	reader := bufio.NewReader(os.Stdin)
   183  	w := new(bytes.Buffer)
   184  	for {
   185  		if len(w.Bytes()) == 0 {
   186  			fmt.Print("> ")
   187  		} else {
   188  			fmt.Print("| ")
   189  		}
   190  		line, err := reader.ReadBytes('\n')
   191  		line = bytes.TrimLeft(line, ">|")
   192  		if err == io.EOF {
   193  			w.WriteTo(os.Stdout)
   194  			fmt.Print(string(line))
   195  			return 0
   196  		}
   197  		_, err = w.Write(line)
   198  		if err != nil {
   199  			return fatal("error: %s", err)
   200  		}
   201  		more, err := c.runChunk(r, w.Bytes())
   202  		if !more {
   203  			w = new(bytes.Buffer)
   204  			if err != nil {
   205  				fmt.Printf("!!! %s\n", err)
   206  				if _, ok := err.(rt.ContextTerminationError); ok {
   207  					fmt.Print("Reset limits and continue? [yN] ")
   208  					line, err := reader.ReadString('\n')
   209  					if err == io.EOF || strings.TrimSpace(line) != "y" {
   210  						return 0
   211  					}
   212  					r.PopContext()
   213  					c.pushContext(r)
   214  				}
   215  			}
   216  		}
   217  	}
   218  }
   219  
   220  func (c *luaCmd) runChunk(r *rt.Runtime, source []byte) (more bool, err error) {
   221  	defer func() {
   222  		if rec := recover(); rec != nil {
   223  			quotaExceeded, ok := rec.(rt.ContextTerminationError)
   224  			if !ok {
   225  				panic(r)
   226  			}
   227  			err = quotaExceeded
   228  			more = false
   229  		}
   230  	}()
   231  	clos, err := r.CompileAndLoadLuaChunkOrExp("<stdin>", source, rt.TableValue(r.GlobalEnv()))
   232  	if err != nil {
   233  		return rt.ErrorIsUnexpectedEOF(err), err
   234  	}
   235  	t := r.MainThread()
   236  	term := rt.NewTerminationWith(nil, 0, true)
   237  	cerr := rt.Call(t, rt.FunctionValue(clos), nil, term)
   238  	if cerr == nil {
   239  		if len(term.Etc()) > 0 {
   240  			cerr = base.Print(t, term.Etc())
   241  			if cerr != nil {
   242  				return false, cerr
   243  			}
   244  		}
   245  		return false, nil
   246  	}
   247  	return false, cerr
   248  }
   249  
   250  func (c *luaCmd) pushContext(r *rt.Runtime) {
   251  	r.PushContext(rt.RuntimeContextDef{
   252  		HardLimits: rt.RuntimeResources{
   253  			Cpu:    c.cpuLimit,
   254  			Memory: c.memLimit,
   255  		},
   256  		RequiredFlags:  c.complianceFlags,
   257  		MessageHandler: debuglib.Traceback,
   258  	})
   259  }
   260  
   261  type execFlags []string
   262  
   263  func (e *execFlags) String() string {
   264  	return strings.Join(*e, "; ")
   265  }
   266  
   267  func (e *execFlags) Set(value string) error {
   268  	*e = append(*e, value)
   269  	return nil
   270  }