gitlab.com/cznic/sqlite.git@v1.0.0/sqlite.go (about)

     1  // Copyright 2017 The Sqlite Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sqlite // import "modernc.org/sqlite"
     6  
     7  import (
     8  	"bytes"
     9  	"database/sql"
    10  	"database/sql/driver"
    11  	"fmt"
    12  	"io"
    13  	"math"
    14  	"os"
    15  	"runtime"
    16  	"sync"
    17  	"time"
    18  	"unsafe"
    19  
    20  	"golang.org/x/net/context"
    21  	"modernc.org/ccgo/crt"
    22  	"modernc.org/sqlite/internal/bin"
    23  )
    24  
    25  var (
    26  	_ driver.Conn    = (*conn)(nil)
    27  	_ driver.Driver  = (*Driver)(nil)
    28  	_ driver.Execer  = (*conn)(nil)
    29  	_ driver.Queryer = (*conn)(nil)
    30  	_ driver.Result  = (*result)(nil)
    31  	_ driver.Rows    = (*rows)(nil)
    32  	_ driver.Stmt    = (*stmt)(nil)
    33  	_ driver.Tx      = (*tx)(nil)
    34  )
    35  
    36  const (
    37  	driverName = "sqlite"
    38  	ptrSize    = 1 << (^uintptr(0)>>32&1 + ^uintptr(0)>>16&1 + ^uintptr(0)>>8&1 + 3) / 8
    39  )
    40  
    41  func init() {
    42  	tls := crt.NewTLS()
    43  	crt.X__register_stdfiles(tls, bin.Xstdin, bin.Xstdout, bin.Xstderr)
    44  	if bin.Xsqlite3_threadsafe(tls) == 0 {
    45  		panic(fmt.Errorf("sqlite: thread safety configuration error"))
    46  	}
    47  
    48  	if bin.Xsqlite3_config(
    49  		tls,
    50  		bin.XSQLITE_CONFIG_LOG,
    51  		func(tls *crt.TLS, pArg unsafe.Pointer, iErrCode int32, zMsg *int8) {
    52  			fmt.Fprintf(os.Stderr, "%v(%#x): %s\n", iErrCode, iErrCode, crt.GoString(zMsg))
    53  		},
    54  		unsafe.Pointer(nil),
    55  	) != 0 {
    56  		panic("sqlite: cannot configure error log callback")
    57  	}
    58  
    59  	sql.Register(driverName, newDrv())
    60  }
    61  
    62  func tracer(rx interface{}, format string, args ...interface{}) {
    63  	var b bytes.Buffer
    64  	_, file, line, _ := runtime.Caller(1)
    65  	fmt.Fprintf(&b, "%v:%v: (%[3]T)(%[3]p).", file, line, rx)
    66  	fmt.Fprintf(&b, format, args...)
    67  	fmt.Fprintf(os.Stderr, "%s\n", b.Bytes())
    68  }
    69  
    70  type result struct {
    71  	*stmt
    72  	lastInsertID int64
    73  	rowsAffected int
    74  }
    75  
    76  func (r *result) String() string {
    77  	return fmt.Sprintf("&%T@%p{stmt: %p, LastInsertId: %v, RowsAffected: %v}", *r, r, r.stmt, r.lastInsertID, r.rowsAffected)
    78  }
    79  
    80  func newResult(s *stmt) (_ *result, err error) {
    81  	r := &result{stmt: s}
    82  	if r.rowsAffected, err = r.changes(); err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	if r.lastInsertID, err = r.lastInsertRowID(); err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	return r, nil
    91  }
    92  
    93  // sqlite3_int64 sqlite3_last_insert_rowid(sqlite3*);
    94  func (r *result) lastInsertRowID() (v int64, _ error) {
    95  	return bin.Xsqlite3_last_insert_rowid(r.tls, r.pdb()), nil
    96  }
    97  
    98  // int sqlite3_changes(sqlite3*);
    99  func (r *result) changes() (int, error) {
   100  	v := bin.Xsqlite3_changes(r.tls, r.pdb())
   101  	return int(v), nil
   102  }
   103  
   104  // LastInsertId returns the database's auto-generated ID after, for example, an
   105  // INSERT into a table with primary key.
   106  func (r *result) LastInsertId() (int64, error) {
   107  	if r == nil {
   108  		return 0, nil
   109  	}
   110  
   111  	return r.lastInsertID, nil
   112  }
   113  
   114  // RowsAffected returns the number of rows affected by the query.
   115  func (r *result) RowsAffected() (int64, error) {
   116  	if r == nil {
   117  		return 0, nil
   118  	}
   119  
   120  	return int64(r.rowsAffected), nil
   121  }
   122  
   123  type rows struct {
   124  	*stmt
   125  	columns []string
   126  	rc0     int
   127  	pstmt   unsafe.Pointer
   128  	doStep  bool
   129  }
   130  
   131  func (r *rows) String() string {
   132  	return fmt.Sprintf("&%T@%p{stmt: %p, columns: %v, rc0: %v, pstmt: %#x, doStep: %v}", *r, r, r.stmt, r.columns, r.rc0, r.pstmt, r.doStep)
   133  }
   134  
   135  func newRows(s *stmt, pstmt unsafe.Pointer, rc0 int) (*rows, error) {
   136  	r := &rows{
   137  		stmt:  s,
   138  		pstmt: pstmt,
   139  		rc0:   rc0,
   140  	}
   141  
   142  	n, err := r.columnCount()
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	r.columns = make([]string, n)
   148  	for i := range r.columns {
   149  		if r.columns[i], err = r.columnName(i); err != nil {
   150  			return nil, err
   151  		}
   152  	}
   153  
   154  	return r, nil
   155  }
   156  
   157  // Columns returns the names of the columns. The number of columns of the
   158  // result is inferred from the length of the slice. If a particular column name
   159  // isn't known, an empty string should be returned for that entry.
   160  func (r *rows) Columns() (c []string) {
   161  	if trace {
   162  		defer func() {
   163  			tracer(r, "Columns(): %v", c)
   164  		}()
   165  	}
   166  	return r.columns
   167  }
   168  
   169  // Close closes the rows iterator.
   170  func (r *rows) Close() (err error) {
   171  	if trace {
   172  		defer func() {
   173  			tracer(r, "Close(): %v", err)
   174  		}()
   175  	}
   176  	return r.finalize(r.pstmt)
   177  }
   178  
   179  // Next is called to populate the next row of data into the provided slice. The
   180  // provided slice will be the same size as the Columns() are wide.
   181  //
   182  // Next should return io.EOF when there are no more rows.
   183  func (r *rows) Next(dest []driver.Value) (err error) {
   184  	if trace {
   185  		defer func() {
   186  			tracer(r, "Next(%v): %v", dest, err)
   187  		}()
   188  	}
   189  	rc := r.rc0
   190  	if r.doStep {
   191  		if rc, err = r.step(r.pstmt); err != nil {
   192  			return err
   193  		}
   194  	}
   195  
   196  	r.doStep = true
   197  
   198  	switch rc {
   199  	case bin.XSQLITE_ROW:
   200  		if g, e := len(dest), len(r.columns); g != e {
   201  			return fmt.Errorf("Next(): have %v destination values, expected %v", g, e)
   202  		}
   203  
   204  		for i := range dest {
   205  			ct, err := r.columnType(i)
   206  			if err != nil {
   207  				return err
   208  			}
   209  
   210  			switch ct {
   211  			case bin.XSQLITE_INTEGER:
   212  				v, err := r.columnInt64(i)
   213  				if err != nil {
   214  					return err
   215  				}
   216  
   217  				dest[i] = v
   218  			case bin.XSQLITE_FLOAT:
   219  				v, err := r.columnDouble(i)
   220  				if err != nil {
   221  					return err
   222  				}
   223  
   224  				dest[i] = v
   225  			case bin.XSQLITE_TEXT:
   226  				v, err := r.columnText(i)
   227  				if err != nil {
   228  					return err
   229  				}
   230  
   231  				dest[i] = v
   232  			case bin.XSQLITE_BLOB:
   233  				v, err := r.columnBlob(i)
   234  				if err != nil {
   235  					return err
   236  				}
   237  
   238  				dest[i] = v
   239  			case bin.XSQLITE_NULL:
   240  				dest[i] = nil
   241  			default:
   242  				panic("internal error")
   243  			}
   244  		}
   245  		return nil
   246  	case bin.XSQLITE_DONE:
   247  		return io.EOF
   248  	default:
   249  		return r.errstr(int32(rc))
   250  	}
   251  }
   252  
   253  // int sqlite3_column_bytes(sqlite3_stmt*, int iCol);
   254  func (r *rows) columnBytes(iCol int) (_ int, err error) {
   255  	v := bin.Xsqlite3_column_bytes(r.tls, r.pstmt, int32(iCol))
   256  	return int(v), nil
   257  }
   258  
   259  // const void *sqlite3_column_blob(sqlite3_stmt*, int iCol);
   260  func (r *rows) columnBlob(iCol int) (v []byte, err error) {
   261  	p := bin.Xsqlite3_column_blob(r.tls, r.pstmt, int32(iCol))
   262  	len, err := r.columnBytes(iCol)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	return crt.GoBytesLen((*int8)(p), len), nil
   268  }
   269  
   270  // const unsigned char *sqlite3_column_text(sqlite3_stmt*, int iCol);
   271  func (r *rows) columnText(iCol int) (v string, err error) {
   272  	p := bin.Xsqlite3_column_text(r.tls, r.pstmt, int32(iCol))
   273  	len, err := r.columnBytes(iCol)
   274  	if err != nil {
   275  		return "", err
   276  	}
   277  
   278  	return crt.GoStringLen((*int8)(unsafe.Pointer(p)), len), nil
   279  }
   280  
   281  // double sqlite3_column_double(sqlite3_stmt*, int iCol);
   282  func (r *rows) columnDouble(iCol int) (v float64, err error) {
   283  	v = bin.Xsqlite3_column_double(r.tls, r.pstmt, int32(iCol))
   284  	return v, nil
   285  }
   286  
   287  // sqlite3_int64 sqlite3_column_int64(sqlite3_stmt*, int iCol);
   288  func (r *rows) columnInt64(iCol int) (v int64, err error) {
   289  	v = bin.Xsqlite3_column_int64(r.tls, r.pstmt, int32(iCol))
   290  	return v, nil
   291  }
   292  
   293  // int sqlite3_column_type(sqlite3_stmt*, int iCol);
   294  func (r *rows) columnType(iCol int) (_ int, err error) {
   295  	v := bin.Xsqlite3_column_type(r.tls, r.pstmt, int32(iCol))
   296  	return int(v), nil
   297  }
   298  
   299  // int sqlite3_column_count(sqlite3_stmt *pStmt);
   300  func (r *rows) columnCount() (_ int, err error) {
   301  	v := bin.Xsqlite3_column_count(r.tls, r.pstmt)
   302  	return int(v), nil
   303  }
   304  
   305  // const char *sqlite3_column_name(sqlite3_stmt*, int N);
   306  func (r *rows) columnName(n int) (string, error) {
   307  	p := bin.Xsqlite3_column_name(r.tls, r.pstmt, int32(n))
   308  	return crt.GoString(p), nil
   309  }
   310  
   311  type stmt struct {
   312  	*conn
   313  	allocs []unsafe.Pointer
   314  	psql   *int8
   315  	ppstmt *unsafe.Pointer
   316  	pzTail **int8
   317  }
   318  
   319  func (s *stmt) String() string {
   320  	return fmt.Sprintf("&%T@%p{conn: %p, alloc %v, psql: %#x, ppstmt: %#x, pzTail: %#x}", *s, s, s.conn, s.allocs, s.psql, s.ppstmt, s.pzTail)
   321  }
   322  
   323  func newStmt(c *conn, sql string) (*stmt, error) {
   324  	s := &stmt{conn: c}
   325  	psql, err := s.cString(sql)
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	s.psql = psql
   331  	ppstmt, err := s.malloc(ptrSize)
   332  	if err != nil {
   333  		s.free(unsafe.Pointer(psql))
   334  		return nil, err
   335  	}
   336  
   337  	s.ppstmt = (*unsafe.Pointer)(ppstmt)
   338  	pzTail, err := s.malloc(ptrSize)
   339  	if err != nil {
   340  		s.free(unsafe.Pointer(psql))
   341  		s.free(ppstmt)
   342  		return nil, err
   343  	}
   344  
   345  	s.pzTail = (**int8)(pzTail)
   346  	return s, nil
   347  }
   348  
   349  // Close closes the statement.
   350  //
   351  // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
   352  func (s *stmt) Close() (err error) {
   353  	if trace {
   354  		defer func() {
   355  			tracer(s, "Close(): %v", err)
   356  		}()
   357  	}
   358  	if s.psql != nil {
   359  		err = s.free(unsafe.Pointer(s.psql))
   360  		s.psql = nil
   361  	}
   362  	if s.ppstmt != nil {
   363  		if err2 := s.free(unsafe.Pointer(s.ppstmt)); err2 != nil && err == nil {
   364  			err = err2
   365  		}
   366  		s.ppstmt = nil
   367  	}
   368  	if s.pzTail != nil {
   369  		if err2 := s.free(unsafe.Pointer(s.pzTail)); err2 != nil && err == nil {
   370  			err = err2
   371  		}
   372  		s.pzTail = nil
   373  	}
   374  	for _, v := range s.allocs {
   375  		if err2 := s.free(v); err2 != nil && err == nil {
   376  			err = err2
   377  		}
   378  	}
   379  	s.allocs = nil
   380  	return err
   381  }
   382  
   383  // NumInput returns the number of placeholder parameters.
   384  //
   385  // If NumInput returns >= 0, the sql package will sanity check argument counts
   386  // from callers and return errors to the caller before the statement's Exec or
   387  // Query methods are called.
   388  //
   389  // NumInput may also return -1, if the driver doesn't know its number of
   390  // placeholders. In that case, the sql package will not sanity check Exec or
   391  // Query argument counts.
   392  func (s *stmt) NumInput() (n int) {
   393  	if trace {
   394  		defer func() {
   395  			tracer(s, "NumInput(): %v", n)
   396  		}()
   397  	}
   398  	return -1
   399  }
   400  
   401  // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
   402  func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
   403  	return s.exec(context.Background(), toNamedValues(args))
   404  }
   405  
   406  func (s *stmt) exec(ctx context.Context, args []namedValue) (r driver.Result, err error) {
   407  	if trace {
   408  		defer func(args []namedValue) {
   409  			tracer(s, "Exec(%v): (%v, %v)", args, r, err)
   410  		}(args)
   411  	}
   412  
   413  	var pstmt unsafe.Pointer
   414  
   415  	donech := make(chan struct{})
   416  	defer close(donech)
   417  	go func() {
   418  		select {
   419  		case <-ctx.Done():
   420  			if pstmt != nil {
   421  				s.interrupt(s.pdb())
   422  			}
   423  		case <-donech:
   424  		}
   425  	}()
   426  
   427  	for psql := s.psql; *psql != 0; psql = *s.pzTail {
   428  		if err := s.prepareV2(psql); err != nil {
   429  			return nil, err
   430  		}
   431  
   432  		pstmt = *s.ppstmt
   433  		if pstmt == nil {
   434  			continue
   435  		}
   436  
   437  		n, err := s.bindParameterCount(pstmt)
   438  		if err != nil {
   439  			return nil, err
   440  		}
   441  
   442  		if n != 0 {
   443  			if err = s.bind(pstmt, n, args); err != nil {
   444  				return nil, err
   445  			}
   446  		}
   447  
   448  		rc, err := s.step(pstmt)
   449  		if err != nil {
   450  			s.finalize(pstmt)
   451  			return nil, err
   452  		}
   453  
   454  		switch rc & 0xff {
   455  		case bin.XSQLITE_DONE, bin.XSQLITE_ROW:
   456  			if err := s.finalize(pstmt); err != nil {
   457  				return nil, err
   458  			}
   459  		default:
   460  			err = s.errstr(int32(rc))
   461  			s.finalize(pstmt)
   462  			return nil, err
   463  		}
   464  	}
   465  	return newResult(s)
   466  }
   467  
   468  func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
   469  	return s.query(context.Background(), toNamedValues(args))
   470  }
   471  
   472  func (s *stmt) query(ctx context.Context, args []namedValue) (r driver.Rows, err error) {
   473  	if trace {
   474  		defer func(args []namedValue) {
   475  			tracer(s, "Query(%v): (%v, %v)", args, r, err)
   476  		}(args)
   477  	}
   478  
   479  	var pstmt, rowStmt unsafe.Pointer
   480  	var rc0 int
   481  
   482  	donech := make(chan struct{})
   483  	defer close(donech)
   484  	go func() {
   485  		select {
   486  		case <-ctx.Done():
   487  			if pstmt != nil {
   488  				s.interrupt(s.pdb())
   489  			}
   490  		case <-donech:
   491  		}
   492  	}()
   493  
   494  	for psql := s.psql; *psql != 0; psql = *s.pzTail {
   495  		if err := s.prepareV2(psql); err != nil {
   496  			return nil, err
   497  		}
   498  
   499  		pstmt = *s.ppstmt
   500  		if pstmt == nil {
   501  			continue
   502  		}
   503  
   504  		n, err := s.bindParameterCount(pstmt)
   505  		if err != nil {
   506  			return nil, err
   507  		}
   508  
   509  		if n != 0 {
   510  			if err = s.bind(pstmt, n, args); err != nil {
   511  				return nil, err
   512  			}
   513  		}
   514  
   515  		rc, err := s.step(pstmt)
   516  		if err != nil {
   517  			s.finalize(pstmt)
   518  			return nil, err
   519  		}
   520  
   521  		switch rc {
   522  		case bin.XSQLITE_ROW:
   523  			if rowStmt != nil {
   524  				if err := s.finalize(pstmt); err != nil {
   525  					return nil, err
   526  				}
   527  
   528  				return nil, fmt.Errorf("query contains multiple select statements")
   529  			}
   530  
   531  			rowStmt = pstmt
   532  			rc0 = rc
   533  		case bin.XSQLITE_DONE:
   534  			if rowStmt == nil {
   535  				rc0 = rc
   536  			}
   537  		default:
   538  			err = s.errstr(int32(rc))
   539  			s.finalize(pstmt)
   540  			return nil, err
   541  		}
   542  	}
   543  	return newRows(s, rowStmt, rc0)
   544  }
   545  
   546  // int sqlite3_bind_double(sqlite3_stmt*, int, double);
   547  func (s *stmt) bindDouble(pstmt unsafe.Pointer, idx1 int, value float64) (err error) {
   548  	if rc := bin.Xsqlite3_bind_double(s.tls, pstmt, int32(idx1), value); rc != 0 {
   549  		return s.errstr(rc)
   550  	}
   551  
   552  	return nil
   553  }
   554  
   555  // int sqlite3_bind_int(sqlite3_stmt*, int, int);
   556  func (s *stmt) bindInt(pstmt unsafe.Pointer, idx1, value int) (err error) {
   557  	if rc := bin.Xsqlite3_bind_int(s.tls, pstmt, int32(idx1), int32(value)); rc != bin.XSQLITE_OK {
   558  		return s.errstr(rc)
   559  	}
   560  
   561  	return nil
   562  }
   563  
   564  // int sqlite3_bind_int64(sqlite3_stmt*, int, sqlite3_int64);
   565  func (s *stmt) bindInt64(pstmt unsafe.Pointer, idx1 int, value int64) (err error) {
   566  	if rc := bin.Xsqlite3_bind_int64(s.tls, pstmt, int32(idx1), value); rc != bin.XSQLITE_OK {
   567  		return s.errstr(rc)
   568  	}
   569  
   570  	return nil
   571  }
   572  
   573  // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
   574  func (s *stmt) bindBlob(pstmt unsafe.Pointer, idx1 int, value []byte) (err error) {
   575  	p, err := s.malloc(len(value))
   576  	if err != nil {
   577  		return err
   578  	}
   579  
   580  	s.allocs = append(s.allocs, p)
   581  	crt.CopyBytes(p, value, false)
   582  	if rc := bin.Xsqlite3_bind_blob(s.tls, pstmt, int32(idx1), p, int32(len(value)), nil); rc != bin.XSQLITE_OK {
   583  		return s.errstr(rc)
   584  	}
   585  
   586  	return nil
   587  }
   588  
   589  // int sqlite3_bind_text(sqlite3_stmt*,int,const char*,int,void(*)(void*));
   590  func (s *stmt) bindText(pstmt unsafe.Pointer, idx1 int, value string) (err error) {
   591  	p, err := s.cString(value)
   592  	if err != nil {
   593  		return err
   594  	}
   595  
   596  	s.allocs = append(s.allocs, unsafe.Pointer(p))
   597  	if rc := bin.Xsqlite3_bind_text(s.tls, pstmt, int32(idx1), p, int32(len(value)), nil); rc != bin.XSQLITE_OK {
   598  		return s.errstr(rc)
   599  	}
   600  
   601  	return nil
   602  }
   603  
   604  func (s *stmt) bind(pstmt unsafe.Pointer, n int, args []namedValue) error {
   605  	for i := 1; i <= n; i++ {
   606  		name, err := s.bindParameterName(pstmt, i)
   607  		if err != nil {
   608  			return err
   609  		}
   610  
   611  		var v namedValue
   612  		for _, v = range args {
   613  			if name != "" {
   614  				// sqlite supports '$', '@' and ':' prefixes for string
   615  				// identifiers and '?' for numeric, so we cannot
   616  				// combine different prefixes with the same name
   617  				// because `database/sql` requires variable names
   618  				// to start with a letter
   619  				if name[1:] == v.Name[:] {
   620  					break
   621  				}
   622  			} else {
   623  				if v.Ordinal == i {
   624  					break
   625  				}
   626  			}
   627  		}
   628  
   629  		if v.Ordinal == 0 {
   630  			if name != "" {
   631  				return fmt.Errorf("missing named argument %q", name[1:])
   632  			}
   633  
   634  			return fmt.Errorf("missing argument with %d index", i)
   635  		}
   636  
   637  		switch x := v.Value.(type) {
   638  		case int64:
   639  			if err := s.bindInt64(pstmt, i, x); err != nil {
   640  				return err
   641  			}
   642  		case float64:
   643  			if err := s.bindDouble(pstmt, i, x); err != nil {
   644  				return err
   645  			}
   646  		case bool:
   647  			v := 0
   648  			if x {
   649  				v = 1
   650  			}
   651  			if err := s.bindInt(pstmt, i, v); err != nil {
   652  				return err
   653  			}
   654  		case []byte:
   655  			if err := s.bindBlob(pstmt, i, x); err != nil {
   656  				return err
   657  			}
   658  		case string:
   659  			if err := s.bindText(pstmt, i, x); err != nil {
   660  				return err
   661  			}
   662  		case time.Time:
   663  			if err := s.bindText(pstmt, i, x.String()); err != nil {
   664  				return err
   665  			}
   666  		default:
   667  			return fmt.Errorf("invalid driver.Value type %T", x)
   668  		}
   669  	}
   670  	return nil
   671  }
   672  
   673  // int sqlite3_bind_parameter_count(sqlite3_stmt*);
   674  func (s *stmt) bindParameterCount(pstmt unsafe.Pointer) (_ int, err error) {
   675  	r := bin.Xsqlite3_bind_parameter_count(s.tls, pstmt)
   676  	return int(r), nil
   677  }
   678  
   679  // const char *sqlite3_bind_parameter_name(sqlite3_stmt*, int);
   680  func (s *stmt) bindParameterName(pstmt unsafe.Pointer, i int) (string, error) {
   681  	p := bin.Xsqlite3_bind_parameter_name(s.tls, pstmt, int32(i))
   682  	return crt.GoString(p), nil
   683  }
   684  
   685  // int sqlite3_finalize(sqlite3_stmt *pStmt);
   686  func (s *stmt) finalize(pstmt unsafe.Pointer) error {
   687  	if rc := bin.Xsqlite3_finalize(s.tls, pstmt); rc != bin.XSQLITE_OK {
   688  		return s.errstr(rc)
   689  	}
   690  
   691  	return nil
   692  }
   693  
   694  // int sqlite3_step(sqlite3_stmt*);
   695  func (s *stmt) step(pstmt unsafe.Pointer) (int, error) {
   696  	r := bin.Xsqlite3_step(s.tls, pstmt)
   697  	return int(r), nil
   698  }
   699  
   700  // int sqlite3_prepare_v2(
   701  //   sqlite3 *db,            /* Database handle */
   702  //   const char *zSql,       /* SQL statement, UTF-8 encoded */
   703  //   int nByte,              /* Maximum length of zSql in bytes. */
   704  //   sqlite3_stmt **ppStmt,  /* OUT: Statement handle */
   705  //   const char **pzTail     /* OUT: Pointer to unused portion of zSql */
   706  // );
   707  func (s *stmt) prepareV2(zSQL *int8) error {
   708  	if rc := bin.Xsqlite3_prepare_v2(s.tls, s.pdb(), zSQL, -1, s.ppstmt, s.pzTail); rc != bin.XSQLITE_OK {
   709  		return s.errstr(rc)
   710  	}
   711  
   712  	return nil
   713  }
   714  
   715  type tx struct {
   716  	*conn
   717  }
   718  
   719  func (t *tx) String() string { return fmt.Sprintf("&%T@%p{conn: %p}", *t, t, t.conn) }
   720  
   721  func newTx(c *conn) (*tx, error) {
   722  	t := &tx{conn: c}
   723  	if err := t.exec(context.Background(), "begin"); err != nil {
   724  		return nil, err
   725  	}
   726  
   727  	return t, nil
   728  }
   729  
   730  // Commit implements driver.Tx.
   731  func (t *tx) Commit() (err error) {
   732  	if trace {
   733  		defer func() {
   734  			tracer(t, "Commit(): %v", err)
   735  		}()
   736  	}
   737  	return t.exec(context.Background(), "commit")
   738  }
   739  
   740  // Rollback implements driver.Tx.
   741  func (t *tx) Rollback() (err error) {
   742  	if trace {
   743  		defer func() {
   744  			tracer(t, "Rollback(): %v", err)
   745  		}()
   746  	}
   747  	return t.exec(context.Background(), "rollback")
   748  }
   749  
   750  // int sqlite3_exec(
   751  //   sqlite3*,                                  /* An open database */
   752  //   const char *sql,                           /* SQL to be evaluated */
   753  //   int (*callback)(void*,int,char**,char**),  /* Callback function */
   754  //   void *,                                    /* 1st argument to callback */
   755  //   char **errmsg                              /* Error msg written here */
   756  // );
   757  func (t *tx) exec(ctx context.Context, sql string) (err error) {
   758  	psql, err := t.cString(sql)
   759  	if err != nil {
   760  		return err
   761  	}
   762  
   763  	defer t.free(unsafe.Pointer(psql))
   764  
   765  	// TODO: use t.conn.ExecContext() instead
   766  	donech := make(chan struct{})
   767  	defer close(donech)
   768  	go func() {
   769  		select {
   770  		case <-ctx.Done():
   771  			t.interrupt(t.pdb())
   772  		case <-donech:
   773  		}
   774  	}()
   775  
   776  	if rc := bin.Xsqlite3_exec(t.tls, t.pdb(), psql, nil, nil, nil); rc != bin.XSQLITE_OK {
   777  		return t.errstr(rc)
   778  	}
   779  
   780  	return nil
   781  }
   782  
   783  type conn struct {
   784  	*Driver
   785  	ppdb **bin.Xsqlite3
   786  	tls  *crt.TLS
   787  }
   788  
   789  func (c *conn) String() string {
   790  	return fmt.Sprintf("&%T@%p{sqlite: %p, Thread: %p, ppdb: %#x}", *c, c, c.Driver, c.tls, c.ppdb)
   791  }
   792  
   793  func newConn(s *Driver, name string) (_ *conn, err error) {
   794  	c := &conn{Driver: s}
   795  
   796  	defer func() {
   797  		if err != nil {
   798  			c.close()
   799  		}
   800  	}()
   801  
   802  	c.Lock()
   803  
   804  	defer c.Unlock()
   805  
   806  	c.tls = crt.NewTLS()
   807  	if err = c.openV2(
   808  		name,
   809  		bin.XSQLITE_OPEN_READWRITE|bin.XSQLITE_OPEN_CREATE|
   810  			bin.XSQLITE_OPEN_FULLMUTEX|
   811  			bin.XSQLITE_OPEN_URI,
   812  	); err != nil {
   813  		return nil, err
   814  	}
   815  
   816  	if err = c.extendedResultCodes(true); err != nil {
   817  		return nil, err
   818  	}
   819  
   820  	return c, nil
   821  }
   822  
   823  // Prepare returns a prepared statement, bound to this connection.
   824  func (c *conn) Prepare(query string) (s driver.Stmt, err error) {
   825  	return c.prepare(context.Background(), query)
   826  }
   827  
   828  func (c *conn) prepare(ctx context.Context, query string) (s driver.Stmt, err error) {
   829  	if trace {
   830  		defer func() {
   831  			tracer(c, "Prepare(%s): (%v, %v)", query, s, err)
   832  		}()
   833  	}
   834  	return newStmt(c, query)
   835  }
   836  
   837  // Close invalidates and potentially stops any current prepared statements and
   838  // transactions, marking this connection as no longer in use.
   839  //
   840  // Because the sql package maintains a free pool of connections and only calls
   841  // Close when there's a surplus of idle connections, it shouldn't be necessary
   842  // for drivers to do their own connection caching.
   843  func (c *conn) Close() (err error) {
   844  	if trace {
   845  		defer func() {
   846  			tracer(c, "Close(): %v", err)
   847  		}()
   848  	}
   849  	return c.close()
   850  }
   851  
   852  // Begin starts a transaction.
   853  func (c *conn) Begin() (driver.Tx, error) {
   854  	return c.begin(context.Background(), txOptions{})
   855  }
   856  
   857  // copy of driver.TxOptions
   858  type txOptions struct {
   859  	Isolation int // driver.IsolationLevel
   860  	ReadOnly  bool
   861  }
   862  
   863  func (c *conn) begin(ctx context.Context, opts txOptions) (t driver.Tx, err error) {
   864  	if trace {
   865  		defer func() {
   866  			tracer(c, "BeginTx(): (%v, %v)", t, err)
   867  		}()
   868  	}
   869  	return newTx(c)
   870  }
   871  
   872  // Execer is an optional interface that may be implemented by a Conn.
   873  //
   874  // If a Conn does not implement Execer, the sql package's DB.Exec will first
   875  // prepare a query, execute the statement, and then close the statement.
   876  //
   877  // Exec may return ErrSkip.
   878  func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
   879  	return c.exec(context.Background(), query, toNamedValues(args))
   880  }
   881  
   882  func (c *conn) exec(ctx context.Context, query string, args []namedValue) (r driver.Result, err error) {
   883  	if trace {
   884  		defer func() {
   885  			tracer(c, "ExecContext(%s, %v): (%v, %v)", query, args, r, err)
   886  		}()
   887  	}
   888  
   889  	s, err := c.prepare(ctx, query)
   890  	if err != nil {
   891  		return nil, err
   892  	}
   893  
   894  	defer func() {
   895  		if err2 := s.Close(); err2 != nil && err == nil {
   896  			err = err2
   897  		}
   898  	}()
   899  
   900  	return s.(*stmt).exec(ctx, args)
   901  }
   902  
   903  // copy of driver.NameValue
   904  type namedValue struct {
   905  	Name    string
   906  	Ordinal int
   907  	Value   driver.Value
   908  }
   909  
   910  // toNamedValues converts []driver.Value to []namedValue
   911  func toNamedValues(vals []driver.Value) []namedValue {
   912  	args := make([]namedValue, 0, len(vals))
   913  	for i, val := range vals {
   914  		args = append(args, namedValue{Value: val, Ordinal: i + 1})
   915  	}
   916  	return args
   917  }
   918  
   919  // Queryer is an optional interface that may be implemented by a Conn.
   920  //
   921  // If a Conn does not implement Queryer, the sql package's DB.Query will first
   922  // prepare a query, execute the statement, and then close the statement.
   923  //
   924  // Query may return ErrSkip.
   925  func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
   926  	return c.query(context.Background(), query, toNamedValues(args))
   927  }
   928  
   929  func (c *conn) query(ctx context.Context, query string, args []namedValue) (r driver.Rows, err error) {
   930  	if trace {
   931  		defer func() {
   932  			tracer(c, "Query(%s, %v): (%v, %v)", query, args, r, err)
   933  		}()
   934  	}
   935  	s, err := c.prepare(ctx, query)
   936  	if err != nil {
   937  		return nil, err
   938  	}
   939  
   940  	defer func() {
   941  		if err2 := s.Close(); err2 != nil && err == nil {
   942  			err = err2
   943  		}
   944  	}()
   945  
   946  	return s.(*stmt).query(ctx, args)
   947  }
   948  
   949  func (c *conn) pdb() *bin.Xsqlite3 { return *c.ppdb }
   950  
   951  // int sqlite3_extended_result_codes(sqlite3*, int onoff);
   952  func (c *conn) extendedResultCodes(on bool) (err error) {
   953  	var v int32
   954  	if on {
   955  		v = 1
   956  	}
   957  	if rc := bin.Xsqlite3_extended_result_codes(c.tls, c.pdb(), v); rc != bin.XSQLITE_OK {
   958  		return c.errstr(rc)
   959  	}
   960  
   961  	return nil
   962  }
   963  
   964  // void *sqlite3_malloc(int);
   965  func (c *conn) malloc(n int) (r unsafe.Pointer, err error) {
   966  	if n > math.MaxInt32 {
   967  		panic("internal error")
   968  	}
   969  
   970  	r = bin.Xsqlite3_malloc(c.tls, int32(n))
   971  	if r == nil {
   972  		return nil, fmt.Errorf("malloc(%v) failed", n)
   973  	}
   974  
   975  	return r, nil
   976  }
   977  
   978  func (c *conn) cString(s string) (*int8, error) {
   979  	n := len(s)
   980  	p, err := c.malloc(n + 1)
   981  	if err != nil {
   982  		return nil, err
   983  	}
   984  
   985  	crt.CopyString(p, s, true)
   986  	return (*int8)(p), nil
   987  }
   988  
   989  // int sqlite3_open_v2(
   990  //   const char *filename,   /* Database filename (UTF-8) */
   991  //   sqlite3 **ppDb,         /* OUT: SQLite db handle */
   992  //   int flags,              /* Flags */
   993  //   const char *zVfs        /* Name of VFS module to use */
   994  // );
   995  func (c *conn) openV2(name string, flags int32) error {
   996  	filename, err := c.cString(name)
   997  	if err != nil {
   998  		return err
   999  	}
  1000  
  1001  	defer c.free(unsafe.Pointer(filename))
  1002  
  1003  	ppdb, err := c.malloc(ptrSize)
  1004  	if err != nil {
  1005  		return err
  1006  	}
  1007  
  1008  	c.ppdb = (**bin.Xsqlite3)(ppdb)
  1009  	if rc := bin.Xsqlite3_open_v2(c.tls, filename, c.ppdb, flags, nil); rc != bin.XSQLITE_OK {
  1010  		return c.errstr(rc)
  1011  	}
  1012  
  1013  	return nil
  1014  }
  1015  
  1016  // const char *sqlite3_errstr(int);
  1017  func (c *conn) errstr(rc int32) (err error) {
  1018  	p := bin.Xsqlite3_errstr(c.tls, rc)
  1019  	str := crt.GoString(p)
  1020  	p = bin.Xsqlite3_errmsg(c.tls, c.pdb())
  1021  
  1022  	switch msg := crt.GoString(p); {
  1023  	case msg == str:
  1024  		return fmt.Errorf("%s (%v)", str, rc)
  1025  	default:
  1026  		return fmt.Errorf("%s: %s (%v)", str, msg, rc)
  1027  	}
  1028  }
  1029  
  1030  // int sqlite3_close_v2(sqlite3*);
  1031  func (c *conn) closeV2() (err error) {
  1032  	if rc := bin.Xsqlite3_close_v2(c.tls, c.pdb()); rc != bin.XSQLITE_OK {
  1033  		return c.errstr(rc)
  1034  	}
  1035  
  1036  	err = c.free(unsafe.Pointer(c.ppdb))
  1037  	c.ppdb = nil
  1038  	return err
  1039  }
  1040  
  1041  // void sqlite3_free(void*);
  1042  func (c *conn) free(p unsafe.Pointer) (err error) {
  1043  	bin.Xsqlite3_free(c.tls, p)
  1044  	return nil
  1045  }
  1046  
  1047  // void sqlite3_interrupt(sqlite3*);
  1048  func (c *conn) interrupt(pdb *bin.Xsqlite3) (err error) {
  1049  	bin.Xsqlite3_interrupt(c.tls, pdb)
  1050  	return nil
  1051  }
  1052  
  1053  func (c *conn) close() (err error) {
  1054  	c.Lock()
  1055  
  1056  	defer c.Unlock()
  1057  
  1058  	if c.ppdb != nil {
  1059  		err = c.closeV2()
  1060  	}
  1061  	return err
  1062  }
  1063  
  1064  // Driver implements database/sql/driver.Driver.
  1065  type Driver struct {
  1066  	sync.Mutex
  1067  }
  1068  
  1069  func newDrv() *Driver { return &Driver{} }
  1070  
  1071  // Open returns a new connection to the database.  The name is a string in a
  1072  // driver-specific format.
  1073  //
  1074  // Open may return a cached connection (one previously closed), but doing so is
  1075  // unnecessary; the sql package maintains a pool of idle connections for
  1076  // efficient re-use.
  1077  //
  1078  // The returned connection is only used by one goroutine at a time.
  1079  func (s *Driver) Open(name string) (c driver.Conn, err error) {
  1080  	if trace {
  1081  		defer func() {
  1082  			tracer(s, "Open(%s): (%v, %v)", name, c, err)
  1083  		}()
  1084  	}
  1085  	return newConn(s, name)
  1086  }