zombiezen.com/go/lua@v0.0.0-20231013005828-290725fb9140/cmd/zombiezen-lua/lua.go (about)

     1  // Copyright 2023 Ross Light
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy of
     4  // this software and associated documentation files (the “Software”), to deal in
     5  // the Software without restriction, including without limitation the rights to
     6  // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
     7  // the Software, and to permit persons to whom the Software is furnished to do so,
     8  // subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in all
    11  // copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
    15  // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
    16  // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
    17  // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
    18  // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
    19  //
    20  // SPDX-License-Identifier: MIT
    21  
    22  // zombiezen-lua is a standalone Lua interpreter.
    23  package main
    24  
    25  import (
    26  	"bufio"
    27  	"errors"
    28  	"flag"
    29  	"fmt"
    30  	"io"
    31  	"math"
    32  	"os"
    33  	"path/filepath"
    34  	"strings"
    35  
    36  	"zombiezen.com/go/lua"
    37  )
    38  
    39  func main() {
    40  	programName := "zombiezen-lua"
    41  	if len(os.Args) > 0 && os.Args[0] != "" {
    42  		programName = filepath.Base(os.Args[0])
    43  	}
    44  	err := run(programName)
    45  	if err != nil {
    46  		fmt.Fprintf(os.Stderr, "%s: %v\n", programName, err)
    47  	}
    48  	if err != nil {
    49  		os.Exit(1)
    50  	}
    51  }
    52  
    53  func run(programName string) error {
    54  	var exprArgs []exprArg
    55  	flag.Usage = func() {
    56  		fmt.Fprintf(os.Stderr, "usage: %s [options] [script [args]]\n", programName)
    57  		flag.PrintDefaults()
    58  	}
    59  	flag.Var(exprArgFlag{'e', &exprArgs}, "e", "execute string '`stat`'")
    60  	flag.Var(exprArgFlag{'l', &exprArgs}, "l", "for `g=mod`, require library 'mod' into global 'g'")
    61  	interactive := flag.Bool("i", false, "enter interactive mode after executing 'script'")
    62  	showVersion := flag.Bool("v", false, "show version information")
    63  	noEnv := flag.Bool("E", false, "ignore environment variables")
    64  	flag.Parse()
    65  
    66  	if *showVersion || *interactive {
    67  		fmt.Println(lua.Copyright)
    68  	}
    69  
    70  	l := new(lua.State)
    71  	if *noEnv {
    72  		l.PushBoolean(true)
    73  		l.RawSetField(lua.RegistryIndex, "LUA_NOENV")
    74  	}
    75  	if err := lua.OpenLibraries(l); err != nil {
    76  		return err
    77  	}
    78  
    79  	var script int
    80  	if len(os.Args) == 0 {
    81  		script = -1
    82  	} else if flag.NArg() == 0 {
    83  		script = 0
    84  	} else {
    85  		script = len(os.Args) - flag.NArg()
    86  	}
    87  	if err := createArgTable(l, os.Args, script); err != nil {
    88  		return err
    89  	}
    90  
    91  	if !*noEnv {
    92  		if err := handleInit(l); err != nil {
    93  			return err
    94  		}
    95  	}
    96  	for _, arg := range exprArgs {
    97  		switch arg.c {
    98  		case 'e':
    99  			if err := doString(l, arg.val, "=(command line)"); err != nil {
   100  				fmt.Fprintf(os.Stderr, "%s: %v\n", programName, err)
   101  			}
   102  		case 'l':
   103  			if err := doLibrary(l, arg.val); err != nil {
   104  				fmt.Fprintf(os.Stderr, "%s: %v\n", programName, err)
   105  			}
   106  		default:
   107  			panic("unreachable")
   108  		}
   109  	}
   110  	if flag.NArg() > 0 {
   111  		if err := handleScript(l, flag.Args()); err != nil {
   112  			return err
   113  		}
   114  	}
   115  	if *interactive {
   116  		return doREPL(l)
   117  	}
   118  	hasE := false
   119  	for _, arg := range exprArgs {
   120  		if arg.c == 'e' {
   121  			hasE = true
   122  			break
   123  		}
   124  	}
   125  	if flag.NArg() == 0 && !*showVersion && !hasE {
   126  		// No active option.
   127  		// TODO(someday): Check whether stdin is a tty.
   128  		fmt.Println(lua.Copyright)
   129  		return doREPL(l)
   130  	}
   131  	return nil
   132  }
   133  
   134  func doREPL(l *lua.State) error {
   135  	s := bufio.NewScanner(os.Stdin)
   136  	for {
   137  		if err := loadLine(l, s); errors.As(err, new(inputError)) {
   138  			if errors.Is(err, io.EOF) {
   139  				return nil
   140  			}
   141  			return err
   142  		} else if err != nil {
   143  			fmt.Fprintln(os.Stderr, err)
   144  			continue
   145  		}
   146  		if err := doCall(l, 0, lua.MultipleReturns); err != nil {
   147  			fmt.Fprintln(os.Stderr, err)
   148  			continue
   149  		}
   150  		print(l, "")
   151  	}
   152  }
   153  
   154  func print(l *lua.State, errPrefix string) {
   155  	n := l.Top()
   156  	if n == 0 {
   157  		return
   158  	}
   159  	if !l.CheckStack(20) {
   160  		fmt.Fprintf(os.Stderr, "%stoo many results (%d) to print\n", errPrefix, n)
   161  		return
   162  	}
   163  	if _, err := l.Global("print", 0); err != nil {
   164  		fmt.Fprintf(os.Stderr, "%s%v\n", errPrefix, err)
   165  		return
   166  	}
   167  	l.Insert(1)
   168  	if err := l.Call(n, 0, 0); err != nil {
   169  		fmt.Fprintf(os.Stderr, "%serror calling 'print' (%v)\n", errPrefix, err)
   170  		return
   171  	}
   172  }
   173  
   174  func handleInit(l *lua.State) error {
   175  	name := fmt.Sprintf("=LUA_INIT_%s_%s", lua.VersionMajor, lua.VersionMinor)
   176  	init, ok := os.LookupEnv(name[1:])
   177  	if !ok {
   178  		name = "=LUA_INIT"
   179  		init, ok = os.LookupEnv(name[1:])
   180  		if !ok {
   181  			return nil
   182  		}
   183  	}
   184  	if filename, ok := strings.CutPrefix(init, "@"); ok {
   185  		return doFile(l, filename)
   186  	}
   187  	return doString(l, init, name)
   188  }
   189  
   190  func handleScript(l *lua.State, args []string) error {
   191  	var r io.ReadCloser
   192  	name := args[0]
   193  	if name == "-" {
   194  		r = io.NopCloser(os.Stdin)
   195  		name = "=stdin"
   196  	} else {
   197  		var err error
   198  		r, err = os.Open(name)
   199  		if err != nil {
   200  			return err
   201  		}
   202  		name = "@" + name
   203  	}
   204  	err := l.Load(r, name, "bt")
   205  	r.Close()
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	nArgs, err := pushArgs(l)
   211  	if err != nil {
   212  		return err
   213  	}
   214  	return doCall(l, nArgs, 0)
   215  }
   216  
   217  func pushArgs(l *lua.State) (int, error) {
   218  	if tp, err := l.Global("arg", 0); err != nil {
   219  		return 0, err
   220  	} else if tp != lua.TypeTable {
   221  		return 0, fmt.Errorf("'arg' (%v) is not a table", tp)
   222  	}
   223  	argIndex := l.AbsIndex(-1)
   224  	n, err := lua.Len(l, argIndex)
   225  	if err != nil {
   226  		return 0, err
   227  	}
   228  	if n > math.MaxInt || !l.CheckStack(int(n)+3) {
   229  		return 0, fmt.Errorf("too many arguments (%d) to script", n)
   230  	}
   231  	for i := int64(1); i <= n; i++ {
   232  		l.RawIndex(argIndex, i)
   233  	}
   234  	l.Remove(argIndex)
   235  	return int(n), nil
   236  }
   237  
   238  func doLibrary(l *lua.State, globname string) error {
   239  	globname, modname, ok := strings.Cut(globname, "=")
   240  	if !ok {
   241  		modname = globname
   242  	}
   243  	if _, err := l.Global("require", 0); err != nil {
   244  		return err
   245  	}
   246  	l.PushString(modname)
   247  	if err := doCall(l, 1, 1); err != nil {
   248  		return err
   249  	}
   250  	if err := l.SetGlobal(globname, 0); err != nil {
   251  		return err
   252  	}
   253  	return nil
   254  }
   255  
   256  func doString(l *lua.State, s string, chunkName string) error {
   257  	if err := l.LoadString(s, chunkName, "t"); err != nil {
   258  		l.Pop(1)
   259  		return err
   260  	}
   261  	return doCall(l, 0, 0)
   262  }
   263  
   264  func doFile(l *lua.State, name string) error {
   265  	f, err := os.Open(name)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	err = l.Load(f, "@"+name, "bt")
   270  	f.Close()
   271  	if err != nil {
   272  		l.Pop(1)
   273  		return err
   274  	}
   275  	return doCall(l, 0, 0)
   276  }
   277  
   278  func doCall(l *lua.State, nArgs, nResults int) error {
   279  	base := l.Top() - nArgs
   280  	l.PushClosure(0, msgHandler)
   281  	l.Insert(base)
   282  	// TODO(someday): Catch signals.
   283  	err := l.Call(nArgs, nResults, base)
   284  	if err != nil {
   285  		l.Pop(1)
   286  	}
   287  	l.Remove(base)
   288  	return err
   289  }
   290  
   291  func msgHandler(l *lua.State) (int, error) {
   292  	msg, ok := l.ToString(1)
   293  	if !ok {
   294  		if called, err := lua.CallMeta(l, 1, "__tostring"); called && err == nil && l.IsString(-1) {
   295  			// Already pushed onto stack and it's a string.
   296  			return 1, nil
   297  		}
   298  		msg = fmt.Sprintf("(error object is a %v value)", l.Type(1))
   299  	}
   300  	// TODO(soon): Append a standard traceback.
   301  	l.PushString(msg)
   302  	return 1, nil
   303  }
   304  
   305  func createArgTable(l *lua.State, args []string, script int) error {
   306  	nArg := len(args) - (script + 1)
   307  	l.CreateTable(nArg, script+1)
   308  	for i, arg := range args {
   309  		l.PushString(arg)
   310  		l.RawSetIndex(-2, int64(i-script))
   311  	}
   312  	if err := l.SetGlobal("arg", 0); err != nil {
   313  		return fmt.Errorf("create arg table: %v", err)
   314  	}
   315  
   316  	return nil
   317  }
   318  
   319  // loadLine reads a line and tries to compile it as an expression or statement.
   320  func loadLine(l *lua.State, s *bufio.Scanner) error {
   321  	l.SetTop(0)
   322  	line, err := readLine(l, s, true)
   323  	if err != nil {
   324  		return err
   325  	}
   326  	if err := addReturn(l, line); err == nil {
   327  		return nil
   328  	}
   329  	for {
   330  		err := l.LoadString(line, "=stdin", "t")
   331  		if err == nil {
   332  			return nil
   333  		}
   334  		if !isIncomplete(err) {
   335  			l.Pop(1)
   336  			return err
   337  		}
   338  		newLine, err := readLine(l, s, false)
   339  		if err != nil {
   340  			return err
   341  		}
   342  		line += "\n" + newLine
   343  	}
   344  }
   345  
   346  func readLine(l *lua.State, s *bufio.Scanner, firstLine bool) (string, error) {
   347  	p, err := prompt(l, firstLine)
   348  	if err != nil {
   349  		return "", inputError{fmt.Errorf("read line: %v", err)}
   350  	}
   351  	os.Stdout.WriteString(p)
   352  	if !s.Scan() {
   353  		err := s.Err()
   354  		if err == nil {
   355  			err = io.EOF
   356  		}
   357  		return "", inputError{fmt.Errorf("read line: %w", err)}
   358  	}
   359  	line := s.Text()
   360  	if firstLine && strings.HasPrefix(line, "=") {
   361  		line = "return " + line
   362  	}
   363  	return line, nil
   364  }
   365  
   366  type inputError struct {
   367  	err error
   368  }
   369  
   370  func (e inputError) Error() string {
   371  	return e.err.Error()
   372  }
   373  
   374  func (e inputError) Unwrap() error {
   375  	return e.err
   376  }
   377  
   378  func prompt(l *lua.State, firstLine bool) (string, error) {
   379  	if firstLine {
   380  		if tp, err := l.Global("_PROMPT", 0); err != nil {
   381  			l.Pop(1)
   382  			return "", err
   383  		} else if tp == lua.TypeNil {
   384  			l.Pop(1)
   385  			return "> ", nil
   386  		}
   387  	} else {
   388  		if tp, err := l.Global("_PROMPT2", 0); err != nil {
   389  			l.Pop(1)
   390  			return "", err
   391  		} else if tp == lua.TypeNil {
   392  			l.Pop(1)
   393  			return ">> ", nil
   394  		}
   395  	}
   396  	p, err := lua.ToString(l, -1)
   397  	l.Pop(1)
   398  	if err != nil {
   399  		return "", fmt.Errorf("custom prompt: %v", err)
   400  	}
   401  	return p, nil
   402  }
   403  
   404  func addReturn(l *lua.State, line string) error {
   405  	retLine := "return " + line + ";"
   406  	if err := l.LoadString(retLine, "=stdin", "t"); err != nil {
   407  		l.Pop(1)
   408  		return err
   409  	}
   410  	return nil
   411  }
   412  
   413  func isIncomplete(err error) bool {
   414  	if err == nil {
   415  		return false
   416  	}
   417  	return lua.IsSyntax(err) && strings.Contains(err.Error(), "<eof>")
   418  }
   419  
   420  type exprArg struct {
   421  	c   byte
   422  	val string
   423  }
   424  
   425  type exprArgFlag struct {
   426  	c     byte
   427  	slice *[]exprArg
   428  }
   429  
   430  func (f exprArgFlag) String() string {
   431  	if f.slice == nil {
   432  		return ""
   433  	}
   434  	first := true
   435  	sb := new(strings.Builder)
   436  	for _, arg := range *f.slice {
   437  		if arg.c != f.c {
   438  			continue
   439  		}
   440  		if first {
   441  			first = false
   442  		} else {
   443  			sb.WriteString(",")
   444  		}
   445  		sb.WriteString(arg.val)
   446  	}
   447  	return sb.String()
   448  }
   449  
   450  func (f exprArgFlag) Set(s string) error {
   451  	*f.slice = append(*f.slice, exprArg{
   452  		c:   f.c,
   453  		val: s,
   454  	})
   455  	return nil
   456  }
   457  
   458  func (f exprArgFlag) Get() any {
   459  	return *f.slice
   460  }