github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/lib/tablelib/tablelib.go (about)

     1  package tablelib
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math"
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/arnodel/golua/lib/packagelib"
    11  
    12  	rt "github.com/arnodel/golua/runtime"
    13  )
    14  
    15  // LibLoader can load the table lib.
    16  var LibLoader = packagelib.Loader{
    17  	Load: load,
    18  	Name: "table",
    19  }
    20  
    21  func load(r *rt.Runtime) (rt.Value, func()) {
    22  	pkg := rt.NewTable()
    23  
    24  	rt.SolemnlyDeclareCompliance(
    25  		rt.ComplyCpuSafe|rt.ComplyMemSafe|rt.ComplyTimeSafe|rt.ComplyIoSafe,
    26  
    27  		r.SetEnvGoFunc(pkg, "concat", concat, 4, false),
    28  		r.SetEnvGoFunc(pkg, "insert", insert, 3, false),
    29  		r.SetEnvGoFunc(pkg, "move", move, 5, false),
    30  		r.SetEnvGoFunc(pkg, "pack", pack, 0, true),
    31  		r.SetEnvGoFunc(pkg, "remove", remove, 2, false),
    32  		r.SetEnvGoFunc(pkg, "sort", sortf, 2, false),
    33  		r.SetEnvGoFunc(pkg, "unpack", unpack, 3, false),
    34  	)
    35  
    36  	return rt.TableValue(pkg), nil
    37  }
    38  
    39  func concat(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    40  	if err := c.Check1Arg(); err != nil {
    41  		return nil, err
    42  	}
    43  	_, err := c.TableArg(0)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	tblVal := c.Arg(0)
    48  	var (
    49  		sep string
    50  		i   int64 = 1
    51  	)
    52  	j, err := rt.IntLen(t, tblVal)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	switch nargs := c.NArgs(); {
    57  	case nargs >= 4:
    58  		j, err = c.IntArg(3)
    59  		if err != nil {
    60  			break
    61  		}
    62  		fallthrough
    63  	case nargs >= 3:
    64  		i, err = c.IntArg(2)
    65  		if err != nil {
    66  			break
    67  		}
    68  		fallthrough
    69  	case nargs >= 2:
    70  		sep, err = c.StringArg(1)
    71  		if err != nil {
    72  			break
    73  		}
    74  		fallthrough
    75  	default:
    76  		var item rt.Value
    77  		if i > j {
    78  			return c.PushingNext1(t.Runtime, rt.StringValue("")), nil
    79  		}
    80  		item, err = rt.Index(t, tblVal, rt.IntValue(i))
    81  		if err != nil {
    82  			break
    83  		}
    84  		var sb strings.Builder
    85  		s, ok := item.ToString()
    86  		if !ok {
    87  			return nil, errInvalidConcatValue(item, i)
    88  		}
    89  		t.RequireBytes(len(s))
    90  		sb.WriteString(s)
    91  		for {
    92  			// Don't require CPU because rt.Index will do
    93  			if i == math.MaxInt64 {
    94  				break
    95  			}
    96  			i++
    97  			if i > j {
    98  				break
    99  			}
   100  			t.RequireBytes(len(sep))
   101  			sb.WriteString(sep)
   102  			item, err = rt.Index(t, tblVal, rt.IntValue(i))
   103  			if err != nil {
   104  				return nil, err
   105  			}
   106  			s, ok = item.ToString()
   107  			if !ok {
   108  				return nil, errInvalidConcatValue(item, i)
   109  			}
   110  			t.RequireBytes(len(s))
   111  			sb.WriteString(s)
   112  		}
   113  		return c.PushingNext1(t.Runtime, rt.StringValue(sb.String())), nil
   114  	}
   115  	return nil, err
   116  }
   117  
   118  func errInvalidConcatValue(v rt.Value, i int64) error {
   119  	s, _ := v.ToString()
   120  	return fmt.Errorf("invalid value (%s) at index %d in table for 'concat'", s, i)
   121  }
   122  
   123  func insert(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   124  	if err := c.CheckNArgs(2); err != nil {
   125  		return nil, err
   126  	}
   127  	_, err := c.TableArg(0)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	tblVal := c.Arg(0)
   132  	var (
   133  		val rt.Value
   134  		pos int64
   135  	)
   136  	tblLen, err := rt.IntLen(t, tblVal)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	if c.NArgs() >= 3 {
   141  		pos, err = c.IntArg(1)
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  		if pos <= 0 || pos > tblLen+1 {
   146  			return nil, errors.New("#2 out of range")
   147  		}
   148  		val = c.Arg(2)
   149  	} else {
   150  		pos = tblLen + 1
   151  		val = c.Arg(1)
   152  	}
   153  	var (
   154  		oldVal rt.Value
   155  		posVal = rt.IntValue(pos)
   156  	)
   157  	for pos <= tblLen {
   158  		// Don't require CPU because rt.Index and rt.SetIndex will do
   159  		oldVal, err = rt.Index(t, tblVal, posVal)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		err = rt.SetIndex(t, tblVal, posVal, val)
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  		val = oldVal
   168  		pos++
   169  		posVal = rt.IntValue(pos)
   170  	}
   171  	err = rt.SetIndex(t, tblVal, posVal, val)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	return c.Next(), nil
   176  }
   177  
   178  func move(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   179  	if err := c.CheckNArgs(4); err != nil {
   180  		return nil, err
   181  	}
   182  	_, err := c.TableArg(0)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	srcVal := c.Arg(0)
   187  	srcStart, err := c.IntArg(1)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	srcEnd, err := c.IntArg(2)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	dstStart, err := c.IntArg(3)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	dstVal := srcVal
   200  	if c.NArgs() >= 5 {
   201  		_, err = c.TableArg(4)
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  		dstVal = c.Arg(4)
   206  	}
   207  	if srcStart > srcEnd || srcStart == dstStart && dstVal == srcVal {
   208  		// Nothing to do apparently!
   209  	} else if srcStart <= 0 && srcStart+math.MaxInt64 <= srcEnd {
   210  		return nil, errors.New("interval too large")
   211  	} else if dstStart >= srcStart {
   212  		// Move in descending order to avoid writing at a position
   213  		// before moving it
   214  		offset := srcEnd - srcStart // 0 <= offset < math.MaxInt64
   215  		if dstStart > math.MaxInt64-offset {
   216  			// Not enough space to move
   217  			return nil, errors.New("destination would wrap around")
   218  		}
   219  		dstStart += offset
   220  		for srcEnd >= srcStart {
   221  			// Don't require CPU because rt.Index and rt.SetIndex will do
   222  			v, err := rt.Index(t, srcVal, rt.IntValue(srcEnd))
   223  			if err == nil {
   224  				err = rt.SetIndex(t, dstVal, rt.IntValue(dstStart), v)
   225  			}
   226  			if err != nil {
   227  				return nil, err
   228  			}
   229  			if srcEnd == math.MinInt64 {
   230  				// Prevent wrapping around
   231  				break
   232  			}
   233  			srcEnd--
   234  			dstStart--
   235  		}
   236  	} else {
   237  		// Move in ascending order to avoid writing at a position
   238  		// before moving it
   239  		for srcStart <= srcEnd {
   240  			// Don't require CPU because rt.Index and rt.SetIndex will do
   241  			v, err := rt.Index(t, srcVal, rt.IntValue(srcStart))
   242  			if err == nil {
   243  				err = rt.SetIndex(t, dstVal, rt.IntValue(dstStart), v)
   244  			}
   245  			if err != nil {
   246  				return nil, err
   247  			}
   248  			if srcStart == math.MaxInt64 {
   249  				// Prevent wrapping around
   250  				break
   251  			}
   252  			srcStart++
   253  			dstStart++
   254  		}
   255  	}
   256  	return c.PushingNext1(t.Runtime, dstVal), nil
   257  }
   258  
   259  func pack(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   260  	tbl := rt.NewTable()
   261  	// We can use t.SetTable() because tbl has no metatable
   262  	for i, v := range c.Etc() {
   263  		// SetTable always consumes CPU so the loop is protected.
   264  		t.SetTable(tbl, rt.IntValue(int64(i+1)), v)
   265  	}
   266  	t.SetTable(tbl, rt.StringValue("n"), rt.IntValue(int64(len(c.Etc()))))
   267  	return c.PushingNext1(t.Runtime, rt.TableValue(tbl)), nil
   268  }
   269  
   270  func remove(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   271  	if err := c.Check1Arg(); err != nil {
   272  		return nil, err
   273  	}
   274  	_, err := c.TableArg(0)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	tblVal := c.Arg(0)
   279  	tblLen, err := rt.IntLen(t, tblVal)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  	pos := tblLen
   284  	if c.NArgs() >= 2 {
   285  		pos, err = c.IntArg(1)
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  	}
   290  	var val rt.Value
   291  	switch {
   292  	case pos == tblLen || pos == tblLen+1:
   293  		posVal := rt.IntValue(pos)
   294  		val, err = rt.Index(t, tblVal, posVal)
   295  		if err == nil {
   296  			err = rt.SetIndex(t, tblVal, posVal, rt.NilValue)
   297  		}
   298  		if err != nil {
   299  			return nil, err
   300  		}
   301  	case pos <= 0 || pos > tblLen:
   302  		return nil, errors.New("#2 out of range")
   303  	default:
   304  		var newVal rt.Value
   305  		for pos <= tblLen {
   306  			// Don't require CPU because rt.Index and rt.SetIndex will do
   307  			tblLenVal := rt.IntValue(tblLen)
   308  			val, err = rt.Index(t, tblVal, tblLenVal)
   309  			if err == nil {
   310  				err = rt.SetIndex(t, tblVal, tblLenVal, newVal)
   311  			}
   312  			if err != nil {
   313  				return nil, err
   314  			}
   315  			tblLen--
   316  			newVal = val
   317  		}
   318  	}
   319  	return c.PushingNext1(t.Runtime, val), nil
   320  }
   321  
   322  type tableSorter struct {
   323  	len  func() int
   324  	less func(i, j int) bool
   325  	swap func(i, j int)
   326  }
   327  
   328  func (s *tableSorter) Less(i, j int) bool {
   329  	return s.less(i, j)
   330  }
   331  
   332  func (s *tableSorter) Swap(i, j int) {
   333  	s.swap(i, j)
   334  }
   335  
   336  func (s *tableSorter) Len() int {
   337  	return s.len()
   338  }
   339  
   340  const maxSortSize = 1 << 40
   341  
   342  type sortError struct {
   343  	err error
   344  }
   345  
   346  func throwSortError(err error) {
   347  	panic(sortError{err: err})
   348  }
   349  
   350  func sortf(t *rt.Thread, c *rt.GoCont) (next rt.Cont, resErr error) {
   351  	if err := c.Check1Arg(); err != nil {
   352  		return nil, err
   353  	}
   354  	_, err := c.TableArg(0)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  	tblVal := c.Arg(0)
   359  	get := func(i int) rt.Value {
   360  		x, err := rt.Index(t, tblVal, rt.IntValue(int64(i+1)))
   361  		if err != nil {
   362  			throwSortError(err)
   363  		}
   364  		return x
   365  	}
   366  	set := func(i int, x rt.Value) {
   367  		err := rt.SetIndex(t, tblVal, rt.IntValue(int64(i+1)), x)
   368  		if err != nil {
   369  			throwSortError(err)
   370  		}
   371  	}
   372  	swap := func(i, j int) {
   373  		x, y := get(i), get(j)
   374  		set(i, y)
   375  		set(j, x)
   376  	}
   377  	l, err := rt.IntLen(t, tblVal)
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  	if l >= maxSortSize {
   382  		return nil, errors.New("too big to sort")
   383  	}
   384  	if l <= 0 {
   385  		return c.Next(), nil
   386  	}
   387  	len := func() int {
   388  		return int(l)
   389  	}
   390  	var less func(i, j int) bool
   391  	if c.NArgs() >= 2 && !c.Arg(1).IsNil() {
   392  		comp := c.Arg(1)
   393  		term := rt.NewTerminationWith(c, 1, false)
   394  		less = func(i, j int) bool {
   395  			term.Reset()
   396  			err := rt.Call(t, comp, []rt.Value{get(i), get(j)}, term)
   397  			if err != nil {
   398  				throwSortError(err)
   399  			}
   400  			return rt.Truth(term.Get(0))
   401  		}
   402  	} else {
   403  		less = func(i, j int) bool {
   404  			res, err := rt.Lt(t, get(i), get(j))
   405  			if err != nil {
   406  				throwSortError(err)
   407  			}
   408  			return res
   409  		}
   410  	}
   411  	defer func() {
   412  		if r := recover(); r != nil {
   413  			next = nil
   414  			if sortErr, ok := r.(sortError); ok {
   415  				resErr = sortErr.err
   416  				return
   417  			}
   418  			panic(r)
   419  		}
   420  	}()
   421  	sorter := &tableSorter{len, less, swap}
   422  	// Because each operation on sorter consumes cpu resources, it's OK to call
   423  	// sort.Sort.
   424  	sort.Sort(sorter)
   425  	return c.Next(), nil
   426  }
   427  
   428  // Maximum number of values that can be unpacked from a table.  Lua docs don't
   429  // specify what this number should be.
   430  const maxUnpackSize = 256
   431  
   432  func unpack(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   433  	if err := c.Check1Arg(); err != nil {
   434  		return nil, err
   435  	}
   436  	_, err := c.TableArg(0)
   437  	if err != nil {
   438  		return nil, err
   439  	}
   440  	tblVal := c.Arg(0)
   441  	var (
   442  		i int64 = 1
   443  		j int64
   444  	)
   445  	nargs := c.NArgs()
   446  	if nargs >= 2 {
   447  		i, err = c.IntArg(1)
   448  		if err != nil {
   449  			return nil, err
   450  		}
   451  	}
   452  	if nargs >= 3 && !c.Arg(2).IsNil() {
   453  		j, err = c.IntArg(2)
   454  	} else {
   455  		j, err = rt.IntLen(t, tblVal)
   456  	}
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  	if i < math.MaxInt64-maxUnpackSize && i+maxUnpackSize <= j {
   461  		return nil, errors.New("too many values to unpack")
   462  	}
   463  	next := c.Next()
   464  	for ; i <= j; i++ {
   465  		// rt.Index consumes cpu so the loop is OK.
   466  		val, err := rt.Index(t, tblVal, rt.IntValue(i))
   467  		if err != nil {
   468  			return nil, err
   469  		}
   470  		t.Push1(next, val)
   471  		if i == math.MaxInt64 {
   472  			// Prevent wrap around
   473  			break
   474  		}
   475  	}
   476  	return next, nil
   477  }