github.com/dfcfw/lua@v0.0.0-20230325031207-0cc7ffb7b8b9/luar/cache.go (about)

     1  package luar
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/dfcfw/lua"
     7  )
     8  
     9  func addMethods(L *lua.LState, c *Config, vtype reflect.Type, tbl *lua.LTable, ptrReceiver bool) {
    10  	for i := 0; i < vtype.NumMethod(); i++ {
    11  		method := vtype.Method(i)
    12  		if method.PkgPath != "" {
    13  			continue
    14  		}
    15  		namesFn := c.MethodNames
    16  		if namesFn == nil {
    17  			namesFn = defaultMethodNames
    18  		}
    19  		fn := funcWrapper(L, method.Func, ptrReceiver)
    20  		for _, name := range namesFn(vtype, method) {
    21  			tbl.RawSetString(name, fn)
    22  		}
    23  	}
    24  }
    25  
    26  func collectFields(vtype reflect.Type, current []int) map[string]reflect.StructField {
    27  	m := make(map[string]reflect.StructField)
    28  
    29  	var subFields []map[string]reflect.StructField
    30  
    31  	for i, n := 0, vtype.NumField(); i < n; i++ {
    32  		field := vtype.Field(i)
    33  
    34  		if field.PkgPath == "" {
    35  			field.Index = append(current[:len(current):len(current)], i)
    36  			m[field.Name] = field
    37  		}
    38  
    39  		if field.Anonymous {
    40  			t := field.Type
    41  			if t.Kind() != reflect.Struct {
    42  				if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
    43  					continue
    44  				}
    45  				t = field.Type.Elem()
    46  			}
    47  			r := collectFields(t, append(current[:len(current):len(current)], i))
    48  			subFields = append(subFields, r)
    49  		}
    50  	}
    51  
    52  	m2 := make(map[string]reflect.StructField)
    53  	for i := 0; i < len(subFields); i++ {
    54  		for name, value := range subFields[i] {
    55  			if _, ok := m2[name]; !ok {
    56  				m2[name] = value
    57  			} else {
    58  				m2[name] = reflect.StructField{}
    59  			}
    60  		}
    61  	}
    62  
    63  	for name, value := range m2 {
    64  		if len(value.Index) > 0 {
    65  			if _, ok := m[name]; !ok {
    66  				m[name] = value
    67  			}
    68  		}
    69  	}
    70  
    71  	return m
    72  }
    73  
    74  func addFields(L *lua.LState, c *Config, vtype reflect.Type, tbl *lua.LTable) {
    75  	namesFn := c.FieldNames
    76  	if namesFn == nil {
    77  		namesFn = defaultFieldNames
    78  	}
    79  
    80  	for _, field := range collectFields(vtype, nil) {
    81  		aliases := namesFn(vtype, field)
    82  		if len(aliases) > 0 {
    83  			ud := L.NewUserData()
    84  			ud.Value = field.Index
    85  			for _, alias := range aliases {
    86  				tbl.RawSetString(alias, ud)
    87  			}
    88  		}
    89  	}
    90  }
    91  
    92  func getMetatable(L *lua.LState, vtype reflect.Type) *lua.LTable {
    93  	config := GetConfig(L)
    94  
    95  	if v := config.regular[vtype]; v != nil {
    96  		return v
    97  	}
    98  
    99  	var (
   100  		mt      *lua.LTable
   101  		methods = L.CreateTable(0, vtype.NumMethod())
   102  	)
   103  
   104  	switch vtype.Kind() {
   105  	case reflect.Array:
   106  		mt = L.CreateTable(0, 7)
   107  
   108  		mt.RawSetString("__index", L.NewFunction(arrayIndex))
   109  		mt.RawSetString("__len", L.NewFunction(arrayLen))
   110  		mt.RawSetString("__call", L.NewFunction(arrayCall))
   111  		mt.RawSetString("__eq", L.NewFunction(arrayEq))
   112  
   113  		addMethods(L, config, vtype, methods, false)
   114  	case reflect.Chan:
   115  		mt = L.CreateTable(0, 8)
   116  
   117  		mt.RawSetString("__index", L.NewFunction(chanIndex))
   118  		mt.RawSetString("__len", L.NewFunction(chanLen))
   119  		mt.RawSetString("__eq", L.NewFunction(chanEq))
   120  		mt.RawSetString("__call", L.NewFunction(chanCall))
   121  		mt.RawSetString("__unm", L.NewFunction(chanUnm))
   122  
   123  		addMethods(L, config, vtype, methods, false)
   124  	case reflect.Map:
   125  		mt = L.CreateTable(0, 7)
   126  
   127  		mt.RawSetString("__index", L.NewFunction(mapIndex))
   128  		mt.RawSetString("__newindex", L.NewFunction(mapNewIndex))
   129  		mt.RawSetString("__len", L.NewFunction(mapLen))
   130  		mt.RawSetString("__call", L.NewFunction(mapCall))
   131  
   132  		addMethods(L, config, vtype, methods, false)
   133  	case reflect.Slice:
   134  		mt = L.CreateTable(0, 8)
   135  
   136  		mt.RawSetString("__index", L.NewFunction(sliceIndex))
   137  		mt.RawSetString("__newindex", L.NewFunction(sliceNewIndex))
   138  		mt.RawSetString("__len", L.NewFunction(sliceLen))
   139  		mt.RawSetString("__call", L.NewFunction(sliceCall))
   140  		mt.RawSetString("__add", L.NewFunction(sliceAdd))
   141  
   142  		addMethods(L, config, vtype, methods, false)
   143  	case reflect.Struct:
   144  		mt = L.CreateTable(0, 6)
   145  
   146  		fields := L.CreateTable(0, vtype.NumField())
   147  		addFields(L, config, vtype, fields)
   148  		mt.RawSetString("fields", fields)
   149  
   150  		mt.RawSetString("__index", L.NewFunction(structIndex))
   151  		mt.RawSetString("__eq", L.NewFunction(structEq))
   152  
   153  		addMethods(L, config, vtype, methods, false)
   154  	case reflect.Ptr:
   155  		switch vtype.Elem().Kind() {
   156  		case reflect.Array:
   157  			mt = L.CreateTable(0, 10)
   158  
   159  			mt.RawSetString("__index", L.NewFunction(arrayPtrIndex))
   160  			mt.RawSetString("__newindex", L.NewFunction(arrayPtrNewIndex))
   161  			mt.RawSetString("__call", L.NewFunction(arrayCall)) // same as non-pointer
   162  			mt.RawSetString("__len", L.NewFunction(arrayLen))   // same as non-pointer
   163  		case reflect.Struct:
   164  			mt = L.CreateTable(0, 8)
   165  
   166  			mt.RawSetString("__index", L.NewFunction(structPtrIndex))
   167  			mt.RawSetString("__newindex", L.NewFunction(structPtrNewIndex))
   168  		default:
   169  			mt = L.CreateTable(0, 7)
   170  
   171  			mt.RawSetString("__index", L.NewFunction(ptrIndex))
   172  		}
   173  
   174  		mt.RawSetString("__eq", L.NewFunction(ptrEq))
   175  		mt.RawSetString("__pow", L.NewFunction(ptrPow))
   176  		mt.RawSetString("__unm", L.NewFunction(ptrUnm))
   177  
   178  		addMethods(L, config, vtype, methods, true)
   179  	default:
   180  		panic("unexpected kind " + vtype.Kind().String())
   181  	}
   182  
   183  	mt.RawSetString("__tostring", L.NewFunction(tostring))
   184  	mt.RawSetString("__metatable", lua.LString("gopher-luar"))
   185  	mt.RawSetString("methods", methods)
   186  
   187  	config.regular[vtype] = mt
   188  	return mt
   189  }
   190  
   191  func getTypeMetatable(L *lua.LState, t reflect.Type) *lua.LTable {
   192  	config := GetConfig(L)
   193  
   194  	if v := config.types; v != nil {
   195  		return v
   196  	}
   197  
   198  	mt := L.CreateTable(0, 3)
   199  	mt.RawSetString("__call", L.NewFunction(typeCall))
   200  	mt.RawSetString("__eq", L.NewFunction(typeEq))
   201  	mt.RawSetString("__metatable", lua.LString("gopher-luar"))
   202  
   203  	config.types = mt
   204  	return mt
   205  }