github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/sqlite3/sqlite3.go (about)

     1  package sqlite3
     2  
     3  /*
     4  #include <sqlite3.h>
     5  #include <stdlib.h>
     6  #include <string.h>
     7  
     8  #ifdef __CYGWIN__
     9  # include <errno.h>
    10  #endif
    11  
    12  #ifndef SQLITE_OPEN_READWRITE
    13  # define SQLITE_OPEN_READWRITE 0
    14  #endif
    15  
    16  #ifndef SQLITE_OPEN_FULLMUTEX
    17  # define SQLITE_OPEN_FULLMUTEX 0
    18  #endif
    19  
    20  static int
    21  _sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
    22  #ifdef SQLITE_OPEN_URI
    23    return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
    24  #else
    25    return sqlite3_open_v2(filename, ppDb, flags, zVfs);
    26  #endif
    27  }
    28  
    29  static int
    30  _sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
    31    return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
    32  }
    33  
    34  static int
    35  _sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
    36    return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
    37  }
    38  
    39  #include <stdio.h>
    40  #include <stdint.h>
    41  
    42  static long
    43  _sqlite3_last_insert_rowid(sqlite3* db) {
    44    return (long) sqlite3_last_insert_rowid(db);
    45  }
    46  
    47  static long
    48  _sqlite3_changes(sqlite3* db) {
    49    return (long) sqlite3_changes(db);
    50  }
    51  
    52  */
    53  import "C"
    54  import (
    55  	"database/sql"
    56  	"database/sql/driver"
    57  	"errors"
    58  	"io"
    59  	"strings"
    60  	"time"
    61  	"unsafe"
    62  )
    63  
    64  // Timestamp formats understood by both this module and SQLite.
    65  // The first format in the slice will be used when saving time values
    66  // into the database. When parsing a string from a timestamp or
    67  // datetime column, the formats are tried in order.
    68  var SQLiteTimestampFormats = []string{
    69  	"2006-01-02 15:04:05.999999999",
    70  	"2006-01-02T15:04:05.999999999",
    71  	"2006-01-02 15:04:05",
    72  	"2006-01-02T15:04:05",
    73  	"2006-01-02 15:04",
    74  	"2006-01-02T15:04",
    75  	"2006-01-02",
    76  }
    77  
    78  func init() {
    79  	sql.Register("sqlite3", &SQLiteDriver{})
    80  }
    81  
    82  // Driver struct.
    83  type SQLiteDriver struct {
    84  	Extensions  []string
    85  	ConnectHook func(*SQLiteConn) error
    86  }
    87  
    88  // Conn struct.
    89  type SQLiteConn struct {
    90  	db *C.sqlite3
    91  }
    92  
    93  // Tx struct.
    94  type SQLiteTx struct {
    95  	c *SQLiteConn
    96  }
    97  
    98  // Stmt struct.
    99  type SQLiteStmt struct {
   100  	c      *SQLiteConn
   101  	s      *C.sqlite3_stmt
   102  	t      string
   103  	closed bool
   104  }
   105  
   106  // Result struct.
   107  type SQLiteResult struct {
   108  	id      int64
   109  	changes int64
   110  }
   111  
   112  // Rows struct.
   113  type SQLiteRows struct {
   114  	s        *SQLiteStmt
   115  	nc       int
   116  	cols     []string
   117  	decltype []string
   118  }
   119  
   120  // Commit transaction.
   121  func (tx *SQLiteTx) Commit() error {
   122  	if err := tx.c.exec("COMMIT"); err != nil {
   123  		return err
   124  	}
   125  	return nil
   126  }
   127  
   128  // Rollback transaction.
   129  func (tx *SQLiteTx) Rollback() error {
   130  	if err := tx.c.exec("ROLLBACK"); err != nil {
   131  		return err
   132  	}
   133  	return nil
   134  }
   135  
   136  // AutoCommit return which currently auto commit or not.
   137  func (c *SQLiteConn) AutoCommit() bool {
   138  	return int(C.sqlite3_get_autocommit(c.db)) != 0
   139  }
   140  
   141  // TODO: Execer & Queryer currently disabled
   142  // https://github.com/mattn/go-sqlite3/issues/82
   143  //// Implements Execer
   144  //func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   145  //	tx, err := c.Begin()
   146  //	if err != nil {
   147  //		return nil, err
   148  //	}
   149  //	for {
   150  //		s, err := c.Prepare(query)
   151  //		if err != nil {
   152  //			tx.Rollback()
   153  //			return nil, err
   154  //		}
   155  //		na := s.NumInput()
   156  //		res, err := s.Exec(args[:na])
   157  //		if err != nil && err != driver.ErrSkip {
   158  //			tx.Rollback()
   159  //			s.Close()
   160  //			return nil, err
   161  //		}
   162  //		args = args[na:]
   163  //		tail := s.(*SQLiteStmt).t
   164  //		if tail == "" {
   165  //			tx.Commit()
   166  //			return res, nil
   167  //		}
   168  //		s.Close()
   169  //		query = tail
   170  //	}
   171  //}
   172  //
   173  //// Implements Queryer
   174  //func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   175  //	tx, err := c.Begin()
   176  //	if err != nil {
   177  //		return nil, err
   178  //	}
   179  //	for {
   180  //		s, err := c.Prepare(query)
   181  //		if err != nil {
   182  //			tx.Rollback()
   183  //			return nil, err
   184  //		}
   185  //		na := s.NumInput()
   186  //		rows, err := s.Query(args[:na])
   187  //		if err != nil && err != driver.ErrSkip {
   188  //			tx.Rollback()
   189  //			s.Close()
   190  //			return nil, err
   191  //		}
   192  //		args = args[na:]
   193  //		tail := s.(*SQLiteStmt).t
   194  //		if tail == "" {
   195  //			tx.Commit()
   196  //			return rows, nil
   197  //		}
   198  //		s.Close()
   199  //		query = tail
   200  //	}
   201  //}
   202  
   203  func (c *SQLiteConn) exec(cmd string) error {
   204  	pcmd := C.CString(cmd)
   205  	defer C.free(unsafe.Pointer(pcmd))
   206  	rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
   207  	if rv != C.SQLITE_OK {
   208  		return ErrNo(rv)
   209  	}
   210  	return nil
   211  }
   212  
   213  // Begin transaction.
   214  func (c *SQLiteConn) Begin() (driver.Tx, error) {
   215  	if err := c.exec("BEGIN"); err != nil {
   216  		return nil, err
   217  	}
   218  	return &SQLiteTx{c}, nil
   219  }
   220  
   221  func errorString(err ErrNo) string {
   222  	return C.GoString(C.sqlite3_errstr(C.int(err)))
   223  }
   224  
   225  // Open database and return a new connection.
   226  // You can specify DSN string with URI filename.
   227  //   test.db
   228  //   file:test.db?cache=shared&mode=memory
   229  //   :memory:
   230  //   file::memory:
   231  func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
   232  	if C.sqlite3_threadsafe() == 0 {
   233  		return nil, errors.New("sqlite library was not compiled for thread-safe operation")
   234  	}
   235  
   236  	var db *C.sqlite3
   237  	name := C.CString(dsn)
   238  	defer C.free(unsafe.Pointer(name))
   239  	rv := C._sqlite3_open_v2(name, &db,
   240  		C.SQLITE_OPEN_FULLMUTEX|
   241  			C.SQLITE_OPEN_READWRITE|
   242  			C.SQLITE_OPEN_CREATE,
   243  		nil)
   244  	if rv != 0 {
   245  		return nil, ErrNo(rv)
   246  	}
   247  	if db == nil {
   248  		return nil, errors.New("sqlite succeeded without returning a database")
   249  	}
   250  
   251  	rv = C.sqlite3_busy_timeout(db, 5000)
   252  	if rv != C.SQLITE_OK {
   253  		return nil, ErrNo(rv)
   254  	}
   255  
   256  	conn := &SQLiteConn{db}
   257  
   258  	if len(d.Extensions) > 0 {
   259  		rv = C.sqlite3_enable_load_extension(db, 1)
   260  		if rv != C.SQLITE_OK {
   261  			return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
   262  		}
   263  
   264  		stmt, err := conn.Prepare("SELECT load_extension(?);")
   265  		if err != nil {
   266  			return nil, err
   267  		}
   268  
   269  		for _, extension := range d.Extensions {
   270  			if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
   271  				return nil, err
   272  			}
   273  		}
   274  
   275  		if err = stmt.Close(); err != nil {
   276  			return nil, err
   277  		}
   278  
   279  		rv = C.sqlite3_enable_load_extension(db, 0)
   280  		if rv != C.SQLITE_OK {
   281  			return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
   282  		}
   283  	}
   284  
   285  	if d.ConnectHook != nil {
   286  		if err := d.ConnectHook(conn); err != nil {
   287  			return nil, err
   288  		}
   289  	}
   290  
   291  	return conn, nil
   292  }
   293  
   294  // Close the connection.
   295  func (c *SQLiteConn) Close() error {
   296  	s := C.sqlite3_next_stmt(c.db, nil)
   297  	for s != nil {
   298  		C.sqlite3_finalize(s)
   299  		s = C.sqlite3_next_stmt(c.db, nil)
   300  	}
   301  	rv := C.sqlite3_close(c.db)
   302  	if rv != C.SQLITE_OK {
   303  		return ErrNo(rv)
   304  	}
   305  	c.db = nil
   306  	return nil
   307  }
   308  
   309  // Prepare query string. Return a new statement.
   310  func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
   311  	pquery := C.CString(query)
   312  	defer C.free(unsafe.Pointer(pquery))
   313  	var s *C.sqlite3_stmt
   314  	var tail *C.char
   315  	rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &tail)
   316  	if rv != C.SQLITE_OK {
   317  		return nil, ErrNo(rv)
   318  	}
   319  	var t string
   320  	if tail != nil && C.strlen(tail) > 0 {
   321  		t = strings.TrimSpace(C.GoString(tail))
   322  	}
   323  	return &SQLiteStmt{c: c, s: s, t: t}, nil
   324  }
   325  
   326  // Close the statement.
   327  func (s *SQLiteStmt) Close() error {
   328  	if s.closed {
   329  		return nil
   330  	}
   331  	s.closed = true
   332  	if s.c == nil || s.c.db == nil {
   333  		return errors.New("sqlite statement with already closed database connection")
   334  	}
   335  	rv := C.sqlite3_finalize(s.s)
   336  	if rv != C.SQLITE_OK {
   337  		return ErrNo(rv)
   338  	}
   339  	return nil
   340  }
   341  
   342  // Return a number of parameters.
   343  func (s *SQLiteStmt) NumInput() int {
   344  	return int(C.sqlite3_bind_parameter_count(s.s))
   345  }
   346  
   347  func (s *SQLiteStmt) bind(args []driver.Value) error {
   348  	rv := C.sqlite3_reset(s.s)
   349  	if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
   350  		return ErrNo(rv)
   351  	}
   352  
   353  	for i, v := range args {
   354  		n := C.int(i + 1)
   355  		switch v := v.(type) {
   356  		case nil:
   357  			rv = C.sqlite3_bind_null(s.s, n)
   358  		case string:
   359  			if len(v) == 0 {
   360  				b := []byte{0}
   361  				rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
   362  			} else {
   363  				b := []byte(v)
   364  				rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
   365  			}
   366  		case int:
   367  			rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
   368  		case int32:
   369  			rv = C.sqlite3_bind_int(s.s, n, C.int(v))
   370  		case int64:
   371  			rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
   372  		case byte:
   373  			rv = C.sqlite3_bind_int(s.s, n, C.int(v))
   374  		case bool:
   375  			if bool(v) {
   376  				rv = C.sqlite3_bind_int(s.s, n, 1)
   377  			} else {
   378  				rv = C.sqlite3_bind_int(s.s, n, 0)
   379  			}
   380  		case float32:
   381  			rv = C.sqlite3_bind_double(s.s, n, C.double(v))
   382  		case float64:
   383  			rv = C.sqlite3_bind_double(s.s, n, C.double(v))
   384  		case []byte:
   385  			var p *byte
   386  			if len(v) > 0 {
   387  				p = &v[0]
   388  			}
   389  			rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
   390  		case time.Time:
   391  			b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
   392  			rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
   393  		}
   394  		if rv != C.SQLITE_OK {
   395  			return ErrNo(rv)
   396  		}
   397  	}
   398  	return nil
   399  }
   400  
   401  // Query the statment with arguments. Return records.
   402  func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
   403  	if err := s.bind(args); err != nil {
   404  		return nil, err
   405  	}
   406  	return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil
   407  }
   408  
   409  // Return last inserted ID.
   410  func (r *SQLiteResult) LastInsertId() (int64, error) {
   411  	return r.id, nil
   412  }
   413  
   414  // Return how many rows affected.
   415  func (r *SQLiteResult) RowsAffected() (int64, error) {
   416  	return r.changes, nil
   417  }
   418  
   419  // Execute the statement with arguments. Return result object.
   420  func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
   421  	if err := s.bind(args); err != nil {
   422  		return nil, err
   423  	}
   424  	rv := C.sqlite3_step(s.s)
   425  	if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
   426  		return nil, ErrNo(rv)
   427  	}
   428  
   429  	res := &SQLiteResult{
   430  		int64(C._sqlite3_last_insert_rowid(s.c.db)),
   431  		int64(C._sqlite3_changes(s.c.db)),
   432  	}
   433  	return res, nil
   434  }
   435  
   436  // Close the rows.
   437  func (rc *SQLiteRows) Close() error {
   438  	if rc.s.closed {
   439  		return nil
   440  	}
   441  	rv := C.sqlite3_reset(rc.s.s)
   442  	if rv != C.SQLITE_OK {
   443  		return ErrNo(rv)
   444  	}
   445  	return nil
   446  }
   447  
   448  // Return column names.
   449  func (rc *SQLiteRows) Columns() []string {
   450  	if rc.nc != len(rc.cols) {
   451  		rc.cols = make([]string, rc.nc)
   452  		for i := 0; i < rc.nc; i++ {
   453  			rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
   454  		}
   455  	}
   456  	return rc.cols
   457  }
   458  
   459  // Move cursor to next.
   460  func (rc *SQLiteRows) Next(dest []driver.Value) error {
   461  	rv := C.sqlite3_step(rc.s.s)
   462  	if rv == C.SQLITE_DONE {
   463  		return io.EOF
   464  	}
   465  	if rv != C.SQLITE_ROW {
   466  		rv = C.sqlite3_reset(rc.s.s)
   467  		if rv != C.SQLITE_OK {
   468  			return ErrNo(rv)
   469  		}
   470  		return nil
   471  	}
   472  
   473  	if rc.decltype == nil {
   474  		rc.decltype = make([]string, rc.nc)
   475  		for i := 0; i < rc.nc; i++ {
   476  			rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
   477  		}
   478  	}
   479  
   480  	for i := range dest {
   481  		switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
   482  		case C.SQLITE_INTEGER:
   483  			val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
   484  			switch rc.decltype[i] {
   485  			case "timestamp", "datetime":
   486  				dest[i] = time.Unix(val, 0)
   487  			case "boolean":
   488  				dest[i] = val > 0
   489  			default:
   490  				dest[i] = val
   491  			}
   492  		case C.SQLITE_FLOAT:
   493  			dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
   494  		case C.SQLITE_BLOB:
   495  			p := C.sqlite3_column_blob(rc.s.s, C.int(i))
   496  			n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
   497  			switch dest[i].(type) {
   498  			case sql.RawBytes:
   499  				dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
   500  			default:
   501  				slice := make([]byte, n)
   502  				copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
   503  				dest[i] = slice
   504  			}
   505  		case C.SQLITE_NULL:
   506  			dest[i] = nil
   507  		case C.SQLITE_TEXT:
   508  			var err error
   509  			s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
   510  
   511  			switch rc.decltype[i] {
   512  			case "timestamp", "datetime":
   513  				for _, format := range SQLiteTimestampFormats {
   514  					if dest[i], err = time.Parse(format, s); err == nil {
   515  						break
   516  					}
   517  				}
   518  				if err != nil {
   519  					// The column is a time value, so return the zero time on parse failure.
   520  					dest[i] = time.Time{}
   521  				}
   522  			default:
   523  				dest[i] = s
   524  			}
   525  
   526  		}
   527  	}
   528  	return nil
   529  }