github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/lib/oslib/oslib.go (about)

     1  package oslib
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"time"
     7  
     8  	"github.com/arnodel/golua/lib/packagelib"
     9  	rt "github.com/arnodel/golua/runtime"
    10  	"github.com/arnodel/golua/safeio"
    11  	"github.com/arnodel/strftime"
    12  )
    13  
    14  // LibLoader can load the os lib.
    15  var LibLoader = packagelib.Loader{
    16  	Load: load,
    17  	Name: "os",
    18  }
    19  
    20  func load(r *rt.Runtime) (rt.Value, func()) {
    21  	pkg := rt.NewTable()
    22  
    23  	rt.SolemnlyDeclareCompliance(
    24  		rt.ComplyCpuSafe|rt.ComplyMemSafe|rt.ComplyTimeSafe|rt.ComplyIoSafe,
    25  
    26  		r.SetEnvGoFunc(pkg, "clock", clock, 0, false),
    27  		r.SetEnvGoFunc(pkg, "date", date, 2, false),
    28  		r.SetEnvGoFunc(pkg, "difftime", difftime, 2, false),
    29  		r.SetEnvGoFunc(pkg, "time", timef, 1, false),
    30  		r.SetEnvGoFunc(pkg, "getenv", getenv, 1, false),
    31  		r.SetEnvGoFunc(pkg, "tmpname", tmpname, 0, false),
    32  		r.SetEnvGoFunc(pkg, "remove", remove, 1, false),
    33  		r.SetEnvGoFunc(pkg, "rename", rename, 2, false),
    34  	)
    35  	// These functions are not safe - I don't know what compliance category to
    36  	// put them in.
    37  	r.SetEnvGoFunc(pkg, "setlocale", setlocale, 2, false)
    38  	r.SetEnvGoFunc(pkg, "exit", exit, 2, false)
    39  	return rt.TableValue(pkg), nil
    40  }
    41  
    42  func date(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    43  	var (
    44  		err    error
    45  		utc    bool
    46  		now    time.Time
    47  		format string
    48  		date   rt.Value
    49  	)
    50  	if c.NArgs() == 0 {
    51  		format = "%c"
    52  	} else {
    53  		format, err = c.StringArg(0)
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  	}
    58  
    59  	// If format starts with "!" it means UTC
    60  	if len(format) > 0 && format[0] == '!' {
    61  		utc = true
    62  		format = format[1:]
    63  	}
    64  
    65  	// Get the time value
    66  	if c.NArgs() > 1 {
    67  		var t int64
    68  		t, err = c.IntArg(1)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  		now = time.Unix(t, 0)
    73  	} else {
    74  		now = time.Now()
    75  	}
    76  	if utc {
    77  		now = now.UTC()
    78  	}
    79  	switch format {
    80  	case "*t":
    81  		{
    82  			tbl := rt.NewTable()
    83  			setTableFields(t.Runtime, tbl, now)
    84  			date = rt.TableValue(tbl)
    85  		}
    86  	default:
    87  		{
    88  			dateStr, fmtErr := strftime.StrictFormat(format, now)
    89  			if fmtErr != nil {
    90  				return nil, fmtErr
    91  			}
    92  			date = rt.StringValue(dateStr)
    93  		}
    94  	}
    95  	return c.PushingNext1(t.Runtime, date), nil
    96  }
    97  
    98  func difftime(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    99  	if err := c.CheckNArgs(2); err != nil {
   100  		return nil, err
   101  	}
   102  	t2, err := c.IntArg(0)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	t1, err := c.IntArg(1)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	return c.PushingNext1(t.Runtime, rt.IntValue(t2-t1)), nil
   111  }
   112  
   113  func exit(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   114  	var (
   115  		code  = 0 // 0 for success, 1 for failure
   116  		close = false
   117  	)
   118  	if c.NArgs() > 0 {
   119  		if !rt.Truth(c.Arg(0)) {
   120  			code = 1
   121  		}
   122  	}
   123  	if c.NArgs() > 1 {
   124  		close = rt.Truth(c.Arg(1))
   125  	}
   126  	if close {
   127  		// TODO: "close" the runtime, i.e. cleanup.
   128  		_ = close
   129  	}
   130  	os.Exit(code)
   131  	return nil, nil
   132  }
   133  
   134  func timef(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   135  	if c.NArgs() == 0 {
   136  		now := time.Now().Unix()
   137  		return c.PushingNext1(t.Runtime, rt.IntValue(now)), nil
   138  	}
   139  	tbl, err := c.TableArg(0)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	var fieldErr error
   144  	var getField = func(dest *int, name string, required bool) bool {
   145  		if fieldErr != nil {
   146  			return false
   147  		}
   148  		var val rt.Value
   149  		val, fieldErr = rt.Index(t, rt.TableValue(tbl), rt.StringValue(name))
   150  		if fieldErr != nil {
   151  			return false
   152  		}
   153  		if val == rt.NilValue {
   154  			if required {
   155  				fieldErr = fmt.Errorf("required field '%s' missing", name)
   156  				return false
   157  			}
   158  			return true
   159  		}
   160  		iVal, ok := val.TryInt()
   161  		if !ok {
   162  			fieldErr = fmt.Errorf("field '%s' is not an integer", name)
   163  			return false
   164  		}
   165  		*dest = int(iVal)
   166  		return true
   167  	}
   168  	var (
   169  		year, month, day int
   170  		hour, min, sec   = 12, 0, 0
   171  	)
   172  	ok := getField(&year, "year", true) &&
   173  		getField(&month, "month", true) &&
   174  		getField(&day, "day", true) &&
   175  		getField(&hour, "hour", false) &&
   176  		getField(&min, "min", false) &&
   177  		getField(&sec, "sec", false)
   178  	if !ok {
   179  		return nil, fieldErr
   180  	}
   181  	// TODO: deal with DST - I have no idea how to do that.
   182  
   183  	date := time.Date(year, time.Month(month), day, hour, min, sec, 0, time.Local)
   184  	setTableFields(t.Runtime, tbl, date)
   185  	return c.PushingNext1(t.Runtime, rt.IntValue(date.Unix())), nil
   186  }
   187  
   188  func setlocale(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   189  	if err := c.Check1Arg(); err != nil {
   190  		return nil, err
   191  	}
   192  	locale, err := c.StringArg(0)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  	// Just pretend we can set the "C" locale and none other
   197  	if locale != "C" {
   198  		return c.PushingNext1(t.Runtime, rt.NilValue), nil
   199  	}
   200  	return c.PushingNext1(t.Runtime, rt.StringValue(locale)), nil
   201  }
   202  
   203  func getenv(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   204  	if err := c.Check1Arg(); err != nil {
   205  		return nil, err
   206  	}
   207  	name, err := c.StringArg(0)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  	val, ok := os.LookupEnv(name)
   212  	valV := rt.NilValue
   213  	if ok {
   214  		t.RequireBytes(len(val))
   215  		valV = rt.StringValue(val)
   216  	}
   217  	return c.PushingNext1(t.Runtime, valV), nil
   218  }
   219  
   220  func tmpname(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   221  	f, ioErr := safeio.TempFile(t.Runtime, "", "")
   222  	if ioErr != nil {
   223  		return t.ProcessIoError(c.Next(), ioErr)
   224  	}
   225  	defer f.Close()
   226  	name := f.Name()
   227  	t.RequireBytes(len(name))
   228  	return c.PushingNext1(t.Runtime, rt.StringValue(name)), nil
   229  }
   230  
   231  func remove(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   232  	if err := c.Check1Arg(); err != nil {
   233  		return nil, err
   234  	}
   235  	name, err := c.StringArg(0)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	ioErr := safeio.RemoveFile(t.Runtime, name)
   240  	if ioErr != nil {
   241  		return t.ProcessIoError(c.Next(), ioErr)
   242  	}
   243  	return c.PushingNext1(t.Runtime, rt.BoolValue(true)), nil
   244  }
   245  
   246  func rename(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   247  	if err := c.CheckNArgs(2); err != nil {
   248  		return nil, err
   249  	}
   250  	oldName, err := c.StringArg(0)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	newName, err := c.StringArg(1)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  	ioErr := safeio.RenameFile(t.Runtime, oldName, newName)
   259  	if ioErr != nil {
   260  		return t.ProcessIoError(c.Next(), ioErr)
   261  	}
   262  	return c.PushingNext1(t.Runtime, rt.BoolValue(true)), nil
   263  }
   264  
   265  //
   266  // Utils
   267  //
   268  
   269  func setTableFields(r *rt.Runtime, tbl *rt.Table, now time.Time) {
   270  	r.SetEnv(tbl, "year", rt.IntValue(int64(now.Year())))
   271  	r.SetEnv(tbl, "month", rt.IntValue(int64(now.Month())))
   272  	r.SetEnv(tbl, "day", rt.IntValue(int64(now.Day())))
   273  	r.SetEnv(tbl, "hour", rt.IntValue(int64(now.Hour())))
   274  	r.SetEnv(tbl, "min", rt.IntValue(int64(now.Minute())))
   275  	r.SetEnv(tbl, "sec", rt.IntValue(int64(now.Second())))
   276  	// Weeks start on Sunday according to Lua!
   277  	wday := now.Weekday() + 1
   278  	if wday == 8 {
   279  		wday = 1
   280  	}
   281  	r.SetEnv(tbl, "wday", rt.IntValue(int64(wday)))
   282  	r.SetEnv(tbl, "yday", rt.IntValue(int64(now.YearDay())))
   283  	r.SetEnv(tbl, "isdst", rt.BoolValue(now.IsDST()))
   284  
   285  }