github.com/dfcfw/lua@v0.0.0-20230325031207-0cc7ffb7b8b9/luar/func.go (about)

     1  package luar
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/dfcfw/lua"
     7  )
     8  
     9  // LState is an wrapper for gopher-lua's LState. It should be used when you
    10  // wish to have a function/method with the standard "func(*lua.LState) int"
    11  // signature.
    12  type LState struct {
    13  	*lua.LState
    14  }
    15  
    16  var (
    17  	refTypeLStatePtr  = reflect.TypeOf((*LState)(nil))
    18  	refTypeLuaLValue  = reflect.TypeOf((*lua.LValue)(nil)).Elem()
    19  	refTypeInt        = reflect.TypeOf(int(0))
    20  	refTypeEmptyIface = reflect.TypeOf((*interface{})(nil)).Elem()
    21  )
    22  
    23  func getFunc(L *lua.LState) (ref reflect.Value, refType reflect.Type) {
    24  	ref = L.Get(lua.UpvalueIndex(1)).(*lua.LUserData).Value.(reflect.Value)
    25  	refType = ref.Type()
    26  	return
    27  }
    28  
    29  func isPtrReceiverMethod(L *lua.LState) bool {
    30  	return bool(L.Get(lua.UpvalueIndex(2)).(lua.LBool))
    31  }
    32  
    33  func funcIsBypass(t reflect.Type) bool {
    34  	if t.NumIn() == 1 && t.NumOut() == 1 && t.In(0) == refTypeLStatePtr && t.Out(0) == refTypeInt {
    35  		return true
    36  	}
    37  	if t.NumIn() == 2 && t.NumOut() == 1 && t.In(1) == refTypeLStatePtr && t.Out(0) == refTypeInt {
    38  		return true
    39  	}
    40  	return false
    41  }
    42  
    43  func funcBypass(L *lua.LState) int {
    44  	ref, refType := getFunc(L)
    45  
    46  	convertedPtr := false
    47  	var receiver reflect.Value
    48  	var ud lua.LValue
    49  
    50  	luarState := LState{L}
    51  	args := make([]reflect.Value, 0, 2)
    52  	if refType.NumIn() == 2 {
    53  		receiverHint := refType.In(0)
    54  		ud = L.Get(1)
    55  		var err error
    56  		if isPtrReceiverMethod(L) {
    57  			receiver, err = lValueToReflect(L, ud, receiverHint, &convertedPtr)
    58  		} else {
    59  			receiver, err = lValueToReflect(L, ud, receiverHint, nil)
    60  		}
    61  		if err != nil {
    62  			L.ArgError(1, err.Error())
    63  		}
    64  		args = append(args, receiver)
    65  		L.Remove(1)
    66  	}
    67  	args = append(args, reflect.ValueOf(&luarState))
    68  	ret := ref.Call(args)[0].Interface().(int)
    69  	if convertedPtr {
    70  		ud.(*lua.LUserData).Value = receiver.Elem().Interface()
    71  	}
    72  	return ret
    73  }
    74  
    75  func funcRegular(L *lua.LState) int {
    76  	ref, refType := getFunc(L)
    77  
    78  	top := L.GetTop()
    79  	expected := refType.NumIn()
    80  	variadic := refType.IsVariadic()
    81  	if !variadic && top != expected {
    82  		L.RaiseError("invalid number of function arguments (%d expected, got %d)", expected, top)
    83  	}
    84  	if variadic && top < expected-1 {
    85  		L.RaiseError("invalid number of function arguments (%d or more expected, got %d)", expected-1, top)
    86  	}
    87  
    88  	convertedPtr := false
    89  	var receiver reflect.Value
    90  	var ud lua.LValue
    91  
    92  	args := make([]reflect.Value, top)
    93  	for i := 0; i < L.GetTop(); i++ {
    94  		var hint reflect.Type
    95  		if variadic && i >= expected-1 {
    96  			hint = refType.In(expected - 1).Elem()
    97  		} else {
    98  			hint = refType.In(i)
    99  		}
   100  		var arg reflect.Value
   101  		var err error
   102  		if i == 0 && isPtrReceiverMethod(L) {
   103  			ud = L.Get(1)
   104  			v := ud
   105  			arg, err = lValueToReflect(L, v, hint, &convertedPtr)
   106  			if err != nil {
   107  				L.ArgError(1, err.Error())
   108  			}
   109  			receiver = arg
   110  		} else {
   111  			v := L.Get(i + 1)
   112  			arg, err = lValueToReflect(L, v, hint, nil)
   113  			if err != nil {
   114  				L.ArgError(i+1, err.Error())
   115  			}
   116  		}
   117  		args[i] = arg
   118  	}
   119  	ret := ref.Call(args)
   120  
   121  	if convertedPtr {
   122  		ud.(*lua.LUserData).Value = receiver.Elem().Interface()
   123  	}
   124  
   125  	for _, val := range ret {
   126  		L.Push(New(L, val.Interface()))
   127  	}
   128  	return len(ret)
   129  }
   130  
   131  func funcWrapper(L *lua.LState, fn reflect.Value, isPtrReceiverMethod bool) *lua.LFunction {
   132  	up := L.NewUserData()
   133  	up.Value = fn
   134  
   135  	if funcIsBypass(fn.Type()) {
   136  		return L.NewClosure(funcBypass, up, lua.LBool(isPtrReceiverMethod))
   137  	}
   138  	return L.NewClosure(funcRegular, up, lua.LBool(isPtrReceiverMethod))
   139  }