github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/coprocess_lua.go (about)

     1  // +build lua
     2  
     3  package gateway
     4  
     5  /*
     6  #cgo pkg-config: luajit
     7  
     8  #include <stdio.h>
     9  #include <stdlib.h>
    10  #include <string.h>
    11  
    12  #include "../coprocess/api.h"
    13  
    14  #include "../coprocess/lua/binding.h"
    15  
    16  #include <lua.h>
    17  #include <lualib.h>
    18  #include <lauxlib.h>
    19  
    20  static void LoadMiddleware(char* middleware_file, char* middleware_contents) {
    21  }
    22  
    23  static void LoadMiddlewareIntoState(lua_State* L, char* middleware_name, char* middleware_contents) {
    24  	luaL_dostring(L, middleware_contents);
    25  }
    26  
    27  static int LuaDispatchHook(struct CoProcessMessage* object, struct CoProcessMessage* outputObject) {
    28  	lua_State *L = luaL_newstate();
    29  
    30  	luaL_openlibs(L);
    31  	// luaL_dofile(L, "coprocess/lua/tyk/core.lua");
    32  	LoadCachedModules(L);
    33  
    34  	LoadCachedMiddleware(L);
    35  	lua_getglobal(L, "dispatch");
    36  
    37  	lua_pushlstring(L, object->p_data, object->length);
    38  	int call_result = lua_pcall(L, 1, 2, 0);
    39  
    40  	size_t lua_output_length = lua_tointeger(L, -1);
    41  	const char* lua_output_data = lua_tolstring(L, 0, &lua_output_length);
    42  
    43  	char* output = malloc(lua_output_length);
    44  	memmove(output, lua_output_data, lua_output_length);
    45  
    46  	lua_close(L);
    47  
    48  	outputObject->p_data = (void*)output;
    49  	outputObject->length = lua_output_length;
    50  
    51  	return 0;
    52  }
    53  
    54  static void LuaDispatchEvent(char* event_json) {
    55  	lua_State *L = luaL_newstate();
    56  	luaL_openlibs(L);
    57  	luaL_dofile(L, "coprocess/lua/tyk/core.lua");
    58  
    59  	lua_getglobal(L, "dispatch_event");
    60  	// lua_pushlstring(L, object->p_data, object->length);
    61  	int call_result = lua_pcall(L, 1, 1, 0);
    62  
    63  	lua_close(L);
    64  }
    65  */
    66  import "C"
    67  
    68  import (
    69  	"encoding/json"
    70  	"errors"
    71  	"io/ioutil"
    72  	"path/filepath"
    73  	"unsafe"
    74  
    75  	"github.com/sirupsen/logrus"
    76  
    77  	"github.com/TykTechnologies/tyk/apidef"
    78  	"github.com/TykTechnologies/tyk/coprocess"
    79  )
    80  
    81  const (
    82  	// ModuleBasePath points to the Tyk modules path.
    83  	ModuleBasePath = "coprocess/lua"
    84  	// MiddlewareBasePath points to the custom middleware path.
    85  	MiddlewareBasePath = "middleware/lua"
    86  )
    87  
    88  func init() {
    89  	var err error
    90  	loadedDrivers[apidef.LuaDriver], err = NewLuaDispatcher()
    91  	if err == nil {
    92  		log.WithFields(logrus.Fields{
    93  			"prefix": "coprocess",
    94  		}).Info("Lua dispatcher was initialized")
    95  	} else {
    96  		log.WithFields(logrus.Fields{
    97  			"prefix": "coprocess",
    98  		}).WithError(err).Error("Couldn't load Lua dispatcher")
    99  	}
   100  }
   101  
   102  // gMiddlewareCache will hold LuaDispatcher.gMiddlewareCache.
   103  var gMiddlewareCache map[string]string
   104  var gModuleCache map[string]string
   105  
   106  // LuaDispatcher implements a coprocess.Dispatcher
   107  type LuaDispatcher struct {
   108  	// LuaDispatcher implements the coprocess.Dispatcher interface.
   109  	coprocess.Dispatcher
   110  	// MiddlewareCache will keep the middleware file name and contents in memory, the contents will be accessed when a Lua state is initialized.
   111  	MiddlewareCache map[string]string
   112  	ModuleCache     map[string]string
   113  }
   114  
   115  // Dispatch takes a CoProcessMessage and sends it to the CP.
   116  func (d *LuaDispatcher) NativeDispatch(objectPtr unsafe.Pointer, newObjectPtr unsafe.Pointer) error {
   117  	object := (*C.struct_CoProcessMessage)(objectPtr)
   118  	newObject := (*C.struct_CoProcessMessage)(newObjectPtr)
   119  	if result := C.LuaDispatchHook(object, newObject); result != 0 {
   120  		return errors.New("Dispatch error")
   121  	}
   122  	return nil
   123  }
   124  
   125  func (d *LuaDispatcher) Dispatch(object *coprocess.Object) (*coprocess.Object, error) {
   126  	objectMsg, err := json.Marshal(object)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	objectMsgStr := string(objectMsg)
   132  	CObjectStr := C.CString(objectMsgStr)
   133  
   134  	objectPtr := (*C.struct_CoProcessMessage)(C.malloc(C.size_t(unsafe.Sizeof(C.struct_CoProcessMessage{}))))
   135  	objectPtr.p_data = unsafe.Pointer(CObjectStr)
   136  	objectPtr.length = C.int(len(objectMsg))
   137  
   138  	newObjectPtr := (*C.struct_CoProcessMessage)(C.malloc(C.size_t(unsafe.Sizeof(C.struct_CoProcessMessage{}))))
   139  
   140  	// Call the dispatcher (objectPtr is freed during this call):
   141  	if err = d.NativeDispatch(unsafe.Pointer(objectPtr), unsafe.Pointer(newObjectPtr)); err != nil {
   142  		return nil, err
   143  	}
   144  	newObjectBytes := C.GoBytes(newObjectPtr.p_data, newObjectPtr.length)
   145  
   146  	newObject := &coprocess.Object{}
   147  
   148  	if err := json.Unmarshal(newObjectBytes, newObject); err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	// Free the returned object memory:
   153  	C.free(unsafe.Pointer(newObjectPtr.p_data))
   154  	C.free(unsafe.Pointer(newObjectPtr))
   155  
   156  	return newObject, nil
   157  }
   158  
   159  // Reload will perform a middleware reload when a hot reload is triggered.
   160  func (d *LuaDispatcher) Reload() {
   161  	files, _ := ioutil.ReadDir(MiddlewareBasePath)
   162  
   163  	if d.MiddlewareCache == nil {
   164  		d.MiddlewareCache = make(map[string]string, len(files))
   165  		gMiddlewareCache = d.MiddlewareCache
   166  	} else {
   167  		for k := range d.MiddlewareCache {
   168  			delete(d.MiddlewareCache, k)
   169  		}
   170  	}
   171  
   172  	for _, f := range files {
   173  		middlewarePath := filepath.Join(MiddlewareBasePath, f.Name())
   174  		contents, err := ioutil.ReadFile(middlewarePath)
   175  		if err != nil {
   176  			log.WithFields(logrus.Fields{
   177  				"prefix": "coprocess",
   178  			}).Error("Failed to read middleware file: ", err)
   179  		}
   180  
   181  		d.MiddlewareCache[f.Name()] = string(contents)
   182  	}
   183  }
   184  
   185  func (d *LuaDispatcher) HandleMiddlewareCache(b *apidef.BundleManifest, basePath string) {
   186  	for _, f := range b.FileList {
   187  		fullPath := filepath.Join(basePath, f)
   188  		contents, err := ioutil.ReadFile(fullPath)
   189  		if err == nil {
   190  			d.ModuleCache[f] = string(contents)
   191  		} else {
   192  			log.WithFields(logrus.Fields{
   193  				"prefix": "coprocess",
   194  			}).Error("Failed to read bundle file: ", err)
   195  		}
   196  	}
   197  }
   198  
   199  func (d *LuaDispatcher) LoadModules() {
   200  	log.WithFields(logrus.Fields{
   201  		"prefix": "coprocess",
   202  	}).Info("Loading Tyk/Lua modules.")
   203  
   204  	if d.ModuleCache == nil {
   205  		d.ModuleCache = make(map[string]string, 0)
   206  		gModuleCache = d.ModuleCache
   207  	}
   208  
   209  	middlewarePath := filepath.Join(ModuleBasePath, "bundle.lua")
   210  	contents, err := ioutil.ReadFile(middlewarePath)
   211  
   212  	if err == nil {
   213  		d.ModuleCache["bundle.lua"] = string(contents)
   214  	} else {
   215  		log.WithFields(logrus.Fields{
   216  			"prefix": "coprocess",
   217  		}).Error("Failed to read bundle file: ", err)
   218  	}
   219  }
   220  
   221  //export LoadCachedModules
   222  func LoadCachedModules(luaState unsafe.Pointer) {
   223  	for moduleName, moduleContents := range gModuleCache {
   224  		cModuleName := C.CString(moduleName)
   225  		cModuleContents := C.CString(moduleContents)
   226  		C.LoadMiddlewareIntoState((*C.struct_lua_State)(luaState), cModuleName, cModuleContents)
   227  		C.free(unsafe.Pointer(cModuleName))
   228  		C.free(unsafe.Pointer(cModuleContents))
   229  	}
   230  }
   231  
   232  //export LoadCachedMiddleware
   233  func LoadCachedMiddleware(luaState unsafe.Pointer) {
   234  	for middlewareName, middlewareContents := range gMiddlewareCache {
   235  		cMiddlewareName := C.CString(middlewareName)
   236  		cMiddlewareContents := C.CString(middlewareContents)
   237  		C.LoadMiddlewareIntoState((*C.struct_lua_State)(luaState), cMiddlewareName, cMiddlewareContents)
   238  		C.free(unsafe.Pointer(cMiddlewareName))
   239  		C.free(unsafe.Pointer(cMiddlewareContents))
   240  	}
   241  }
   242  
   243  func (d *LuaDispatcher) DispatchEvent(eventJSON []byte) {
   244  	CEventJSON := C.CString(string(eventJSON))
   245  	C.LuaDispatchEvent(CEventJSON)
   246  	C.free(unsafe.Pointer(CEventJSON))
   247  }
   248  
   249  // NewCoProcessDispatcher wraps all the actions needed for this CP.
   250  func NewLuaDispatcher() (coprocess.Dispatcher, error) {
   251  	dispatcher := &LuaDispatcher{}
   252  	dispatcher.LoadModules()
   253  	dispatcher.Reload()
   254  	return dispatcher, nil
   255  }