github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-sqlite3/driver/sqlite3.go (about) 1 // Package driver is an sqlite driver. 2 package driver 3 4 import ( 5 "database/sql" 6 "embed" 7 "encoding/base64" 8 "fmt" 9 "io/fs" 10 "strings" 11 12 "github.com/volatiletech/sqlboiler/v4/drivers" 13 "github.com/volatiletech/sqlboiler/v4/importers" 14 _ "modernc.org/sqlite" 15 ) 16 17 //go:embed override 18 var templates embed.FS 19 20 func init() { 21 drivers.RegisterFromInit("sqlite3", &SQLiteDriver{}) 22 } 23 24 // Assemble the db info 25 func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 26 driver := &SQLiteDriver{} 27 return driver.Assemble(config) 28 } 29 30 // SQLiteDriver holds the database connection string and a handle 31 // to the database connection. 32 type SQLiteDriver struct { 33 connStr string 34 dbConn *sql.DB 35 } 36 37 // Templates that should be added/overridden 38 func (s SQLiteDriver) Templates() (map[string]string, error) { 39 tpls := make(map[string]string) 40 fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error { 41 if err != nil { 42 return err 43 } 44 45 if d.IsDir() { 46 return nil 47 } 48 49 b, err := fs.ReadFile(templates, path) 50 if err != nil { 51 return err 52 } 53 tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b) 54 55 return nil 56 }) 57 58 return tpls, nil 59 } 60 61 // Assemble the db info 62 func (s SQLiteDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 63 defer func() { 64 if r := recover(); r != nil && err == nil { 65 dbinfo = nil 66 err = r.(error) 67 } 68 }() 69 70 dbname := config.MustString(drivers.ConfigDBName) 71 whitelist, _ := config.StringSlice(drivers.ConfigWhitelist) 72 blacklist, _ := config.StringSlice(drivers.ConfigBlacklist) 73 concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) 74 75 s.connStr = SQLiteBuildQueryString(dbname) 76 s.dbConn, err = sql.Open("sqlite", s.connStr) 77 if err != nil { 78 return nil, fmt.Errorf("sqlboiler-sqlite failed to connect to database: %w", err) 79 } 80 81 defer func() { 82 if e := s.dbConn.Close(); e != nil { 83 dbinfo = nil 84 err = e 85 } 86 }() 87 88 dbinfo = &drivers.DBInfo{ 89 Dialect: drivers.Dialect{ 90 LQ: '"', 91 RQ: '"', 92 93 UseSchema: false, 94 UseDefaultKeyword: true, 95 UseLastInsertID: false, 96 }, 97 } 98 99 dbinfo.Tables, err = drivers.TablesConcurrently(s, "", whitelist, blacklist, concurrency) 100 if err != nil { 101 return nil, err 102 } 103 104 return dbinfo, err 105 } 106 107 // SQLiteBuildQueryString builds a query string for SQLite. 108 func SQLiteBuildQueryString(file string) string { 109 return "file:" + file + "?_loc=UTC&mode=ro" 110 } 111 112 // Open opens the database connection using the connection string 113 func (s SQLiteDriver) Open() error { 114 var err error 115 116 s.dbConn, err = sql.Open("sqlite3", s.connStr) 117 if err != nil { 118 return err 119 } 120 121 return nil 122 } 123 124 // Close closes the database connection 125 func (s SQLiteDriver) Close() { 126 s.dbConn.Close() 127 } 128 129 // TableNames connects to the sqlite database and 130 // retrieves all table names from sqlite_master 131 func (s SQLiteDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { 132 query := `SELECT name FROM sqlite_master WHERE type='table'` 133 args := []interface{}{} 134 135 if len(whitelist) > 0 { 136 tables := drivers.TablesFromList(whitelist) 137 if len(tables) > 0 { 138 query += fmt.Sprintf(" and tbl_name in (%s)", strings.Repeat(",?", len(tables))[1:]) 139 for _, w := range tables { 140 args = append(args, w) 141 } 142 } 143 } 144 145 if len(blacklist) > 0 { 146 tables := drivers.TablesFromList(blacklist) 147 if len(tables) > 0 { 148 query += fmt.Sprintf(" and tbl_name not in (%s)", strings.Repeat(",?", len(tables))[1:]) 149 for _, b := range tables { 150 args = append(args, b) 151 } 152 } 153 } 154 155 rows, err := s.dbConn.Query(query, args...) 156 157 if err != nil { 158 return nil, err 159 } 160 161 var names []string 162 defer rows.Close() 163 for rows.Next() { 164 var name string 165 if err := rows.Scan(&name); err != nil { 166 return nil, err 167 } 168 if name != "sqlite_sequence" { 169 names = append(names, name) 170 } 171 } 172 173 return names, nil 174 } 175 176 // ViewNames connects to the sqlite database and 177 // retrieves all view names from sqlite_master 178 func (s SQLiteDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) { 179 query := `SELECT name FROM sqlite_master WHERE type='view'` 180 args := []interface{}{} 181 182 if len(whitelist) > 0 { 183 views := drivers.TablesFromList(whitelist) 184 if len(views) > 0 { 185 query += fmt.Sprintf(" and tbl_name in (%s)", strings.Repeat(",?", len(views))[1:]) 186 for _, w := range views { 187 args = append(args, w) 188 } 189 } 190 } 191 192 if len(blacklist) > 0 { 193 views := drivers.TablesFromList(blacklist) 194 if len(views) > 0 { 195 query += fmt.Sprintf(" and tbl_name not in (%s)", strings.Repeat(",?", len(views))[1:]) 196 for _, b := range views { 197 args = append(args, b) 198 } 199 } 200 } 201 202 rows, err := s.dbConn.Query(query, args...) 203 204 if err != nil { 205 return nil, err 206 } 207 208 var names []string 209 defer rows.Close() 210 for rows.Next() { 211 var name string 212 if err := rows.Scan(&name); err != nil { 213 return nil, err 214 } 215 if name != "sqlite_sequence" { 216 names = append(names, name) 217 } 218 } 219 220 return names, nil 221 } 222 223 // ViewCapabilities return what actions are allowed for a view. 224 func (s SQLiteDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) { 225 // Inserts may be allowed with the presence of an INSTEAD OF TRIGGER 226 // but it is not yet implemented. 227 // See: https://www.sqlite.org/lang_createview.html 228 capabilities := drivers.ViewCapabilities{ 229 CanInsert: false, 230 CanUpsert: false, 231 } 232 233 return capabilities, nil 234 } 235 236 func (s SQLiteDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 237 return s.Columns(schema, tableName, whitelist, blacklist) 238 } 239 240 type sqliteIndex struct { 241 SeqNum int 242 Unique int 243 Partial int 244 Name string 245 Origin string 246 Columns []string 247 } 248 249 type sqliteTableInfo struct { 250 Cid string 251 Name string 252 Type string 253 NotNull bool 254 DefaultValue *string 255 Pk int 256 Hidden int 257 } 258 259 func (s SQLiteDriver) tableInfo(tableName string) ([]*sqliteTableInfo, error) { 260 var ret []*sqliteTableInfo 261 rows, err := s.dbConn.Query(fmt.Sprintf("PRAGMA table_xinfo('%s')", tableName)) 262 263 if err != nil { 264 return nil, err 265 } 266 defer rows.Close() 267 268 for rows.Next() { 269 tinfo := &sqliteTableInfo{} 270 if err := rows.Scan(&tinfo.Cid, &tinfo.Name, &tinfo.Type, &tinfo.NotNull, &tinfo.DefaultValue, &tinfo.Pk, &tinfo.Hidden); err != nil { 271 return nil, fmt.Errorf("unable to scan for table %s: %w", tableName, err) 272 } 273 274 ret = append(ret, tinfo) 275 } 276 return ret, nil 277 } 278 279 func (s SQLiteDriver) indexes(tableName string) ([]*sqliteIndex, error) { 280 var ret []*sqliteIndex 281 rows, err := s.dbConn.Query(fmt.Sprintf("PRAGMA index_list('%s')", tableName)) 282 if err != nil { 283 return nil, err 284 } 285 defer rows.Close() 286 287 for rows.Next() { 288 var idx = &sqliteIndex{} 289 var columns []string 290 if err := rows.Scan(&idx.SeqNum, &idx.Name, &idx.Unique, &idx.Origin, &idx.Partial); err != nil { 291 return nil, err 292 } 293 // get all columns stored within the index 294 rowsColumns, err := s.dbConn.Query(fmt.Sprintf("PRAGMA index_info('%s')", idx.Name)) 295 if err != nil { 296 return nil, err 297 } 298 for rowsColumns.Next() { 299 var rankIndex, rankTable int 300 var colName string 301 if err := rowsColumns.Scan(&rankIndex, &rankTable, &colName); err != nil { 302 return nil, fmt.Errorf("unable to scan for index %s: %w", idx.Name, err) 303 } 304 columns = append(columns, colName) 305 } 306 rowsColumns.Close() 307 idx.Columns = columns 308 ret = append(ret, idx) 309 } 310 return ret, nil 311 } 312 313 // Columns takes a table name and attempts to retrieve the table information 314 // from the database. It retrieves the column names 315 // and column types and returns those as a []Column after TranslateColumnType() 316 // converts the SQL types to Go types, for example: "varchar" to "string" 317 func (s SQLiteDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 318 var columns []drivers.Column 319 320 // get all indexes 321 idxs, err := s.indexes(tableName) 322 if err != nil { 323 return nil, err 324 } 325 326 // finally get the remaining information about the columns 327 tinfo, err := s.tableInfo(tableName) 328 if err != nil { 329 return nil, err 330 } 331 332 query := "SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? AND sql LIKE '%AUTOINCREMENT%'" 333 result, err := s.dbConn.Query(query, tableName) 334 if err != nil { 335 return nil, err 336 } 337 tableHasAutoIncr := result.Next() 338 if err := result.Close(); err != nil { 339 return nil, err 340 } 341 342 var whiteColumns, blackColumns []string 343 if len(whitelist) != 0 { 344 whiteColumns = drivers.ColumnsFromList(whitelist, tableName) 345 } 346 if len(blacklist) != 0 { 347 blackColumns = drivers.ColumnsFromList(blacklist, tableName) 348 } 349 350 nPkeys := 0 351 for _, column := range tinfo { 352 if column.Pk == 1 { 353 nPkeys++ 354 } 355 } 356 357 ColumnLoop: 358 for _, column := range tinfo { 359 if len(whitelist) != 0 { 360 found := false 361 for _, white := range whiteColumns { 362 if white == column.Name { 363 found = true 364 break 365 } 366 } 367 if !found { 368 continue 369 } 370 } else if len(blacklist) != 0 { 371 for _, black := range blackColumns { 372 if black == column.Name { 373 continue ColumnLoop 374 } 375 } 376 } 377 378 bColumn := drivers.Column{ 379 Name: column.Name, 380 FullDBType: strings.ToUpper(column.Type), 381 DBType: strings.ToUpper(column.Type), 382 Nullable: !column.NotNull, 383 } 384 385 // also get a correct information for Unique 386 for _, idx := range idxs { 387 // A unique index with multiple columns does not make 388 // the individual column unique 389 if len(idx.Columns) > 1 { 390 continue 391 } 392 for _, name := range idx.Columns { 393 if name == column.Name { 394 // A column is unique if it has a unique non-partial index 395 bColumn.Unique = idx.Unique > 0 && idx.Partial == 0 396 } 397 } 398 } 399 400 isPrimaryKeyInteger := column.Pk == 1 && bColumn.FullDBType == "INTEGER" 401 // This is special behavior noted in the sqlite documentation. 402 // An integer primary key becomes synonymous with the internal ROWID 403 // and acts as an auto incrementing value. Although there's important 404 // differences between using the keyword AUTOINCREMENT and this inferred 405 // version, they don't matter here so just masquerade as the same thing as 406 // above. 407 autoIncr := isPrimaryKeyInteger && (tableHasAutoIncr || nPkeys == 1) 408 409 // See: https://github.com/sqlite/sqlite/blob/91f621531dc1cb9ba5f6a47eb51b1de9ed8bdd07/src/pragma.c#L1165 410 bColumn.AutoGenerated = autoIncr || column.Hidden == 2 || column.Hidden == 3 411 412 if column.DefaultValue != nil { 413 bColumn.Default = *column.DefaultValue 414 } else if autoIncr { 415 bColumn.Default = "auto_increment" 416 } else if bColumn.AutoGenerated { 417 bColumn.Default = "auto_generated" 418 } 419 420 if bColumn.Nullable && bColumn.Default == "" { 421 bColumn.Default = "NULL" 422 } 423 424 columns = append(columns, bColumn) 425 } 426 427 return columns, nil 428 } 429 430 // PrimaryKeyInfo looks up the primary key for a table. 431 func (s SQLiteDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) { 432 // lookup the columns affected by the PK 433 tinfo, err := s.tableInfo(tableName) 434 if err != nil { 435 return nil, err 436 } 437 438 var columns []string 439 for _, column := range tinfo { 440 if column.Pk > 0 { 441 columns = append(columns, column.Name) 442 } 443 } 444 445 var pk *drivers.PrimaryKey 446 if len(columns) > 0 { 447 pk = &drivers.PrimaryKey{Columns: columns} 448 } 449 return pk, nil 450 } 451 452 // ForeignKeyInfo retrieves the foreign keys for a given table name. 453 func (s SQLiteDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { 454 var fkeys []drivers.ForeignKey 455 456 query := fmt.Sprintf("PRAGMA foreign_key_list('%s')", tableName) 457 458 var rows *sql.Rows 459 var err error 460 if rows, err = s.dbConn.Query(query, tableName); err != nil { 461 return nil, err 462 } 463 defer rows.Close() 464 465 for rows.Next() { 466 var fkey drivers.ForeignKey 467 var onu, ond, match string 468 var id, seq int 469 470 fkey.Table = tableName 471 err = rows.Scan(&id, &seq, &fkey.ForeignTable, &fkey.Column, &fkey.ForeignColumn, &onu, &ond, &match) 472 if err != nil { 473 return nil, err 474 } 475 fkey.Name = fmt.Sprintf("FK_%d", id) 476 477 fkeys = append(fkeys, fkey) 478 } 479 480 if err = rows.Err(); err != nil { 481 return nil, err 482 } 483 484 return fkeys, nil 485 } 486 487 // TranslateColumnType converts sqlite database types to Go types, for example 488 // "varchar" to "string" and "bigint" to "int64". It returns this parsed data 489 // as a Column object. 490 // https://sqlite.org/datatype3.html 491 func (SQLiteDriver) TranslateColumnType(c drivers.Column) drivers.Column { 492 if c.Nullable { 493 switch strings.Split(c.DBType, "(")[0] { 494 case "INT", "INTEGER", "BIGINT": 495 c.Type = "null.Int64" 496 case "TINYINT", "INT8": 497 c.Type = "null.Int8" 498 case "SMALLINT", "INT2": 499 c.Type = "null.Int16" 500 case "MEDIUMINT": 501 c.Type = "null.Int32" 502 case "UNSIGNED BIG INT": 503 c.Type = "null.Uint64" 504 case "CHARACTER", "VARCHAR", "VARYING CHARACTER", "NCHAR", 505 "NATIVE CHARACTER", "NVARCHAR", "TEXT", "CLOB": 506 c.Type = "null.String" 507 case "BLOB": 508 c.Type = "null.Bytes" 509 case "FLOAT": 510 c.Type = "null.Float32" 511 case "REAL", "DOUBLE", "DOUBLE PRECISION": 512 c.Type = "null.Float64" 513 case "NUMERIC", "DECIMAL": 514 c.Type = "types.NullDecimal" 515 case "BOOLEAN": 516 c.Type = "null.Bool" 517 case "DATE", "DATETIME": 518 c.Type = "null.Time" 519 case "JSON": 520 c.Type = "null.JSON" 521 522 default: 523 c.Type = "null.String" 524 } 525 } else { 526 switch c.DBType { 527 case "INT", "INTEGER", "BIGINT": 528 c.Type = "int64" 529 case "TINYINT", "INT8": 530 c.Type = "int8" 531 case "SMALLINT", "INT2": 532 c.Type = "int16" 533 case "MEDIUMINT": 534 c.Type = "int32" 535 case "UNSIGNED BIG INT": 536 c.Type = "uint64" 537 case "CHARACTER", "VARCHAR", "VARYING CHARACTER", "NCHAR", 538 "NATIVE CHARACTER", "NVARCHAR", "TEXT", "CLOB": 539 c.Type = "string" 540 case "BLOB": 541 c.Type = "[]byte" 542 case "FLOAT": 543 c.Type = "float32" 544 case "REAL", "DOUBLE", "DOUBLE PRECISION": 545 c.Type = "float64" 546 case "NUMERIC", "DECIMAL": 547 c.Type = "types.Decimal" 548 case "BOOLEAN": 549 c.Type = "bool" 550 case "DATE", "DATETIME": 551 c.Type = "time.Time" 552 case "JSON": 553 c.Type = "types.JSON" 554 555 default: 556 c.Type = "string" 557 } 558 } 559 560 return c 561 } 562 563 // Imports returns important imports for the driver 564 func (SQLiteDriver) Imports() (col importers.Collection, err error) { 565 col.All = importers.Set{ 566 Standard: importers.List{ 567 `"strconv"`, 568 }, 569 } 570 571 col.Singleton = importers.Map{ 572 "sqlite_upsert": { 573 Standard: importers.List{ 574 `"fmt"`, 575 `"strings"`, 576 }, 577 ThirdParty: importers.List{ 578 `"github.com/volatiletech/strmangle"`, 579 `"github.com/volatiletech/sqlboiler/v4/drivers"`, 580 }, 581 }, 582 } 583 584 col.TestSingleton = importers.Map{ 585 "sqlite3_suites_test": { 586 Standard: importers.List{ 587 `"testing"`, 588 }, 589 }, 590 "sqlite3_main_test": { 591 Standard: importers.List{ 592 `"database/sql"`, 593 `"fmt"`, 594 `"io"`, 595 `"math/rand"`, 596 `"os"`, 597 `"os/exec"`, 598 `"path/filepath"`, 599 `"regexp"`, 600 }, 601 ThirdParty: importers.List{ 602 `"github.com/pkg/errors"`, 603 `"github.com/spf13/viper"`, 604 `_ "modernc.org/sqlite"`, 605 }, 606 }, 607 } 608 609 col.BasedOnType = importers.Map{ 610 "null.Float32": { 611 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 612 }, 613 "null.Float64": { 614 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 615 }, 616 "null.Int": { 617 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 618 }, 619 "null.Int8": { 620 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 621 }, 622 "null.Int16": { 623 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 624 }, 625 "null.Int32": { 626 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 627 }, 628 "null.Int64": { 629 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 630 }, 631 "null.Uint": { 632 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 633 }, 634 "null.Uint8": { 635 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 636 }, 637 "null.Uint16": { 638 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 639 }, 640 "null.Uint32": { 641 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 642 }, 643 "null.Uint64": { 644 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 645 }, 646 "null.String": { 647 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 648 }, 649 "null.Bool": { 650 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 651 }, 652 "null.Time": { 653 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 654 }, 655 "null.Bytes": { 656 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 657 }, 658 659 "time.Time": { 660 Standard: importers.List{`"time"`}, 661 }, 662 "types.Decimal": { 663 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 664 }, 665 "types.NullDecimal": { 666 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 667 }, 668 669 "types.JSON": { 670 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 671 }, 672 "null.JSON": { 673 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 674 }, 675 } 676 return col, err 677 }