github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/starlib/starlarksql/sql.go (about)

     1  // Copyright 2021 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sql provides an interface to conntect to SQL databases.
     6  package starlarksql
     7  
     8  import (
     9  	"database/sql"
    10  	"database/sql/driver"
    11  	"fmt"
    12  	"net/url"
    13  	"sort"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/emcfarlane/larking/starlib/starext"
    18  	"github.com/emcfarlane/larking/starlib/starlarkerrors"
    19  	"github.com/emcfarlane/larking/starlib/starlarkthread"
    20  	starlarktime "go.starlark.net/lib/time"
    21  	"go.starlark.net/starlark"
    22  	"go.starlark.net/starlarkstruct"
    23  	"gocloud.dev/mysql"
    24  	"gocloud.dev/postgres"
    25  )
    26  
    27  func NewModule() *starlarkstruct.Module {
    28  	return &starlarkstruct.Module{
    29  		Name: "sql",
    30  		Members: starlark.StringDict{
    31  			"open": starext.MakeBuiltin("sql.open", Open),
    32  
    33  			// sql errors
    34  			"err_conn_done": starlarkerrors.NewError(sql.ErrConnDone),
    35  			"err_no_rows":   starlarkerrors.NewError(sql.ErrNoRows),
    36  			"err_tx_done":   starlarkerrors.NewError(sql.ErrTxDone),
    37  		},
    38  	}
    39  }
    40  
    41  // genQueryOptions generates standard query options.
    42  func genQueryOptions(q url.Values) string {
    43  	if s := q.Encode(); s != "" {
    44  		return "?" + s
    45  	}
    46  	return ""
    47  }
    48  
    49  // genOpaque generates a opaque file path DSN from the passed URL.
    50  func genOpaque(u *url.URL) (string, error) {
    51  	if u.Opaque == "" {
    52  		return "", fmt.Errorf("error missing path")
    53  	}
    54  	return u.Opaque + genQueryOptions(u.Query()), nil
    55  }
    56  
    57  func Open(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
    58  	var name string
    59  	if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 1, &name); err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	u, err := url.Parse(name)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	ctx := starlarkthread.GetContext(thread)
    69  
    70  	var db *sql.DB
    71  	switch {
    72  	case strings.HasSuffix(u.Scheme, "mysql"):
    73  		db, err = mysql.Open(ctx, name)
    74  	case strings.HasSuffix(u.Scheme, "postgres"):
    75  		db, err = postgres.Open(ctx, name)
    76  	case u.Scheme == "sqlite":
    77  		// build dsn
    78  		dsn, derr := genOpaque(u)
    79  		if derr != nil {
    80  			return nil, derr
    81  		}
    82  
    83  		db, err = sql.Open("sqlite", dsn)
    84  
    85  	default:
    86  		return nil, fmt.Errorf("unsupported database %s", u.Scheme)
    87  	}
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	v := NewDB(name, db)
    93  	if err := starlarkthread.AddResource(thread, v); err != nil {
    94  		return nil, err
    95  	}
    96  	return v, nil
    97  }
    98  
    99  type DB struct {
   100  	name string
   101  	db   *sql.DB
   102  
   103  	frozen bool
   104  }
   105  
   106  func NewDB(name string, db *sql.DB) *DB { return &DB{name: name, db: db} }
   107  func (db *DB) Close() error             { return db.db.Close() }
   108  
   109  func (v *DB) String() string        { return fmt.Sprintf("<db %q>", v.name) }
   110  func (v *DB) Type() string          { return "sql.db" }
   111  func (v *DB) Freeze()               { v.frozen = true } // immutable?
   112  func (v *DB) Truth() starlark.Bool  { return v.db != nil }
   113  func (v *DB) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) }
   114  
   115  type dbAttr func(v *DB) starlark.Value
   116  
   117  var dbAttrs = map[string]dbAttr{
   118  	"exec":      func(v *DB) starlark.Value { return starext.MakeMethod(v, "exec", v.exec) },
   119  	"query":     func(v *DB) starlark.Value { return starext.MakeMethod(v, "query", v.query) },
   120  	"query_row": func(v *DB) starlark.Value { return starext.MakeMethod(v, "query_row", v.queryRow) },
   121  	"ping":      func(v *DB) starlark.Value { return starext.MakeMethod(v, "ping", v.ping) },
   122  	"close":     func(v *DB) starlark.Value { return starext.MakeMethod(v, "close", v.close) },
   123  }
   124  
   125  func (v *DB) Attr(name string) (starlark.Value, error) {
   126  	if a := dbAttrs[name]; a != nil {
   127  		return a(v), nil
   128  	}
   129  	return nil, nil
   130  }
   131  func (v *DB) AttrNames() []string {
   132  	names := make([]string, 0, len(dbAttrs))
   133  	for name := range dbAttrs {
   134  		names = append(names, name)
   135  	}
   136  	sort.Strings(names)
   137  	return names
   138  }
   139  
   140  type Result struct {
   141  	result sql.Result
   142  }
   143  
   144  func (r *Result) String() string        { return fmt.Sprintf("<result %t>", r.result != nil) }
   145  func (r *Result) Type() string          { return "sql.result" }
   146  func (r *Result) Freeze()               {} // immutable
   147  func (r *Result) Truth() starlark.Bool  { return r.result != nil }
   148  func (r *Result) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", r.Type()) }
   149  func (r *Result) AttrNames() []string   { return []string{"last_insert_id", "rows_affected"} }
   150  func (r *Result) Attr(name string) (starlark.Value, error) {
   151  	switch name {
   152  	case "last_insert_id":
   153  		i, err := r.result.LastInsertId()
   154  		if err != nil {
   155  			return nil, err
   156  		}
   157  		return starlark.MakeInt64(i), nil
   158  	case "rows_affected":
   159  		i, err := r.result.RowsAffected()
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		return starlark.MakeInt64(i), nil
   164  	default:
   165  		return nil, nil
   166  	}
   167  }
   168  
   169  func makeArgs(args starlark.Tuple) ([]interface{}, error) {
   170  	// translate arg types
   171  	xs := make([]interface{}, len(args))
   172  	for i, arg := range args {
   173  		switch arg := arg.(type) {
   174  		case starlark.NoneType:
   175  			xs[i] = nil
   176  		case starlark.Bool:
   177  			xs[i] = bool(arg)
   178  		case starlark.String:
   179  			xs[i] = string(arg)
   180  		case starlark.Bytes:
   181  			xs[i] = []byte(arg)
   182  		case starlark.Int:
   183  			x, ok := arg.Uint64()
   184  			if !ok {
   185  				return nil, fmt.Errorf("invalid arg int too larg: %v", arg.String())
   186  			}
   187  			xs[i] = x
   188  		case starlark.Float:
   189  			xs[i] = float64(arg)
   190  		case starlarktime.Time:
   191  			xs[i] = time.Time(arg)
   192  		case driver.Valuer:
   193  			x, err := arg.Value()
   194  			if err != nil {
   195  				return nil, err
   196  			}
   197  			xs[i] = x
   198  		default:
   199  			return nil, fmt.Errorf("invalid arg type: %v", arg.Type())
   200  		}
   201  	}
   202  	return xs, nil
   203  }
   204  
   205  //func dbBeginTx(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   206  //	return nil, nil // TODO: Create struct TX.
   207  //}
   208  
   209  func (v *DB) exec(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   210  	queryArgs := args
   211  	if len(args) > 1 {
   212  		queryArgs = args[:1]
   213  	}
   214  	var query string
   215  	if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	dbArgs, err := makeArgs(args[1:])
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	ctx := starlarkthread.GetContext(thread)
   225  	result, err := v.db.ExecContext(ctx, query, dbArgs...)
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  	return &Result{result: result}, nil
   230  
   231  }
   232  
   233  func (v *DB) query(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   234  	queryArgs := args
   235  	if len(args) > 1 {
   236  		queryArgs = args[:1]
   237  	}
   238  	var query string
   239  	if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	dbArgs, err := makeArgs(args[1:])
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	ctx := starlarkthread.GetContext(thread)
   249  	rows, err := v.db.QueryContext(ctx, query, dbArgs...)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	cols, err := rows.ColumnTypes()
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  	columns := make([]string, len(cols))
   259  	mapping := make(map[string]int, len(cols))
   260  	for i, col := range cols {
   261  		columns[i] = col.Name()
   262  		mapping[col.Name()] = i
   263  	}
   264  
   265  	r := &Rows{
   266  		columns: columns,
   267  		mapping: mapping,
   268  		rows:    rows,
   269  	}
   270  	if err := starlarkthread.AddResource(thread, r); err != nil {
   271  		return nil, err
   272  	}
   273  	return r, nil
   274  }
   275  
   276  func (v *DB) queryRow(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   277  	queryArgs := args
   278  	if len(args) > 1 {
   279  		queryArgs = args[:1]
   280  	}
   281  	var query string
   282  	if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil {
   283  		return nil, err
   284  	}
   285  
   286  	dbArgs, err := makeArgs(args[1:])
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  
   291  	ctx := starlarkthread.GetContext(thread)
   292  	rows, err := v.db.QueryContext(ctx, query, dbArgs...)
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  	defer rows.Close()
   297  
   298  	cols, err := rows.ColumnTypes()
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	columns := make([]string, len(cols))
   303  	for i, col := range cols {
   304  		columns[i] = col.Name()
   305  	}
   306  
   307  	if !rows.Next() {
   308  		return nil, sql.ErrNoRows
   309  	}
   310  
   311  	m := make(map[string]int, len(columns))
   312  	x := &Row{
   313  		mapping: m,
   314  		values:  make([]starlark.Value, len(columns)),
   315  	}
   316  
   317  	dest := make([]interface{}, len(columns))
   318  	for i, name := range columns {
   319  		m[name] = i
   320  		dest[i] = x.scanAt(i)
   321  	}
   322  
   323  	if err := rows.Scan(dest...); err != nil {
   324  		return nil, err
   325  	}
   326  	return x, nil
   327  }
   328  
   329  func (v *DB) ping(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
   330  	if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 0); err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	ctx := starlarkthread.GetContext(thread)
   335  	if err := v.db.PingContext(ctx); err != nil {
   336  		return nil, err
   337  	}
   338  	return starlark.None, nil
   339  }
   340  
   341  func (v *DB) close(_ *starlark.Thread, fnname string, _ starlark.Tuple, _ []starlark.Tuple) (starlark.Value, error) {
   342  	if err := v.db.Close(); err != nil {
   343  		return nil, err
   344  	}
   345  	return starlark.None, nil
   346  }
   347  
   348  type Rows struct {
   349  	columns []string
   350  	mapping map[string]int
   351  	rows    *sql.Rows
   352  
   353  	frozen   bool
   354  	iterErr  error
   355  	closeErr error
   356  }
   357  
   358  func (v *Rows) Close() error {
   359  	v.Freeze()
   360  	return v.closeErr
   361  }
   362  func (v *Rows) String() string { return fmt.Sprintf("<rows %s>", strings.Join(v.columns, ", ")) }
   363  func (v *Rows) Type() string   { return "sql.rows" }
   364  func (v *Rows) Freeze() {
   365  	if !v.frozen {
   366  		v.closeErr = v.rows.Close()
   367  	}
   368  	v.frozen = true
   369  }
   370  func (v *Rows) Truth() starlark.Bool  { return v.rows != nil }
   371  func (v *Rows) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) }
   372  
   373  func (v *Rows) Iterate() starlark.Iterator {
   374  	return v
   375  }
   376  
   377  func (v *Rows) Next(p *starlark.Value) bool {
   378  	if ok := v.rows.Next(); !ok {
   379  		return false
   380  	}
   381  
   382  	x := &Row{
   383  		mapping: v.mapping,
   384  		values:  make([]starlark.Value, len(v.columns)),
   385  	}
   386  
   387  	dest := make([]interface{}, len(v.columns))
   388  	for i := range v.columns {
   389  		dest[i] = x.scanAt(i)
   390  	}
   391  
   392  	v.iterErr = v.rows.Scan(dest...)
   393  	*p = x
   394  	return v.iterErr == nil
   395  }
   396  func (v *Rows) Done() {
   397  	v.closeErr = v.rows.Close()
   398  	v.frozen = true
   399  }
   400  
   401  type Row struct {
   402  	mapping map[string]int
   403  	values  []starlark.Value
   404  }
   405  
   406  func (v *Row) String() string        { return fmt.Sprintf("<row %q>", strings.Join(v.AttrNames(), ", ")) }
   407  func (v *Row) Type() string          { return "sql.row" }
   408  func (v *Row) Freeze()               {} // immutable
   409  func (v *Row) Truth() starlark.Bool  { return len(v.values) > 0 }
   410  func (v *Row) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) }
   411  
   412  func (v *Row) Attr(name string) (starlark.Value, error) {
   413  	if i, ok := v.mapping[name]; ok {
   414  		return v.values[i], nil
   415  	}
   416  	return nil, fmt.Errorf("unknown name")
   417  }
   418  func (v *Row) AttrNames() []string {
   419  	names := make([]string, 0, len(v.mapping))
   420  	for name := range v.mapping {
   421  		names = append(names, name)
   422  	}
   423  	sort.Strings(names)
   424  	return names
   425  }
   426  func (v *Row) Index(i int) starlark.Value { return v.values[i] }
   427  func (v *Row) Len() int                   { return len(v.mapping) }
   428  
   429  type scanFn func(value interface{}) error
   430  
   431  func (f scanFn) Scan(value interface{}) error { return f(value) }
   432  
   433  func (r *Row) scanAt(index int) scanFn {
   434  	return func(value interface{}) (err error) {
   435  		var v starlark.Value
   436  		switch x := value.(type) {
   437  		case int64:
   438  			v = starlark.MakeInt64(x)
   439  		case float64:
   440  			v = starlark.Float(x)
   441  		case bool:
   442  			v = starlark.Bool(x)
   443  		case []byte:
   444  			v = starlark.Bytes(string(x))
   445  		case string:
   446  			v = starlark.String(x)
   447  		case time.Time:
   448  			v = starlarktime.Time(x)
   449  		case nil:
   450  			v = starlark.None
   451  		default:
   452  			return fmt.Errorf("unhandled type: %T", value)
   453  		}
   454  		r.values[index] = v
   455  		return
   456  	}
   457  }