github.com/bitxmesh/gopher-lua@v0.0.0-20190327085718-93c344ef97a4/state_test.go (about)

     1  package lua
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  	"testing"
     7  	"time"
     8  )
     9  
    10  func TestCallStackOverflow(t *testing.T) {
    11  	L := NewState(Options{
    12  		CallStackSize: 3,
    13  	})
    14  	defer L.Close()
    15  	errorIfScriptNotFail(t, L, `
    16      local function a()
    17      end
    18      local function b()
    19        a()
    20      end
    21      local function c()
    22        print(_printregs())
    23        b()
    24      end
    25      c()
    26      `, "stack overflow")
    27  }
    28  
    29  func TestSkipOpenLibs(t *testing.T) {
    30  	L := NewState(Options{SkipOpenLibs: true})
    31  	defer L.Close()
    32  	errorIfScriptNotFail(t, L, `print("")`,
    33  		"attempt to call a non-function object")
    34  	L2 := NewState()
    35  	defer L2.Close()
    36  	errorIfScriptFail(t, L2, `print("")`)
    37  }
    38  
    39  func TestGetAndReplace(t *testing.T) {
    40  	L := NewState()
    41  	defer L.Close()
    42  	L.Push(LString("a"))
    43  	L.Replace(1, LString("b"))
    44  	L.Replace(0, LString("c"))
    45  	errorIfNotEqual(t, LNil, L.Get(0))
    46  	errorIfNotEqual(t, LNil, L.Get(-10))
    47  	errorIfNotEqual(t, L.Env, L.Get(EnvironIndex))
    48  	errorIfNotEqual(t, LString("b"), L.Get(1))
    49  	L.Push(LString("c"))
    50  	L.Push(LString("d"))
    51  	L.Replace(-2, LString("e"))
    52  	errorIfNotEqual(t, LString("e"), L.Get(-2))
    53  	registry := L.NewTable()
    54  	L.Replace(RegistryIndex, registry)
    55  	L.G.Registry = registry
    56  	errorIfGFuncNotFail(t, L, func(L *LState) int {
    57  		L.Replace(RegistryIndex, LNil)
    58  		return 0
    59  	}, "registry must be a table")
    60  	errorIfGFuncFail(t, L, func(L *LState) int {
    61  		env := L.NewTable()
    62  		L.Replace(EnvironIndex, env)
    63  		errorIfNotEqual(t, env, L.Get(EnvironIndex))
    64  		return 0
    65  	})
    66  	errorIfGFuncNotFail(t, L, func(L *LState) int {
    67  		L.Replace(EnvironIndex, LNil)
    68  		return 0
    69  	}, "environment must be a table")
    70  	errorIfGFuncFail(t, L, func(L *LState) int {
    71  		gbl := L.NewTable()
    72  		L.Replace(GlobalsIndex, gbl)
    73  		errorIfNotEqual(t, gbl, L.G.Global)
    74  		return 0
    75  	})
    76  	errorIfGFuncNotFail(t, L, func(L *LState) int {
    77  		L.Replace(GlobalsIndex, LNil)
    78  		return 0
    79  	}, "_G must be a table")
    80  
    81  	L2 := NewState()
    82  	defer L2.Close()
    83  	clo := L2.NewClosure(func(L2 *LState) int {
    84  		L2.Replace(UpvalueIndex(1), LNumber(3))
    85  		errorIfNotEqual(t, LNumber(3), L2.Get(UpvalueIndex(1)))
    86  		return 0
    87  	}, LNumber(1), LNumber(2))
    88  	L2.SetGlobal("clo", clo)
    89  	errorIfScriptFail(t, L2, `clo()`)
    90  }
    91  
    92  func TestRemove(t *testing.T) {
    93  	L := NewState()
    94  	defer L.Close()
    95  	L.Push(LString("a"))
    96  	L.Push(LString("b"))
    97  	L.Push(LString("c"))
    98  
    99  	L.Remove(4)
   100  	errorIfNotEqual(t, LString("a"), L.Get(1))
   101  	errorIfNotEqual(t, LString("b"), L.Get(2))
   102  	errorIfNotEqual(t, LString("c"), L.Get(3))
   103  	errorIfNotEqual(t, 3, L.GetTop())
   104  
   105  	L.Remove(3)
   106  	errorIfNotEqual(t, LString("a"), L.Get(1))
   107  	errorIfNotEqual(t, LString("b"), L.Get(2))
   108  	errorIfNotEqual(t, LNil, L.Get(3))
   109  	errorIfNotEqual(t, 2, L.GetTop())
   110  	L.Push(LString("c"))
   111  
   112  	L.Remove(-10)
   113  	errorIfNotEqual(t, LString("a"), L.Get(1))
   114  	errorIfNotEqual(t, LString("b"), L.Get(2))
   115  	errorIfNotEqual(t, LString("c"), L.Get(3))
   116  	errorIfNotEqual(t, 3, L.GetTop())
   117  
   118  	L.Remove(2)
   119  	errorIfNotEqual(t, LString("a"), L.Get(1))
   120  	errorIfNotEqual(t, LString("c"), L.Get(2))
   121  	errorIfNotEqual(t, LNil, L.Get(3))
   122  	errorIfNotEqual(t, 2, L.GetTop())
   123  }
   124  
   125  func TestToInt(t *testing.T) {
   126  	L := NewState()
   127  	defer L.Close()
   128  	L.Push(LNumber(10))
   129  	L.Push(LString("99.9"))
   130  	L.Push(L.NewTable())
   131  	errorIfNotEqual(t, 10, L.ToInt(1))
   132  	errorIfNotEqual(t, 99, L.ToInt(2))
   133  	errorIfNotEqual(t, 0, L.ToInt(3))
   134  }
   135  
   136  func TestToInt64(t *testing.T) {
   137  	L := NewState()
   138  	defer L.Close()
   139  	L.Push(LNumber(10))
   140  	L.Push(LString("99.9"))
   141  	L.Push(L.NewTable())
   142  	errorIfNotEqual(t, int64(10), L.ToInt64(1))
   143  	errorIfNotEqual(t, int64(99), L.ToInt64(2))
   144  	errorIfNotEqual(t, int64(0), L.ToInt64(3))
   145  }
   146  
   147  func TestToNumber(t *testing.T) {
   148  	L := NewState()
   149  	defer L.Close()
   150  	L.Push(LNumber(10))
   151  	L.Push(LString("99.9"))
   152  	L.Push(L.NewTable())
   153  	errorIfNotEqual(t, LNumber(10), L.ToNumber(1))
   154  	errorIfNotEqual(t, LNumber(99.9), L.ToNumber(2))
   155  	errorIfNotEqual(t, LNumber(0), L.ToNumber(3))
   156  }
   157  
   158  func TestToString(t *testing.T) {
   159  	L := NewState()
   160  	defer L.Close()
   161  	L.Push(LNumber(10))
   162  	L.Push(LString("99.9"))
   163  	L.Push(L.NewTable())
   164  	errorIfNotEqual(t, "10", L.ToString(1))
   165  	errorIfNotEqual(t, "99.9", L.ToString(2))
   166  	errorIfNotEqual(t, "", L.ToString(3))
   167  }
   168  
   169  func TestToTable(t *testing.T) {
   170  	L := NewState()
   171  	defer L.Close()
   172  	L.Push(LNumber(10))
   173  	L.Push(LString("99.9"))
   174  	L.Push(L.NewTable())
   175  	errorIfFalse(t, L.ToTable(1) == nil, "index 1 must be nil")
   176  	errorIfFalse(t, L.ToTable(2) == nil, "index 2 must be nil")
   177  	errorIfNotEqual(t, L.Get(3), L.ToTable(3))
   178  }
   179  
   180  func TestToFunction(t *testing.T) {
   181  	L := NewState()
   182  	defer L.Close()
   183  	L.Push(LNumber(10))
   184  	L.Push(LString("99.9"))
   185  	L.Push(L.NewFunction(func(L *LState) int { return 0 }))
   186  	errorIfFalse(t, L.ToFunction(1) == nil, "index 1 must be nil")
   187  	errorIfFalse(t, L.ToFunction(2) == nil, "index 2 must be nil")
   188  	errorIfNotEqual(t, L.Get(3), L.ToFunction(3))
   189  }
   190  
   191  func TestToUserData(t *testing.T) {
   192  	L := NewState()
   193  	defer L.Close()
   194  	L.Push(LNumber(10))
   195  	L.Push(LString("99.9"))
   196  	L.Push(L.NewUserData())
   197  	errorIfFalse(t, L.ToUserData(1) == nil, "index 1 must be nil")
   198  	errorIfFalse(t, L.ToUserData(2) == nil, "index 2 must be nil")
   199  	errorIfNotEqual(t, L.Get(3), L.ToUserData(3))
   200  }
   201  
   202  func TestToChannel(t *testing.T) {
   203  	L := NewState()
   204  	defer L.Close()
   205  	L.Push(LNumber(10))
   206  	L.Push(LString("99.9"))
   207  	var ch chan LValue
   208  	L.Push(LChannel(ch))
   209  	errorIfFalse(t, L.ToChannel(1) == nil, "index 1 must be nil")
   210  	errorIfFalse(t, L.ToChannel(2) == nil, "index 2 must be nil")
   211  	errorIfNotEqual(t, ch, L.ToChannel(3))
   212  }
   213  
   214  func TestObjLen(t *testing.T) {
   215  	L := NewState()
   216  	defer L.Close()
   217  	errorIfNotEqual(t, 3, L.ObjLen(LString("abc")))
   218  	tbl := L.NewTable()
   219  	tbl.Append(LTrue)
   220  	tbl.Append(LTrue)
   221  	errorIfNotEqual(t, 2, L.ObjLen(tbl))
   222  	mt := L.NewTable()
   223  	L.SetField(mt, "__len", L.NewFunction(func(L *LState) int {
   224  		tbl := L.CheckTable(1)
   225  		L.Push(LNumber(tbl.Len() + 1))
   226  		return 1
   227  	}))
   228  	L.SetMetatable(tbl, mt)
   229  	errorIfNotEqual(t, 3, L.ObjLen(tbl))
   230  	errorIfNotEqual(t, 0, L.ObjLen(LNumber(10)))
   231  }
   232  
   233  func TestConcat(t *testing.T) {
   234  	L := NewState()
   235  	defer L.Close()
   236  	errorIfNotEqual(t, "a1c", L.Concat(LString("a"), LNumber(1), LString("c")))
   237  }
   238  
   239  func TestPCall(t *testing.T) {
   240  	L := NewState()
   241  	defer L.Close()
   242  	L.Register("f1", func(L *LState) int {
   243  		panic("panic!")
   244  	})
   245  	errorIfScriptNotFail(t, L, `f1()`, "panic!")
   246  	L.Push(L.GetGlobal("f1"))
   247  	err := L.PCall(0, 0, L.NewFunction(func(L *LState) int {
   248  		L.Push(LString("by handler"))
   249  		return 1
   250  	}))
   251  	errorIfFalse(t, strings.Contains(err.Error(), "by handler"), "")
   252  
   253  	err = L.PCall(0, 0, L.NewFunction(func(L *LState) int {
   254  		L.RaiseError("error!")
   255  		return 1
   256  	}))
   257  	errorIfFalse(t, strings.Contains(err.Error(), "error!"), "")
   258  
   259  	err = L.PCall(0, 0, L.NewFunction(func(L *LState) int {
   260  		panic("panicc!")
   261  	}))
   262  	errorIfFalse(t, strings.Contains(err.Error(), "panicc!"), "")
   263  }
   264  
   265  func TestCoroutineApi1(t *testing.T) {
   266  	L := NewState()
   267  	defer L.Close()
   268  	co, _ := L.NewThread()
   269  	errorIfScriptFail(t, L, `
   270        function coro(v)
   271          assert(v == 10)
   272          local ret1, ret2 = coroutine.yield(1,2,3)
   273          assert(ret1 == 11)
   274          assert(ret2 == 12)
   275          coroutine.yield(4)
   276          return 5
   277        end
   278      `)
   279  	fn := L.GetGlobal("coro").(*LFunction)
   280  	st, err, values := L.Resume(co, fn, LNumber(10))
   281  	errorIfNotEqual(t, ResumeYield, st)
   282  	errorIfNotNil(t, err)
   283  	errorIfNotEqual(t, 3, len(values))
   284  	errorIfNotEqual(t, LNumber(1), values[0].(LNumber))
   285  	errorIfNotEqual(t, LNumber(2), values[1].(LNumber))
   286  	errorIfNotEqual(t, LNumber(3), values[2].(LNumber))
   287  
   288  	st, err, values = L.Resume(co, fn, LNumber(11), LNumber(12))
   289  	errorIfNotEqual(t, ResumeYield, st)
   290  	errorIfNotNil(t, err)
   291  	errorIfNotEqual(t, 1, len(values))
   292  	errorIfNotEqual(t, LNumber(4), values[0].(LNumber))
   293  
   294  	st, err, values = L.Resume(co, fn)
   295  	errorIfNotEqual(t, ResumeOK, st)
   296  	errorIfNotNil(t, err)
   297  	errorIfNotEqual(t, 1, len(values))
   298  	errorIfNotEqual(t, LNumber(5), values[0].(LNumber))
   299  
   300  	L.Register("myyield", func(L *LState) int {
   301  		return L.Yield(L.ToNumber(1))
   302  	})
   303  	errorIfScriptFail(t, L, `
   304        function coro_error()
   305          coroutine.yield(1,2,3)
   306          myyield(4)
   307          assert(false, "--failed--")
   308        end
   309      `)
   310  	fn = L.GetGlobal("coro_error").(*LFunction)
   311  	co, _ = L.NewThread()
   312  	st, err, values = L.Resume(co, fn)
   313  	errorIfNotEqual(t, ResumeYield, st)
   314  	errorIfNotNil(t, err)
   315  	errorIfNotEqual(t, 3, len(values))
   316  	errorIfNotEqual(t, LNumber(1), values[0].(LNumber))
   317  	errorIfNotEqual(t, LNumber(2), values[1].(LNumber))
   318  	errorIfNotEqual(t, LNumber(3), values[2].(LNumber))
   319  
   320  	st, err, values = L.Resume(co, fn)
   321  	errorIfNotEqual(t, ResumeYield, st)
   322  	errorIfNotNil(t, err)
   323  	errorIfNotEqual(t, 1, len(values))
   324  	errorIfNotEqual(t, LNumber(4), values[0].(LNumber))
   325  
   326  	st, err, values = L.Resume(co, fn)
   327  	errorIfNotEqual(t, ResumeError, st)
   328  	errorIfNil(t, err)
   329  	errorIfFalse(t, strings.Contains(err.Error(), "--failed--"), "error message must be '--failed--'")
   330  	st, err, values = L.Resume(co, fn)
   331  	errorIfNotEqual(t, ResumeError, st)
   332  	errorIfNil(t, err)
   333  	errorIfFalse(t, strings.Contains(err.Error(), "can not resume a dead thread"), "can not resume a dead thread")
   334  
   335  }
   336  
   337  func TestContextTimeout(t *testing.T) {
   338  	L := NewState()
   339  	defer L.Close()
   340  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   341  	defer cancel()
   342  	L.SetContext(ctx)
   343  	errorIfNotEqual(t, ctx, L.Context())
   344  	err := L.DoString(`
   345  	  local clock = os.clock
   346        function sleep(n)  -- seconds
   347          local t0 = clock()
   348          while clock() - t0 <= n do end
   349        end
   350  	  sleep(3)
   351  	`)
   352  	errorIfNil(t, err)
   353  	errorIfFalse(t, strings.Contains(err.Error(), "context deadline exceeded"), "execution must be canceled")
   354  
   355  	oldctx := L.RemoveContext()
   356  	errorIfNotEqual(t, ctx, oldctx)
   357  	errorIfNotNil(t, L.ctx)
   358  }
   359  
   360  func TestContextCancel(t *testing.T) {
   361  	L := NewState()
   362  	defer L.Close()
   363  	ctx, cancel := context.WithCancel(context.Background())
   364  	errch := make(chan error, 1)
   365  	L.SetContext(ctx)
   366  	go func() {
   367  		errch <- L.DoString(`
   368  	    local clock = os.clock
   369          function sleep(n)  -- seconds
   370            local t0 = clock()
   371            while clock() - t0 <= n do end
   372          end
   373  	    sleep(3)
   374  	  `)
   375  	}()
   376  	time.Sleep(1 * time.Second)
   377  	cancel()
   378  	err := <-errch
   379  	errorIfNil(t, err)
   380  	errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "execution must be canceled")
   381  }
   382  
   383  func TestContextWithCroutine(t *testing.T) {
   384  	L := NewState()
   385  	defer L.Close()
   386  	ctx, cancel := context.WithCancel(context.Background())
   387  	L.SetContext(ctx)
   388  	defer cancel()
   389  	L.DoString(`
   390  	    function coro()
   391  		  local i = 0
   392  		  while true do
   393  		    coroutine.yield(i)
   394  			i = i+1
   395  		  end
   396  		  return i
   397  	    end
   398  	`)
   399  	co, cocancel := L.NewThread()
   400  	defer cocancel()
   401  	fn := L.GetGlobal("coro").(*LFunction)
   402  	_, err, values := L.Resume(co, fn)
   403  	errorIfNotNil(t, err)
   404  	errorIfNotEqual(t, LNumber(0), values[0])
   405  	// cancel the parent context
   406  	cancel()
   407  	_, err, values = L.Resume(co, fn)
   408  	errorIfNil(t, err)
   409  	errorIfFalse(t, strings.Contains(err.Error(), "context canceled"), "coroutine execution must be canceled when the parent context is canceled")
   410  
   411  }
   412  
   413  func TestPCallAfterFail(t *testing.T) {
   414  	L := NewState()
   415  	defer L.Close()
   416  	errFn := L.NewFunction(func(L *LState) int {
   417  		L.RaiseError("error!")
   418  		return 0
   419  	})
   420  	changeError := L.NewFunction(func(L *LState) int {
   421  		L.Push(errFn)
   422  		err := L.PCall(0, 0, nil)
   423  		if err != nil {
   424  			L.RaiseError("A New Error")
   425  		}
   426  		return 0
   427  	})
   428  	L.Push(changeError)
   429  	err := L.PCall(0, 0, nil)
   430  	errorIfFalse(t, strings.Contains(err.Error(), "A New Error"), "error not propogated correctly")
   431  }