github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/starlib/starlarksql/sql.go (about) 1 // Copyright 2021 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package sql provides an interface to conntect to SQL databases. 6 package starlarksql 7 8 import ( 9 "database/sql" 10 "database/sql/driver" 11 "fmt" 12 "net/url" 13 "sort" 14 "strings" 15 "time" 16 17 "github.com/emcfarlane/larking/starlib/starext" 18 "github.com/emcfarlane/larking/starlib/starlarkerrors" 19 "github.com/emcfarlane/larking/starlib/starlarkthread" 20 starlarktime "go.starlark.net/lib/time" 21 "go.starlark.net/starlark" 22 "go.starlark.net/starlarkstruct" 23 "gocloud.dev/mysql" 24 "gocloud.dev/postgres" 25 ) 26 27 func NewModule() *starlarkstruct.Module { 28 return &starlarkstruct.Module{ 29 Name: "sql", 30 Members: starlark.StringDict{ 31 "open": starext.MakeBuiltin("sql.open", Open), 32 33 // sql errors 34 "err_conn_done": starlarkerrors.NewError(sql.ErrConnDone), 35 "err_no_rows": starlarkerrors.NewError(sql.ErrNoRows), 36 "err_tx_done": starlarkerrors.NewError(sql.ErrTxDone), 37 }, 38 } 39 } 40 41 // genQueryOptions generates standard query options. 42 func genQueryOptions(q url.Values) string { 43 if s := q.Encode(); s != "" { 44 return "?" + s 45 } 46 return "" 47 } 48 49 // genOpaque generates a opaque file path DSN from the passed URL. 50 func genOpaque(u *url.URL) (string, error) { 51 if u.Opaque == "" { 52 return "", fmt.Errorf("error missing path") 53 } 54 return u.Opaque + genQueryOptions(u.Query()), nil 55 } 56 57 func Open(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 58 var name string 59 if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 1, &name); err != nil { 60 return nil, err 61 } 62 63 u, err := url.Parse(name) 64 if err != nil { 65 return nil, err 66 } 67 68 ctx := starlarkthread.GetContext(thread) 69 70 var db *sql.DB 71 switch { 72 case strings.HasSuffix(u.Scheme, "mysql"): 73 db, err = mysql.Open(ctx, name) 74 case strings.HasSuffix(u.Scheme, "postgres"): 75 db, err = postgres.Open(ctx, name) 76 case u.Scheme == "sqlite": 77 // build dsn 78 dsn, derr := genOpaque(u) 79 if derr != nil { 80 return nil, derr 81 } 82 83 db, err = sql.Open("sqlite", dsn) 84 85 default: 86 return nil, fmt.Errorf("unsupported database %s", u.Scheme) 87 } 88 if err != nil { 89 return nil, err 90 } 91 92 v := NewDB(name, db) 93 if err := starlarkthread.AddResource(thread, v); err != nil { 94 return nil, err 95 } 96 return v, nil 97 } 98 99 type DB struct { 100 name string 101 db *sql.DB 102 103 frozen bool 104 } 105 106 func NewDB(name string, db *sql.DB) *DB { return &DB{name: name, db: db} } 107 func (db *DB) Close() error { return db.db.Close() } 108 109 func (v *DB) String() string { return fmt.Sprintf("<db %q>", v.name) } 110 func (v *DB) Type() string { return "sql.db" } 111 func (v *DB) Freeze() { v.frozen = true } // immutable? 112 func (v *DB) Truth() starlark.Bool { return v.db != nil } 113 func (v *DB) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) } 114 115 type dbAttr func(v *DB) starlark.Value 116 117 var dbAttrs = map[string]dbAttr{ 118 "exec": func(v *DB) starlark.Value { return starext.MakeMethod(v, "exec", v.exec) }, 119 "query": func(v *DB) starlark.Value { return starext.MakeMethod(v, "query", v.query) }, 120 "query_row": func(v *DB) starlark.Value { return starext.MakeMethod(v, "query_row", v.queryRow) }, 121 "ping": func(v *DB) starlark.Value { return starext.MakeMethod(v, "ping", v.ping) }, 122 "close": func(v *DB) starlark.Value { return starext.MakeMethod(v, "close", v.close) }, 123 } 124 125 func (v *DB) Attr(name string) (starlark.Value, error) { 126 if a := dbAttrs[name]; a != nil { 127 return a(v), nil 128 } 129 return nil, nil 130 } 131 func (v *DB) AttrNames() []string { 132 names := make([]string, 0, len(dbAttrs)) 133 for name := range dbAttrs { 134 names = append(names, name) 135 } 136 sort.Strings(names) 137 return names 138 } 139 140 type Result struct { 141 result sql.Result 142 } 143 144 func (r *Result) String() string { return fmt.Sprintf("<result %t>", r.result != nil) } 145 func (r *Result) Type() string { return "sql.result" } 146 func (r *Result) Freeze() {} // immutable 147 func (r *Result) Truth() starlark.Bool { return r.result != nil } 148 func (r *Result) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", r.Type()) } 149 func (r *Result) AttrNames() []string { return []string{"last_insert_id", "rows_affected"} } 150 func (r *Result) Attr(name string) (starlark.Value, error) { 151 switch name { 152 case "last_insert_id": 153 i, err := r.result.LastInsertId() 154 if err != nil { 155 return nil, err 156 } 157 return starlark.MakeInt64(i), nil 158 case "rows_affected": 159 i, err := r.result.RowsAffected() 160 if err != nil { 161 return nil, err 162 } 163 return starlark.MakeInt64(i), nil 164 default: 165 return nil, nil 166 } 167 } 168 169 func makeArgs(args starlark.Tuple) ([]interface{}, error) { 170 // translate arg types 171 xs := make([]interface{}, len(args)) 172 for i, arg := range args { 173 switch arg := arg.(type) { 174 case starlark.NoneType: 175 xs[i] = nil 176 case starlark.Bool: 177 xs[i] = bool(arg) 178 case starlark.String: 179 xs[i] = string(arg) 180 case starlark.Bytes: 181 xs[i] = []byte(arg) 182 case starlark.Int: 183 x, ok := arg.Uint64() 184 if !ok { 185 return nil, fmt.Errorf("invalid arg int too larg: %v", arg.String()) 186 } 187 xs[i] = x 188 case starlark.Float: 189 xs[i] = float64(arg) 190 case starlarktime.Time: 191 xs[i] = time.Time(arg) 192 case driver.Valuer: 193 x, err := arg.Value() 194 if err != nil { 195 return nil, err 196 } 197 xs[i] = x 198 default: 199 return nil, fmt.Errorf("invalid arg type: %v", arg.Type()) 200 } 201 } 202 return xs, nil 203 } 204 205 //func dbBeginTx(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 206 // return nil, nil // TODO: Create struct TX. 207 //} 208 209 func (v *DB) exec(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 210 queryArgs := args 211 if len(args) > 1 { 212 queryArgs = args[:1] 213 } 214 var query string 215 if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil { 216 return nil, err 217 } 218 219 dbArgs, err := makeArgs(args[1:]) 220 if err != nil { 221 return nil, err 222 } 223 224 ctx := starlarkthread.GetContext(thread) 225 result, err := v.db.ExecContext(ctx, query, dbArgs...) 226 if err != nil { 227 return nil, err 228 } 229 return &Result{result: result}, nil 230 231 } 232 233 func (v *DB) query(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 234 queryArgs := args 235 if len(args) > 1 { 236 queryArgs = args[:1] 237 } 238 var query string 239 if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil { 240 return nil, err 241 } 242 243 dbArgs, err := makeArgs(args[1:]) 244 if err != nil { 245 return nil, err 246 } 247 248 ctx := starlarkthread.GetContext(thread) 249 rows, err := v.db.QueryContext(ctx, query, dbArgs...) 250 if err != nil { 251 return nil, err 252 } 253 254 cols, err := rows.ColumnTypes() 255 if err != nil { 256 return nil, err 257 } 258 columns := make([]string, len(cols)) 259 mapping := make(map[string]int, len(cols)) 260 for i, col := range cols { 261 columns[i] = col.Name() 262 mapping[col.Name()] = i 263 } 264 265 r := &Rows{ 266 columns: columns, 267 mapping: mapping, 268 rows: rows, 269 } 270 if err := starlarkthread.AddResource(thread, r); err != nil { 271 return nil, err 272 } 273 return r, nil 274 } 275 276 func (v *DB) queryRow(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 277 queryArgs := args 278 if len(args) > 1 { 279 queryArgs = args[:1] 280 } 281 var query string 282 if err := starlark.UnpackPositionalArgs(fnname, queryArgs, kwargs, 1, &query); err != nil { 283 return nil, err 284 } 285 286 dbArgs, err := makeArgs(args[1:]) 287 if err != nil { 288 return nil, err 289 } 290 291 ctx := starlarkthread.GetContext(thread) 292 rows, err := v.db.QueryContext(ctx, query, dbArgs...) 293 if err != nil { 294 return nil, err 295 } 296 defer rows.Close() 297 298 cols, err := rows.ColumnTypes() 299 if err != nil { 300 return nil, err 301 } 302 columns := make([]string, len(cols)) 303 for i, col := range cols { 304 columns[i] = col.Name() 305 } 306 307 if !rows.Next() { 308 return nil, sql.ErrNoRows 309 } 310 311 m := make(map[string]int, len(columns)) 312 x := &Row{ 313 mapping: m, 314 values: make([]starlark.Value, len(columns)), 315 } 316 317 dest := make([]interface{}, len(columns)) 318 for i, name := range columns { 319 m[name] = i 320 dest[i] = x.scanAt(i) 321 } 322 323 if err := rows.Scan(dest...); err != nil { 324 return nil, err 325 } 326 return x, nil 327 } 328 329 func (v *DB) ping(thread *starlark.Thread, fnname string, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { 330 if err := starlark.UnpackPositionalArgs(fnname, args, kwargs, 0); err != nil { 331 return nil, err 332 } 333 334 ctx := starlarkthread.GetContext(thread) 335 if err := v.db.PingContext(ctx); err != nil { 336 return nil, err 337 } 338 return starlark.None, nil 339 } 340 341 func (v *DB) close(_ *starlark.Thread, fnname string, _ starlark.Tuple, _ []starlark.Tuple) (starlark.Value, error) { 342 if err := v.db.Close(); err != nil { 343 return nil, err 344 } 345 return starlark.None, nil 346 } 347 348 type Rows struct { 349 columns []string 350 mapping map[string]int 351 rows *sql.Rows 352 353 frozen bool 354 iterErr error 355 closeErr error 356 } 357 358 func (v *Rows) Close() error { 359 v.Freeze() 360 return v.closeErr 361 } 362 func (v *Rows) String() string { return fmt.Sprintf("<rows %s>", strings.Join(v.columns, ", ")) } 363 func (v *Rows) Type() string { return "sql.rows" } 364 func (v *Rows) Freeze() { 365 if !v.frozen { 366 v.closeErr = v.rows.Close() 367 } 368 v.frozen = true 369 } 370 func (v *Rows) Truth() starlark.Bool { return v.rows != nil } 371 func (v *Rows) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) } 372 373 func (v *Rows) Iterate() starlark.Iterator { 374 return v 375 } 376 377 func (v *Rows) Next(p *starlark.Value) bool { 378 if ok := v.rows.Next(); !ok { 379 return false 380 } 381 382 x := &Row{ 383 mapping: v.mapping, 384 values: make([]starlark.Value, len(v.columns)), 385 } 386 387 dest := make([]interface{}, len(v.columns)) 388 for i := range v.columns { 389 dest[i] = x.scanAt(i) 390 } 391 392 v.iterErr = v.rows.Scan(dest...) 393 *p = x 394 return v.iterErr == nil 395 } 396 func (v *Rows) Done() { 397 v.closeErr = v.rows.Close() 398 v.frozen = true 399 } 400 401 type Row struct { 402 mapping map[string]int 403 values []starlark.Value 404 } 405 406 func (v *Row) String() string { return fmt.Sprintf("<row %q>", strings.Join(v.AttrNames(), ", ")) } 407 func (v *Row) Type() string { return "sql.row" } 408 func (v *Row) Freeze() {} // immutable 409 func (v *Row) Truth() starlark.Bool { return len(v.values) > 0 } 410 func (v *Row) Hash() (uint32, error) { return 0, fmt.Errorf("unhashable type: %s", v.Type()) } 411 412 func (v *Row) Attr(name string) (starlark.Value, error) { 413 if i, ok := v.mapping[name]; ok { 414 return v.values[i], nil 415 } 416 return nil, fmt.Errorf("unknown name") 417 } 418 func (v *Row) AttrNames() []string { 419 names := make([]string, 0, len(v.mapping)) 420 for name := range v.mapping { 421 names = append(names, name) 422 } 423 sort.Strings(names) 424 return names 425 } 426 func (v *Row) Index(i int) starlark.Value { return v.values[i] } 427 func (v *Row) Len() int { return len(v.mapping) } 428 429 type scanFn func(value interface{}) error 430 431 func (f scanFn) Scan(value interface{}) error { return f(value) } 432 433 func (r *Row) scanAt(index int) scanFn { 434 return func(value interface{}) (err error) { 435 var v starlark.Value 436 switch x := value.(type) { 437 case int64: 438 v = starlark.MakeInt64(x) 439 case float64: 440 v = starlark.Float(x) 441 case bool: 442 v = starlark.Bool(x) 443 case []byte: 444 v = starlark.Bytes(string(x)) 445 case string: 446 v = starlark.String(x) 447 case time.Time: 448 v = starlarktime.Time(x) 449 case nil: 450 v = starlark.None 451 default: 452 return fmt.Errorf("unhandled type: %T", value) 453 } 454 r.values[index] = v 455 return 456 } 457 }