github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/actions/lua/load.go (about)

     1  package lua
     2  
     3  import (
     4  	"bytes"
     5  	"embed"
     6  	"errors"
     7  	"fmt"
     8  	"io/fs"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"github.com/Shopify/go-lua"
    14  	"github.com/hashicorp/go-multierror"
    15  )
    16  
    17  const (
    18  	pathListSeparator = ';'
    19  	defaultPath       = "?.lua"
    20  )
    21  
    22  //go:embed lakefs/catalogexport/*.lua
    23  var luaEmbeddedCode embed.FS
    24  
    25  var ErrNoFile = errors.New("no file")
    26  
    27  func findLoader(l *lua.State, name string) {
    28  	var msg string
    29  	if l.Field(lua.UpValueIndex(1), "searchers"); !l.IsTable(3) {
    30  		lua.Errorf(l, "'package.searchers' must be a table")
    31  	}
    32  	for i := 1; ; i++ {
    33  		if l.RawGetInt(3, i); l.IsNil(-1) {
    34  			l.Pop(1)
    35  			l.PushString(msg)
    36  			lua.Errorf(l, "module '%s' not found: %s", name, msg)
    37  		}
    38  		l.PushString(name)
    39  		if l.Call(1, 2); l.IsFunction(-2) {
    40  			return
    41  		} else if l.IsString(-2) {
    42  			msg += lua.CheckString(l, -2)
    43  		}
    44  		l.Pop(2)
    45  	}
    46  }
    47  
    48  func findFile(l *lua.State, name, field, dirSep string) (string, error) {
    49  	l.Field(lua.UpValueIndex(1), field)
    50  	path, ok := l.ToString(-1)
    51  	if !ok {
    52  		lua.Errorf(l, "'package.%s' must be a string", field)
    53  	}
    54  	return searchPath(name, path, ".", dirSep)
    55  }
    56  
    57  func checkLoad(l *lua.State, loaded bool, fileName string) int {
    58  	if loaded { // Module loaded successfully?
    59  		l.PushString(fileName) // Second argument to module.
    60  		return 2               // Return open function & file name.
    61  	}
    62  	m := lua.CheckString(l, 1)
    63  	e := lua.CheckString(l, -1)
    64  	lua.Errorf(l, "error loading module '%s' from file '%s':\n\t%s", m, fileName, e)
    65  	panic("unreachable")
    66  }
    67  
    68  func searcherLua(l *lua.State) int {
    69  	name := lua.CheckString(l, 1)
    70  	filename, err := findFile(l, name, "path", string(filepath.Separator))
    71  	if err != nil {
    72  		return 1 // Module isn't found in this path.
    73  	}
    74  
    75  	return checkLoad(l, loadFile(l, filename, "") == nil, filename)
    76  }
    77  
    78  func loadFile(l *lua.State, fileName, mode string) error {
    79  	fileNameIndex := l.Top() + 1
    80  	fileError := func(what string) error {
    81  		fileName, _ := l.ToString(fileNameIndex)
    82  		l.PushFString("cannot %s %s", what, fileName[1:])
    83  		l.Remove(fileNameIndex)
    84  		return lua.FileError
    85  	}
    86  	l.PushString("@" + fileName)
    87  	data, err := luaEmbeddedCode.ReadFile(fileName)
    88  	if err != nil {
    89  		return fileError("open")
    90  	}
    91  	s, _ := l.ToString(-1)
    92  	err = l.Load(bytes.NewReader(data), s, mode)
    93  	switch {
    94  	case err == nil, errors.Is(err, lua.SyntaxError), errors.Is(err, lua.MemoryError): // do nothing
    95  	default:
    96  		l.SetTop(fileNameIndex)
    97  		return fileError("read")
    98  	}
    99  	l.Remove(fileNameIndex)
   100  	return err
   101  }
   102  
   103  func searcherPreload(l *lua.State) int {
   104  	name := lua.CheckString(l, 1)
   105  	l.Field(lua.RegistryIndex, "_PRELOAD")
   106  	l.Field(-1, name)
   107  	if l.IsNil(-1) {
   108  		l.PushString(fmt.Sprintf("\n\tno field package.preload['%s']", name))
   109  	}
   110  	return 1
   111  }
   112  
   113  func createSearchersTable(l *lua.State) {
   114  	searchers := []lua.Function{searcherPreload, searcherLua}
   115  	l.CreateTable(len(searchers), 0)
   116  	for i, s := range searchers {
   117  		l.PushValue(-2)
   118  		l.PushGoClosure(s, 1)
   119  		l.RawSetInt(-2, i+1)
   120  	}
   121  }
   122  
   123  func searchPath(name, path, sep, dirSep string) (string, error) {
   124  	var err error
   125  	if sep != "" {
   126  		name = strings.ReplaceAll(name, sep, dirSep) // Replace sep by dirSep.
   127  	}
   128  	path = strings.ReplaceAll(path, string(pathListSeparator), string(filepath.ListSeparator))
   129  	for _, template := range filepath.SplitList(path) {
   130  		if template == "" {
   131  			continue
   132  		}
   133  		filename := strings.ReplaceAll(template, "?", name)
   134  		if readable(filename) {
   135  			return filename, nil
   136  		}
   137  		err = multierror.Append(err, fmt.Errorf("%w %s", ErrNoFile, filename))
   138  	}
   139  	return "", err
   140  }
   141  
   142  func readable(name string) bool {
   143  	if !fs.ValidPath(name) {
   144  		return false
   145  	}
   146  	info, err := fs.Stat(luaEmbeddedCode, name)
   147  	return err == nil && !info.IsDir()
   148  }
   149  
   150  func noEnv(l *lua.State) bool {
   151  	l.Field(lua.RegistryIndex, "LUA_NOENV")
   152  	b := l.ToBoolean(-1)
   153  	l.Pop(1)
   154  	return b
   155  }
   156  
   157  func setPath(l *lua.State, field, env, def string) {
   158  	if path := os.Getenv(env); path == "" || noEnv(l) {
   159  		l.PushString(def)
   160  	} else {
   161  		o := fmt.Sprintf("%c%c", pathListSeparator, pathListSeparator)
   162  		n := fmt.Sprintf("%c%s%c", pathListSeparator, def, pathListSeparator)
   163  		path = strings.ReplaceAll(path, o, n)
   164  		l.PushString(path)
   165  	}
   166  	l.SetField(-2, field)
   167  }
   168  
   169  var packageLibrary = []lua.RegistryFunction{
   170  	{Name: "loadlib", Function: func(l *lua.State) int {
   171  		_ = lua.CheckString(l, 1) // path
   172  		_ = lua.CheckString(l, 2) // init
   173  		l.PushNil()
   174  		l.PushString("dynamic libraries not enabled; check your Lua installation")
   175  		l.PushString("absent")
   176  		return 3 // Return nil, error message, and where.
   177  	}},
   178  	{Name: "searchpath", Function: func(l *lua.State) int {
   179  		name := lua.CheckString(l, 1)
   180  		path := lua.CheckString(l, 2)
   181  		sep := lua.OptString(l, 3, ".")
   182  		dirSep := lua.OptString(l, 4, string(filepath.Separator))
   183  		f, err := searchPath(name, path, sep, dirSep)
   184  		if err != nil {
   185  			l.PushNil()
   186  			l.PushString(err.Error())
   187  			return 2
   188  		}
   189  		l.PushString(f)
   190  		return 1
   191  	}},
   192  }
   193  
   194  // PackageOpen opens the package library. Usually passed to Require.
   195  func PackageOpen(l *lua.State) int {
   196  	lua.NewLibrary(l, packageLibrary)
   197  	createSearchersTable(l)
   198  	l.SetField(-2, "searchers")
   199  	setPath(l, "path", "LUA_PATH", defaultPath)
   200  	l.PushString(fmt.Sprintf("%c\n%c\n?\n!\n-\n", filepath.Separator, pathListSeparator))
   201  	l.SetField(-2, "config")
   202  	lua.SubTable(l, lua.RegistryIndex, "_LOADED")
   203  	l.SetField(-2, "loaded")
   204  	lua.SubTable(l, lua.RegistryIndex, "_PRELOAD")
   205  	l.SetField(-2, "preload")
   206  	l.PushGlobalTable()
   207  	l.PushValue(-2)
   208  	lua.SetFunctions(l, []lua.RegistryFunction{
   209  		{Name: "require", Function: func(l *lua.State) int {
   210  			name := lua.CheckString(l, 1)
   211  			l.SetTop(1)
   212  			l.Field(lua.RegistryIndex, "_LOADED")
   213  			l.Field(2, name)
   214  			if l.ToBoolean(-1) {
   215  				return 1
   216  			}
   217  			l.Pop(1)
   218  			findLoader(l, name)
   219  			l.PushString(name)
   220  			l.Insert(-2)
   221  			l.Call(2, 1)
   222  			if !l.IsNil(-1) {
   223  				l.SetField(2, name)
   224  			}
   225  			l.Field(2, name)
   226  			if l.IsNil(-1) {
   227  				l.PushBoolean(true)
   228  				l.PushValue(-1)
   229  				l.SetField(2, name)
   230  			}
   231  			return 1
   232  		}},
   233  	}, 1)
   234  	l.Pop(1)
   235  	return 1
   236  }