github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/runtime/marshal.go (about)

     1  package runtime
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  
     9  	"github.com/arnodel/golua/code"
    10  )
    11  
    12  var marshalPrefix = []byte{6, 0, 4}
    13  var ErrInvalidMarshalPrefix = errors.New("Invalid marshal prefix")
    14  
    15  // HasMarshalPrefix returns true if the byte slice passed starts witht the magic
    16  // prefix for Lua marshalled values.
    17  func HasMarshalPrefix(bs []byte) bool {
    18  	return len(bs) >= len(marshalPrefix) && bytes.Equal(marshalPrefix, bs[:len(marshalPrefix)])
    19  }
    20  
    21  // MarshalConst serializes a const value to the writer w.
    22  func MarshalConst(w io.Writer, c Value, budget uint64) (used uint64, err error) {
    23  	defer func() {
    24  		if r := recover(); r == budgetConsumed {
    25  			used = budget
    26  		}
    27  	}()
    28  	if _, err := w.Write(marshalPrefix); err != nil {
    29  		return 0, err
    30  	}
    31  	bw := bwriter{w: w, budget: budget}
    32  	bw.writeConst(c)
    33  	return budget - bw.budget, bw.err
    34  }
    35  
    36  // UnmarshalConst reads from r to deserialize a const value.
    37  func UnmarshalConst(r io.Reader, budget uint64) (v Value, used uint64, err error) {
    38  	defer func() {
    39  		if r := recover(); r == budgetConsumed {
    40  			used = budget
    41  		}
    42  	}()
    43  	pfx := make([]byte, len(marshalPrefix))
    44  	_, err = r.Read(pfx)
    45  	if !bytes.Equal(pfx, marshalPrefix) {
    46  		err = ErrInvalidMarshalPrefix
    47  	}
    48  	if err != nil {
    49  		return
    50  	}
    51  	br := breader{r: r, budget: budget}
    52  	v = br.readConst()
    53  	return v, budget - br.budget, br.err
    54  }
    55  
    56  //
    57  // bwriter: helper data struture to serialise values
    58  //
    59  type bwriter struct {
    60  	w   io.Writer
    61  	err error
    62  
    63  	budget uint64
    64  }
    65  
    66  func (w *bwriter) writeConst(c Value) {
    67  	switch c.Type() {
    68  	case IntType:
    69  		w.consumeBudget(1 + 8)
    70  		w.write(IntType, c.AsInt())
    71  	case FloatType:
    72  		w.consumeBudget(1 + 8)
    73  		w.write(FloatType, c.AsFloat())
    74  	case StringType:
    75  		w.consumeBudget(1 + 0) // w.writeString will consume the string budget
    76  		w.write(StringType, c.AsString())
    77  	case CodeType:
    78  		w.writeCode(c.AsCode())
    79  	// Booleans and nil are inlined so this shouldn't be neeeded.  Keeping
    80  	// around in case this is reversed
    81  	//
    82  	//  case BoolType:
    83  	//  w.consumeBudget(1 + 1)
    84  	//  w.write(BoolType, c.AsBool())
    85  	// case NilType:
    86  	//  w.consumeBudget(1)
    87  	//  w.write(NilType)
    88  	default:
    89  		w.err = errInvalidValueType
    90  	}
    91  }
    92  
    93  func (w *bwriter) writeCode(c *Code) {
    94  	w.consumeBudget(1 + 0 + 0 + 8 + 8 + 8)
    95  	w.write(
    96  		CodeType,
    97  		c.source,
    98  		c.name,
    99  		int64(len(c.code)), c.code,
   100  		int64(len(c.lines)), c.lines,
   101  		int64(len(c.consts)),
   102  	)
   103  	for _, k := range c.consts {
   104  		w.writeConst(k)
   105  	}
   106  	w.consumeBudget(2 + 2 + 2 + 8)
   107  	w.write(
   108  		c.UpvalueCount,
   109  		c.RegCount,
   110  		c.CellCount,
   111  		int64(len(c.UpNames)),
   112  	)
   113  	for _, n := range c.UpNames {
   114  		w.writeString(n)
   115  	}
   116  }
   117  
   118  func (w *bwriter) write(xs ...interface{}) {
   119  	if w.err != nil {
   120  		return
   121  	}
   122  	for _, x := range xs {
   123  		switch xx := x.(type) {
   124  		case string:
   125  			w.writeString(xx)
   126  		default:
   127  			w.err = binary.Write(w.w, binary.LittleEndian, x)
   128  		}
   129  		if w.err != nil {
   130  			return
   131  		}
   132  	}
   133  }
   134  
   135  func (w *bwriter) writeString(s string) {
   136  	w.consumeBudget(uint64(8 + len(s)))
   137  	w.write(int64(len(s)))
   138  	if w.err == nil {
   139  		_, w.err = w.w.Write([]byte(s))
   140  	}
   141  }
   142  
   143  func (w *bwriter) consumeBudget(amount uint64) {
   144  	if w.budget == 0 {
   145  		return
   146  	}
   147  	if w.budget < amount {
   148  		panic(budgetConsumed)
   149  	}
   150  	w.budget -= amount
   151  }
   152  
   153  var budgetConsumed interface{} = "budget consumed"
   154  
   155  //
   156  // breader: helper datastructure to deserialize values
   157  //
   158  
   159  type breader struct {
   160  	r   io.Reader
   161  	err error
   162  
   163  	budget uint64
   164  }
   165  
   166  func (r *breader) readConst() (v Value) {
   167  	var tp ValueType
   168  	r.read(1, &tp)
   169  	if r.err != nil {
   170  		return
   171  	}
   172  	switch tp {
   173  	case IntType:
   174  		var x int64
   175  		r.read(8, &x)
   176  		v = IntValue(x)
   177  	case FloatType:
   178  		var x float64
   179  		r.read(8, &x)
   180  		v = FloatValue(x)
   181  	case StringType:
   182  		s := r.readString()
   183  		v = StringValue(s)
   184  	case CodeType:
   185  		x := new(Code)
   186  		r.readCode(x)
   187  		v = CodeValue(x)
   188  	// Booleans and nil are inlined so this shouldn't be needed.  Keeping around
   189  	// in case this is reversed.
   190  	//
   191  	// case BoolType:
   192  	// 	var x bool
   193  	// 	r.read(1, &x)
   194  	// 	v = BoolValue(x)
   195  	// case NilType:
   196  	// 	v = NilValue
   197  	default:
   198  		r.err = errInvalidValueType
   199  	}
   200  	if r.err != nil {
   201  		return NilValue
   202  	}
   203  	return v
   204  }
   205  
   206  func (r *breader) readCode(c *Code) {
   207  	var sz int64
   208  	r.read(
   209  		0+0+8,
   210  		&c.source,
   211  		&c.name,
   212  		&sz,
   213  	)
   214  	c.code = make([]code.Opcode, sz)
   215  	r.read(
   216  		4*uint64(sz)+8,
   217  		c.code,
   218  		&sz,
   219  	)
   220  	c.lines = make([]int32, sz)
   221  	r.read(
   222  		4*uint64(sz)+8,
   223  		c.lines,
   224  		&sz,
   225  	)
   226  	c.consts = make([]Value, sz)
   227  	for i := range c.consts {
   228  		c.consts[i] = r.readConst()
   229  	}
   230  	r.read(
   231  		2+2+2+8,
   232  		&c.UpvalueCount,
   233  		&c.RegCount,
   234  		&c.CellCount,
   235  		&sz,
   236  	)
   237  	c.UpNames = make([]string, sz)
   238  	for i := range c.UpNames {
   239  		c.UpNames[i] = r.readString()
   240  	}
   241  }
   242  
   243  func (r *breader) read(sz uint64, xs ...interface{}) {
   244  	if r.err != nil {
   245  		return
   246  	}
   247  	r.consumeBudget(sz)
   248  	for _, x := range xs {
   249  		switch xx := x.(type) {
   250  		case *string:
   251  			*xx = r.readString()
   252  		default:
   253  			r.err = binary.Read(r.r, binary.LittleEndian, x)
   254  		}
   255  		if r.err != nil {
   256  			return
   257  		}
   258  	}
   259  }
   260  
   261  func (r *breader) readString() (s string) {
   262  	if r.err != nil {
   263  		return
   264  	}
   265  	var sl int64
   266  	r.read(8, &sl)
   267  	if r.err != nil {
   268  		return
   269  	}
   270  	r.consumeBudget(uint64(sl))
   271  	b := make([]byte, sl)
   272  	_, r.err = r.r.Read(b)
   273  	if r.err == nil {
   274  		s = string(b)
   275  	}
   276  	return
   277  }
   278  
   279  func (r *breader) consumeBudget(amount uint64) {
   280  	if r.budget == 0 {
   281  		return
   282  	}
   283  	if r.budget < amount {
   284  		panic(budgetConsumed)
   285  	}
   286  	r.budget -= amount
   287  }
   288  
   289  var errInvalidValueType = errors.New("Invalid value type")