github.com/dfklegend/cell2/utils@v0.0.0-20240402033734-a0a9f3d9335d/golua/golua.go (about)

     1  package golua
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"os"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"errors"
    12  
    13  	libs "github.com/vadv/gopher-lua-libs"
    14  	lua "github.com/yuin/gopher-lua"
    15  	"github.com/yuin/gopher-lua/parse"
    16  )
    17  
    18  // compiledLuaProtoCache 编译好的lua proto缓存
    19  var compiledLuaProtoCache map[string]*lua.FunctionProto
    20  var userLuaPath string
    21  
    22  func init() {
    23  	InitLuaPathAndCompile("", false)
    24  	compiledLuaProtoCache = make(map[string]*lua.FunctionProto)
    25  }
    26  
    27  // SetUserLuaPath 和编译功能分开比较合适,后续优化
    28  // 如果一个工程目录下多个不同用途的lua脚本,用不同的入口函数来做区分
    29  func setUserLuaPath(luaPath string) {
    30  	if userLuaPath == luaPath {
    31  		return
    32  	}
    33  	userLuaPath = luaPath
    34  	switch userLuaPath {
    35  	case ".":
    36  		lua.LuaPathDefault = fmt.Sprintf("%s;./?.lua", lua.LuaPathDefault)
    37  	case "":
    38  		lua.LuaPathDefault = fmt.Sprintf("%s;?.lua", lua.LuaPathDefault)
    39  	default:
    40  		lua.LuaPathDefault = fmt.Sprintf("%s;%s/?.lua", lua.LuaPathDefault, luaPath)
    41  	}
    42  	os.Setenv(lua.LuaPath, lua.LuaPathDefault)
    43  }
    44  
    45  func InitLuaPathAndCompile(luaPath string, compile bool) {
    46  	setUserLuaPath(luaPath)
    47  	if compile {
    48  		compileLuaFiles(luaPath)
    49  	}
    50  }
    51  
    52  func CompileLuaFiles(luaPath string) {
    53  	compileLuaFiles(luaPath)
    54  }
    55  
    56  type LuaEngine struct {
    57  	L *lua.LState
    58  }
    59  
    60  func NewLuaEngine() *LuaEngine {
    61  
    62  	engine := LuaEngine{
    63  		L: lua.NewState(),
    64  	}
    65  	// 引入buildin库
    66  	engine.L.OpenLibs()
    67  
    68  	return &engine
    69  }
    70  
    71  // LoadModule 模块加载
    72  func (e *LuaEngine) LoadModule(module string, loader func(L *lua.LState) int) {
    73  	e.L.PreloadModule(module, loader)
    74  }
    75  
    76  // LoadGopherLuaLibs gopher-lua-libs
    77  func (e *LuaEngine) LoadGopherLuaLibs() {
    78  	libs.Preload(e.L)
    79  }
    80  
    81  func (e *LuaEngine) Close() {
    82  	e.L.Close()
    83  }
    84  
    85  func (e *LuaEngine) DoLuaString(luaCode string) error {
    86  	return e.L.DoString(luaCode)
    87  }
    88  
    89  func (e *LuaEngine) DoLuaFile(luaFile string) error {
    90  
    91  	var fileName string
    92  	if strings.HasSuffix(luaFile, ".lua") {
    93  		fileName = luaFile
    94  	} else {
    95  		fileName = fmt.Sprintf("%s.lua", luaFile)
    96  	}
    97  	if IsCompiled(fileName) {
    98  		return doCompiledLuaProto(e.L, compiledLuaProtoCache[fileName])
    99  	}
   100  	if "" == userLuaPath {
   101  		luaFile = fileName
   102  	} else {
   103  		luaFile = fmt.Sprintf("%s/%s", userLuaPath, fileName)
   104  	}
   105  	return checkError(e.L.DoFile(luaFile))
   106  }
   107  
   108  // DoLuaMethod call a func in luaFile
   109  func (e *LuaEngine) DoLuaMethod(luaFile string, method string, args ...any) error {
   110  	if err := e.DoLuaFile(luaFile); err != nil {
   111  		panic(err)
   112  	}
   113  
   114  	fn := e.L.Env.RawGetString(method)
   115  
   116  	if fn == nil || fn.Type() != lua.LTFunction {
   117  		panic(errors.New(fmt.Sprintf("not found lua method! %s:%s", luaFile, method)))
   118  	}
   119  
   120  	return Call(e.L, fn, args...)
   121  	//var argsArr []lua.LValue
   122  	//
   123  	//for _, arg := range args {
   124  	//	argsArr = append(argsArr, luar.New(e.L, arg))
   125  	//}
   126  	//
   127  	//err := e.L.CallByParam(lua.P{
   128  	//	Fn:      fn,
   129  	//	NRet:    0,
   130  	//	Protect: true,
   131  	//	Handler: nil,
   132  	//}, argsArr...)
   133  	//
   134  	//return checkError(err)
   135  }
   136  
   137  func (e *LuaEngine) DoLuaMethodWithResult(luaFile string, method string, args ...any) (lua.LValue, error) {
   138  
   139  	if err := e.DoLuaFile(luaFile); err != nil {
   140  		panic(err)
   141  	}
   142  
   143  	fn := e.L.Env.RawGetString(method)
   144  
   145  	if fn == nil || fn.Type() != lua.LTFunction {
   146  		panic(errors.New(fmt.Sprintf("not found lua method! %s:%s", luaFile, method)))
   147  	}
   148  
   149  	return CallWithResult(e.L, fn, args...)
   150  	//var argsArr []lua.LValue
   151  	//
   152  	//for _, arg := range args {
   153  	//	argsArr = append(argsArr, luar.New(e.L, arg))
   154  	//}
   155  	//
   156  	//err := e.L.CallByParam(lua.P{
   157  	//	Fn:      fn,
   158  	//	NRet:    1,
   159  	//	Protect: true,
   160  	//	Handler: nil,
   161  	//}, argsArr...)
   162  	//
   163  	//if err != nil {
   164  	//	checkError(err)
   165  	//}
   166  	//
   167  	//ret := e.L.Get(-1)
   168  	//e.L.Pop(1)
   169  	//
   170  	//return ret, err
   171  }
   172  
   173  //compileLuaFiles 编译lua脚本并缓存map[luaName,luaProto]
   174  func compileLuaFiles(luaDir string) {
   175  	fileInfos, err := ioutil.ReadDir(luaDir)
   176  	if err != nil {
   177  		panic(err)
   178  	}
   179  	for _, fileInfo := range fileInfos {
   180  
   181  		fileName := fileInfo.Name()
   182  
   183  		if fileInfo.IsDir() {
   184  			compileLuaFiles(fmt.Sprintf("%s/%s", luaDir, fileName))
   185  			continue
   186  		}
   187  
   188  		if !strings.HasSuffix(fileName, ".lua") {
   189  			continue
   190  		}
   191  
   192  		luaFile := fmt.Sprintf("%s/%s", luaDir, fileName)
   193  
   194  		proto, err := compileLuaFile(luaFile)
   195  		if err != nil {
   196  			panic(err)
   197  		}
   198  
   199  		// cache lua compiled proto
   200  		fileName = strings.ReplaceAll(luaFile, fmt.Sprintf("%s/", userLuaPath), "")
   201  		compiledLuaProtoCache[fileName] = proto
   202  	}
   203  }
   204  
   205  //IsCompiled 检查lua代码缓存是否存在
   206  func IsCompiled(luaName string) bool {
   207  	return compiledLuaProtoCache[luaName] != nil
   208  }
   209  
   210  func compileLuaFile(luaFile string) (*lua.FunctionProto, error) {
   211  	file, err := os.Open(luaFile)
   212  	defer file.Close()
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	reader := bufio.NewReader(file)
   217  	chunk, err := parse.Parse(reader, luaFile)
   218  	if err != nil {
   219  		return nil, checkError(&lua.ApiError{Type: lua.ApiErrorSyntax, Object: lua.LString(err.Error()), Cause: err})
   220  	}
   221  	proto, err := lua.Compile(chunk, luaFile)
   222  	if err != nil {
   223  		return nil, checkError(err)
   224  	}
   225  	return proto, nil
   226  }
   227  
   228  func doCompiledLuaProto(L *lua.LState, proto *lua.FunctionProto) error {
   229  	lfunc := L.NewFunctionFromProto(proto)
   230  	L.Push(lfunc)
   231  	return checkError(L.PCall(0, lua.MultRet, nil))
   232  }
   233  
   234  func stackTrace(apiError *lua.ApiError, level int) string {
   235  	//got "exception/example.lua line:13(column:3) near 'end':   syntax error"
   236  	line := apiError.Object.String()
   237  	//need "exception/example.lua:12"
   238  	file := strings.Split(line, " line:")[0]
   239  	lineNum := substringBetween(line, "line:", "(column")
   240  	n, _ := strconv.Atoi(lineNum)
   241  	var buf []string
   242  	header := "stack traceback:"
   243  	buf = append(buf, fmt.Sprintf("\t%v", fmt.Sprintf("%s:%d", file, n-1))) //行号要减去1
   244  	buf = append(buf, fmt.Sprintf("\t%v: %v", "[G]", "?"))
   245  	buf = buf[intMax(0, intMin(level, len(buf))):len(buf)]
   246  	if len(buf) > 20 {
   247  		newbuf := make([]string, 0, 20)
   248  		newbuf = append(newbuf, buf[0:7]...)
   249  		newbuf = append(newbuf, "\t...")
   250  		newbuf = append(newbuf, buf[len(buf)-7:len(buf)]...)
   251  		buf = newbuf
   252  	}
   253  	return fmt.Sprintf("%s\n%s", header, strings.Join(buf, "\n"))
   254  }
   255  
   256  func intMin(a, b int) int {
   257  	if a < b {
   258  		return a
   259  	} else {
   260  		return b
   261  	}
   262  }
   263  
   264  func intMax(a, b int) int {
   265  	if a > b {
   266  		return a
   267  	} else {
   268  		return b
   269  	}
   270  }
   271  
   272  //got "exception/example.lua line:13(column:3) near 'end':   syntax error"
   273  //need "exception/example.lua:12"
   274  func substringBetween(source string, left string, right string) string {
   275  	return strings.Split(strings.Split(source, left)[1], right)[0]
   276  }
   277  
   278  func checkError(err error) error {
   279  	if err == nil {
   280  		return nil
   281  	}
   282  
   283  	var apiError *lua.ApiError
   284  	if errors.As(err, &apiError) {
   285  		if apiError.Type == lua.ApiErrorSyntax {
   286  			apiError.StackTrace = stackTrace(apiError, 0)
   287  		}
   288  		panic(apiError)
   289  	}
   290  	return err
   291  }