github.com/aergoio/aergo@v1.3.1/contract/db_module.c (about)

     1  #include <string.h>
     2  #include <stdlib.h>
     3  #include <ctype.h>
     4  #include <time.h>
     5  #include <sqlite3-binding.h>
     6  #include "vm.h"
     7  #include "sqlcheck.h"
     8  #include "lgmp.h"
     9  #include "util.h"
    10  #include "_cgo_export.h"
    11  
    12  #define LAST_ERROR(L,db,rc)                         \
    13      do {                                            \
    14          if ((rc) != SQLITE_OK) {                    \
    15              luaL_error((L), sqlite3_errmsg((db)));  \
    16          }                                           \
    17      } while(0)
    18  
    19  #define RESOURCE_PSTMT_KEY "_RESOURCE_PSTMT_KEY_"
    20  #define RESOURCE_RS_KEY "_RESOURCE_RS_KEY_"
    21  
    22  extern const int *getLuaExecContext(lua_State *L);
    23  static void get_column_meta(lua_State *L, sqlite3_stmt* stmt);
    24  
    25  static int append_resource(lua_State *L, const char *key, void *data)
    26  {
    27      int refno;
    28      if (luaL_findtable(L, LUA_REGISTRYINDEX, key, 0) != NULL) {
    29          luaL_error(L, "cannot find the environment of the db module");
    30      }
    31      /* tab */
    32      lua_pushlightuserdata(L, data);     /* tab pstmt */
    33      refno = luaL_ref(L, -2);            /* tab */
    34      lua_pop(L, 1);                      /* remove tab */
    35      return refno;
    36  }
    37  
    38  #define DB_PSTMT_ID "__db_pstmt__"
    39  
    40  typedef struct {
    41      sqlite3 *db;
    42      sqlite3_stmt *s;
    43      int closed;
    44      int refno;
    45  } db_pstmt_t;
    46  
    47  #define DB_RS_ID "__db_rs__"
    48  
    49  typedef struct {
    50      sqlite3 *db;
    51      sqlite3_stmt *s;
    52      int closed;
    53      int nc;
    54      int shared_stmt;
    55      char **decltypes;
    56      int refno;
    57  } db_rs_t;
    58  
    59  static db_rs_t *get_db_rs(lua_State *L, int pos)
    60  {
    61      db_rs_t *rs = luaL_checkudata(L, pos, DB_RS_ID);
    62      if (rs->closed) {
    63          luaL_error(L, "resultset is closed");
    64      }
    65      return rs;
    66  }
    67  
    68  static int db_rs_tostr(lua_State *L)
    69  {
    70      db_rs_t *rs = luaL_checkudata(L, 1, DB_RS_ID);
    71      if (rs->closed) {
    72          lua_pushfstring(L, "resultset is closed");
    73      } else {
    74          lua_pushfstring(L, "resultset{handle=%p}", rs->s);
    75      }
    76      return 1;
    77  }
    78  
    79  static char *dup_decltype(const char *decltype)
    80  {
    81      int n;
    82      char *p;
    83      char *c;
    84  
    85      if (decltype == NULL) {
    86          return NULL;
    87      }
    88  
    89      p = c = malloc(strlen(decltype)+1);
    90      while ((*c++ = tolower(*decltype++))) ;
    91  
    92      if (strcmp(p, "date") == 0 || strcmp(p, "datetime") == 0 || strcmp(p, "timestamp") == 0 ||
    93          strcmp(p, "boolean") == 0) {
    94          return p;
    95      }
    96      free(p);
    97      return NULL;
    98  }
    99  
   100  static void free_decltypes(db_rs_t *rs)
   101  {
   102      int i;
   103      for (i = 0; i < rs->nc; i++) {
   104          if (rs->decltypes[i] != NULL)
   105              free(rs->decltypes[i]);
   106      }
   107      free(rs->decltypes);
   108      rs->decltypes = NULL;
   109  }
   110  
   111  static int db_rs_get(lua_State *L)
   112  {
   113      db_rs_t *rs = get_db_rs(L, 1);
   114      int i;
   115      sqlite3_int64 d;
   116      double f;
   117      int n;
   118      const unsigned char *s;
   119  
   120      if (rs->decltypes == NULL) {
   121          luaL_error(L, "`get' called without calling `next'");
   122      }
   123      for (i = 0; i < rs->nc; i++) {
   124          switch (sqlite3_column_type(rs->s, i)) {
   125          case SQLITE_INTEGER:
   126              d = sqlite3_column_int64(rs->s, i);
   127              if (rs->decltypes[i] == NULL)  {
   128                  lua_pushinteger(L, d);
   129              } else if (strcmp(rs->decltypes[i], "boolean") == 0) {
   130                  if (d != 0) {
   131                      lua_pushboolean(L, 1);
   132                  } else {
   133                      lua_pushboolean(L, 0);
   134                  }
   135              } else { // date, datetime, timestamp
   136                  char buf[80];
   137                  strftime(buf, 80, "%Y-%m-%d %H:%M:%S", gmtime((time_t *)&d));
   138                  lua_pushlstring(L, (const char *)buf, strlen(buf));
   139              }
   140              break;
   141          case SQLITE_FLOAT:
   142              f = sqlite3_column_double(rs->s, i);
   143              lua_pushnumber(L, f);
   144              break;
   145          case SQLITE_TEXT:
   146              n = sqlite3_column_bytes(rs->s, i);
   147              s = sqlite3_column_text(rs->s, i);
   148              lua_pushlstring(L, (const char *)s, n);
   149              break;
   150          case SQLITE_NULL:
   151              /* fallthrough */
   152          default: /* unsupported types */
   153              lua_pushnil(L);
   154          }
   155      }
   156      return rs->nc;
   157  }
   158  
   159  static int db_rs_colcnt(lua_State *L)
   160  {
   161      db_rs_t *rs = get_db_rs(L, 1);
   162  
   163      lua_pushinteger(L, rs->nc);
   164      return 1;
   165  }
   166  
   167  static void db_rs_close(lua_State *L, db_rs_t *rs, int remove)
   168  {
   169      if (rs->closed) {
   170          return;
   171      }
   172      rs->closed = 1;
   173      if (rs->decltypes) {
   174          free_decltypes(rs);
   175      }
   176      if (rs->shared_stmt == 0) {
   177          sqlite3_finalize(rs->s);
   178      }
   179      if (remove) {
   180          if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY, 0) != NULL) {
   181              luaL_error(L, "cannot find the environment of the db module");
   182          }
   183          luaL_unref(L, -1, rs->refno);
   184          lua_pop(L, 1);
   185      }
   186  }
   187  
   188  static int db_rs_next(lua_State *L)
   189  {
   190      db_rs_t *rs = get_db_rs(L, 1);
   191      int rc;
   192  
   193      rc = sqlite3_step(rs->s);
   194      if (rc == SQLITE_DONE) {
   195          db_rs_close(L, rs, 1);
   196          lua_pushboolean(L, 0);
   197      } else if (rc != SQLITE_ROW) {
   198          rc = sqlite3_reset(rs->s);
   199          LAST_ERROR(L, rs->db, rc);
   200          db_rs_close(L, rs, 1);
   201          lua_pushboolean(L, 0);
   202      } else {
   203          if (rs->decltypes == NULL) {
   204              int i;
   205              rs->decltypes = malloc(sizeof(char *) * rs->nc);
   206              for (i = 0; i < rs->nc; i++) {
   207                  rs->decltypes[i] = dup_decltype(sqlite3_column_decltype(rs->s, i));
   208              }
   209          }
   210          lua_pushboolean(L, 1);
   211      }
   212      return 1;
   213  }
   214  
   215  static int db_rs_gc(lua_State *L)
   216  {
   217      db_rs_close(L, luaL_checkudata(L, 1, DB_RS_ID), 1);
   218      return 0;
   219  }
   220  
   221  static db_pstmt_t *get_db_pstmt(lua_State *L, int pos)
   222  {
   223      db_pstmt_t *pstmt = luaL_checkudata(L, pos, DB_PSTMT_ID);
   224      if (pstmt->closed) {
   225          luaL_error(L, "prepared statement is closed");
   226      }
   227      return pstmt;
   228  }
   229  
   230  static int db_pstmt_tostr(lua_State *L)
   231  {
   232      db_pstmt_t *pstmt = luaL_checkudata(L, 1, DB_PSTMT_ID);
   233      if (pstmt->closed) {
   234          lua_pushfstring(L, "prepared statement is closed");
   235      } else {
   236          lua_pushfstring(L, "prepared statement{handle=%p}", pstmt->s);
   237      }
   238      return 1;
   239  }
   240  
   241  static int bind(lua_State *L, sqlite3 *db, sqlite3_stmt *pstmt)
   242  {
   243      int rc, i;
   244      int argc = lua_gettop(L) - 1;
   245      int param_count;
   246  
   247      param_count = sqlite3_bind_parameter_count(pstmt);
   248      if (argc != param_count) {
   249          lua_pushfstring(L, "parameter count mismatch: want %d got %d", param_count, argc);
   250          return -1;
   251      }
   252  
   253      rc = sqlite3_reset(pstmt);
   254      sqlite3_clear_bindings(pstmt);
   255      if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) {
   256          lua_pushfstring(L, sqlite3_errmsg(db));
   257          return -1;
   258      }
   259  
   260      for (i = 1; i <= argc; i++) {
   261          int t, b, n = i + 1;
   262          const char *s;
   263          size_t l;
   264  
   265          luaL_checkany(L, n);
   266          t = lua_type(L, n);
   267  
   268          switch (t) {
   269          case LUA_TNUMBER:
   270              if (luaL_isinteger(L, n)) {
   271                  lua_Integer d = lua_tointeger(L, n);
   272                  rc = sqlite3_bind_int64(pstmt, i, (sqlite3_int64)d);
   273              } else {
   274                  lua_Number d = lua_tonumber(L, n);
   275                  rc = sqlite3_bind_double(pstmt, i, (double)d);
   276              }
   277              break;
   278          case LUA_TSTRING:
   279              s = lua_tolstring(L, n, &l);
   280              rc = sqlite3_bind_text(pstmt, i, s, l, SQLITE_TRANSIENT);
   281              break;
   282          case LUA_TBOOLEAN:
   283              b = lua_toboolean(L, i+1);
   284              if (b) {
   285                  rc = sqlite3_bind_int(pstmt, i, 1);
   286              } else {
   287                  rc = sqlite3_bind_int(pstmt, i, 0);
   288              }
   289              break;
   290          case LUA_TNIL:
   291              rc = sqlite3_bind_null(pstmt, i);
   292              break;
   293          case LUA_TUSERDATA:
   294          {
   295              if (lua_isbignumber(L, n)) {
   296                  long int d = lua_get_bignum_si(L, n);
   297                  if (d == 0 && lua_bignum_is_zero(L, n) != 0) {
   298                      char *s = lua_get_bignum_str(L, n);
   299                      if (s != NULL) {
   300                          lua_pushfstring(L, "bignum value overflow for binding %s", s);
   301                          free(s);
   302                      }
   303                      return -1;
   304                  }
   305                  rc = sqlite3_bind_int64(pstmt, i, (sqlite3_int64)d);
   306                  break;
   307              }
   308          }
   309          default:
   310              lua_pushfstring(L, "unsupported type: %s", lua_typename(L, n));
   311              return -1;
   312          }
   313          if (rc != SQLITE_OK) {
   314              lua_pushfstring(L, sqlite3_errmsg(db));
   315              return -1;
   316          }
   317      }
   318      
   319      return 0;
   320  }
   321  
   322  static int db_pstmt_exec(lua_State *L)
   323  {
   324      int rc, n;
   325      db_pstmt_t *pstmt = get_db_pstmt(L, 1);
   326  
   327      /*check for exec in function */
   328  	getLuaExecContext(L);
   329  
   330      rc = bind(L, pstmt->db, pstmt->s);
   331      if (rc == -1) {
   332          sqlite3_reset(pstmt->s);
   333          sqlite3_clear_bindings(pstmt->s);
   334          luaL_error(L, lua_tostring(L, -1));
   335      }
   336      rc = sqlite3_step(pstmt->s);
   337      if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) {
   338          sqlite3_reset(pstmt->s);
   339          sqlite3_clear_bindings(pstmt->s);
   340          luaL_error(L, sqlite3_errmsg(pstmt->db));
   341      }
   342      n = sqlite3_changes(pstmt->db);
   343      lua_pushinteger(L, n);
   344      return 1;
   345  }
   346  
   347  static int db_pstmt_query(lua_State *L)
   348  {
   349      int rc;
   350      db_pstmt_t *pstmt = get_db_pstmt(L, 1);
   351      db_rs_t *rs;
   352  
   353  	getLuaExecContext(L);
   354      if (!sqlite3_stmt_readonly(pstmt->s)) {
   355          luaL_error(L, "invalid sql command(permitted readonly)");
   356      }
   357      rc = bind(L, pstmt->db, pstmt->s);
   358      if (rc != 0) {
   359          sqlite3_reset(pstmt->s);
   360          sqlite3_clear_bindings(pstmt->s);
   361          luaL_error(L, lua_tostring(L, -1));
   362      }
   363  
   364      rs = (db_rs_t *)lua_newuserdata(L, sizeof(db_rs_t));
   365      luaL_getmetatable(L, DB_RS_ID);
   366      lua_setmetatable(L, -2);
   367      rs->db = pstmt->db;
   368      rs->s = pstmt->s;
   369      rs->closed = 0;
   370      rs->nc = sqlite3_column_count(pstmt->s);
   371      rs->shared_stmt = 1;
   372      rs->decltypes = NULL;
   373      rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs);
   374  
   375      return 1;
   376  }
   377  
   378  static void get_column_meta(lua_State *L, sqlite3_stmt* stmt)
   379  {
   380      const char *name, *decltype;
   381      int type;
   382      int colcnt = sqlite3_column_count(stmt);
   383      int i;
   384  
   385      lua_createtable(L, 0, 2);
   386      lua_pushinteger(L, colcnt);
   387      lua_setfield(L, -2, "colcnt");
   388      if (colcnt > 0) {
   389          lua_createtable(L, colcnt, 0);  /* colinfos names */
   390          lua_createtable(L, colcnt, 0);  /* colinfos names decltypes */
   391      }
   392      else {
   393          lua_pushnil(L);
   394          lua_pushnil(L);
   395      }
   396      for (i = 0; i < colcnt; i++) {
   397          name = sqlite3_column_name(stmt, i);
   398          if (name == NULL)
   399              lua_pushstring(L, "");
   400          else
   401              lua_pushstring(L, name);
   402          lua_rawseti(L, -3, i+1);
   403  
   404          decltype = sqlite3_column_decltype(stmt, i);
   405          if (decltype == NULL)
   406              lua_pushstring(L, "");
   407           else
   408              lua_pushstring(L, decltype);
   409          lua_rawseti(L, -2, i+1);
   410      }
   411      lua_setfield(L, -3, "decltypes");
   412      lua_setfield(L, -2, "names");
   413  }
   414  
   415  static int db_pstmt_column_info(lua_State *L)
   416  {
   417      int colcnt;
   418      db_pstmt_t *pstmt = get_db_pstmt(L, 1);
   419  	getLuaExecContext(L);
   420  
   421      get_column_meta(L, pstmt->s);
   422      return 1;
   423  }
   424  
   425  static int db_pstmt_bind_param_cnt(lua_State *L)
   426  {
   427      db_pstmt_t *pstmt = get_db_pstmt(L, 1);
   428  	getLuaExecContext(L);
   429  
   430  	lua_pushinteger(L, sqlite3_bind_parameter_count(pstmt->s));
   431  
   432  	return 1;
   433  }
   434  
   435  static void db_pstmt_close(lua_State *L, db_pstmt_t *pstmt, int remove)
   436  {
   437      if (pstmt->closed)
   438          return;
   439      pstmt->closed = 1;
   440      sqlite3_finalize(pstmt->s);
   441      if (remove) {
   442          if (luaL_findtable(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY, 0) != NULL) {
   443              luaL_error(L, "cannot find the environment of the db module");
   444          }
   445          luaL_unref(L, -1, pstmt->refno);
   446          lua_pop(L, 1);
   447      }
   448  }
   449  
   450  static int db_pstmt_gc(lua_State *L)
   451  {
   452      db_pstmt_close(L, luaL_checkudata(L, 1, DB_PSTMT_ID), 1);
   453      return 0;
   454  }
   455  
   456  static int db_exec(lua_State *L)
   457  {
   458      const char *cmd;
   459      sqlite3 *db;
   460      sqlite3_stmt *s;
   461      int rc;
   462  
   463      /*check for exec in function */
   464  	getLuaExecContext(L);
   465      cmd = luaL_checkstring(L, 1);
   466      if (!sqlcheck_is_permitted_sql(cmd)) {
   467          luaL_error(L, "invalid sql command");
   468      }
   469      db = vm_get_db(L);
   470      rc = sqlite3_prepare_v2(db, cmd, -1, &s, NULL);
   471      LAST_ERROR(L, db, rc);
   472  
   473      rc = bind(L, db, s);
   474      if (rc == -1) {
   475          sqlite3_finalize(s);
   476          luaL_error(L, lua_tostring(L, -1));
   477      }
   478  
   479      rc = sqlite3_step(s);
   480      if (rc != SQLITE_ROW && rc != SQLITE_OK && rc != SQLITE_DONE) {
   481          sqlite3_finalize(s);
   482          luaL_error(L, sqlite3_errmsg(db));
   483      }
   484      sqlite3_finalize(s);
   485  
   486      lua_pushinteger(L, sqlite3_changes(db));
   487      return 1;
   488  }
   489  
   490  static int db_query(lua_State *L)
   491  {
   492      const char *query;
   493      int rc;
   494      sqlite3 *db;
   495      sqlite3_stmt *s;
   496      db_rs_t *rs;
   497  
   498  	getLuaExecContext(L);
   499      query = luaL_checkstring(L, 1);
   500      if (!sqlcheck_is_readonly_sql(query)) {
   501          luaL_error(L, "invalid sql command(permitted readonly)");
   502      }
   503      db = vm_get_db(L);
   504      rc = sqlite3_prepare_v2(db, query, -1, &s, NULL);
   505      LAST_ERROR(L, db, rc);
   506  
   507      rc = bind(L, db, s);
   508      if (rc == -1) {
   509          sqlite3_finalize(s);
   510          luaL_error(L, lua_tostring(L, -1));
   511      }
   512  
   513      rs = (db_rs_t *)lua_newuserdata(L, sizeof(db_rs_t));
   514      luaL_getmetatable(L, DB_RS_ID);
   515      lua_setmetatable(L, -2);
   516      rs->db = db;
   517      rs->s = s;
   518      rs->closed = 0;
   519      rs->nc = sqlite3_column_count(s);
   520      rs->shared_stmt = 0;
   521      rs->decltypes = NULL;
   522      rs->refno = append_resource(L, RESOURCE_RS_KEY, (void *)rs);
   523  
   524      return 1;
   525  }
   526  
   527  static int db_prepare(lua_State *L)
   528  {
   529      const char *sql;
   530      int rc;
   531      int ref;
   532      sqlite3 *db;
   533      sqlite3_stmt *s;
   534      db_pstmt_t *pstmt;
   535  
   536      sql = luaL_checkstring(L, 1);
   537      if (!sqlcheck_is_permitted_sql(sql)) {
   538          luaL_error(L, "invalid sql command");
   539      }
   540      db = vm_get_db(L);
   541      rc = sqlite3_prepare_v2(db, sql, -1, &s, NULL);
   542      LAST_ERROR(L, db, rc);
   543  
   544      pstmt = (db_pstmt_t *)lua_newuserdata(L, sizeof(db_pstmt_t));
   545      luaL_getmetatable(L, DB_PSTMT_ID);
   546      lua_setmetatable(L, -2);
   547      pstmt->db = db;
   548      pstmt->s = s;
   549      pstmt->closed = 0;
   550      pstmt->refno = append_resource(L, RESOURCE_PSTMT_KEY, (void *)pstmt);
   551  
   552      return 1;
   553  }
   554  
   555  static int db_get_snapshot(lua_State *L)
   556  {
   557      char *snapshot;
   558      int *service = (int *)getLuaExecContext(L);
   559  
   560      snapshot = LuaGetDbSnapshot(service);
   561      strPushAndRelease(L, snapshot);
   562  
   563      return 1;
   564  }
   565  
   566  static int db_open_with_snapshot(lua_State *L)
   567  {
   568      char *snapshot = (char *)luaL_checkstring(L, 1);
   569      char *errStr;
   570      int *service = (int *)getLuaExecContext(L);
   571  
   572      errStr = LuaGetDbHandleSnap(service, snapshot);
   573      if (errStr != NULL) {
   574          strPushAndRelease(L, errStr);
   575          luaL_throwerror(L);
   576      }
   577      return 1;
   578  }
   579  
   580  int lua_db_release_resource(lua_State *L)
   581  {
   582      lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_RS_KEY);
   583      if (lua_istable(L, -1)) {
   584          /* T */
   585          lua_pushnil(L); /* T nil(key) */
   586          while (lua_next(L, -2)) {
   587              if (lua_islightuserdata(L, -1))
   588                  db_rs_close(L, (db_rs_t *)lua_topointer(L, -1), 0);
   589              lua_pop(L, 1);
   590          }
   591          lua_pop(L, 1);
   592      }
   593      lua_getfield(L, LUA_REGISTRYINDEX, RESOURCE_PSTMT_KEY);
   594      if (lua_istable(L, -1)) {
   595          /* T */
   596          lua_pushnil(L); /* T nil(key) */
   597          while (lua_next(L, -2)) {
   598              if (lua_islightuserdata(L, -1))
   599                  db_pstmt_close(L, (db_pstmt_t *)lua_topointer(L, -1), 0);
   600              lua_pop(L, 1);
   601          }
   602          lua_pop(L, 1);
   603      }
   604      return 0;
   605  }
   606  
   607  int luaopen_db(lua_State *L)
   608  {
   609      static const luaL_Reg rs_methods[] = {
   610          {"next",  db_rs_next},
   611          {"get", db_rs_get},
   612          {"colcnt", db_rs_colcnt},
   613          {"__tostring", db_rs_tostr},
   614          {"__gc", db_rs_gc},
   615          {NULL, NULL}
   616      };
   617  
   618      static const luaL_Reg pstmt_methods[] = {
   619          {"exec",  db_pstmt_exec},
   620          {"query", db_pstmt_query},
   621          {"column_info", db_pstmt_column_info},
   622          {"bind_param_cnt", db_pstmt_bind_param_cnt},
   623          {"__tostring", db_pstmt_tostr},
   624          {"__gc", db_pstmt_gc},
   625          {NULL, NULL}
   626      };
   627  
   628      static const luaL_Reg db_lib[] = {
   629          {"exec", db_exec},
   630          {"query", db_query},
   631          {"prepare", db_prepare},
   632          {"getsnap", db_get_snapshot},
   633          {"open_with_snapshot", db_open_with_snapshot},
   634          {NULL, NULL}
   635      };
   636  
   637      luaL_newmetatable(L, DB_RS_ID);
   638      lua_pushvalue(L, -1);
   639      lua_setfield(L, -2, "__index");
   640      luaL_register(L, NULL, rs_methods);
   641  
   642      luaL_newmetatable(L, DB_PSTMT_ID);
   643      lua_pushvalue(L, -1);
   644      lua_setfield(L, -2, "__index");
   645      luaL_register(L, NULL, pstmt_methods);
   646  
   647  	luaL_register(L, "db", db_lib);
   648  	lua_pop(L, 3);
   649  	return 1;
   650  }