github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-mysql/driver/mysql.go (about) 1 package driver 2 3 import ( 4 "database/sql" 5 "embed" 6 "encoding/base64" 7 "fmt" 8 "io/fs" 9 "strconv" 10 "strings" 11 12 "github.com/friendsofgo/errors" 13 "github.com/go-sql-driver/mysql" 14 "github.com/volatiletech/sqlboiler/v4/drivers" 15 "github.com/volatiletech/sqlboiler/v4/importers" 16 "github.com/volatiletech/strmangle" 17 ) 18 19 //go:embed override 20 var templates embed.FS 21 22 func init() { 23 drivers.RegisterFromInit("mysql", &MySQLDriver{}) 24 } 25 26 // Assemble is more useful for calling into the library so you don't 27 // have to instantiate an empty type. 28 func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 29 driver := MySQLDriver{} 30 return driver.Assemble(config) 31 } 32 33 // MySQLDriver holds the database connection string and a handle 34 // to the database connection. 35 type MySQLDriver struct { 36 connStr string 37 conn *sql.DB 38 addEnumTypes bool 39 enumNullPrefix string 40 tinyIntAsInt bool 41 } 42 43 // Templates that should be added/overridden 44 func (MySQLDriver) Templates() (map[string]string, error) { 45 tpls := make(map[string]string) 46 fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error { 47 if err != nil { 48 return err 49 } 50 51 if d.IsDir() { 52 return nil 53 } 54 55 b, err := fs.ReadFile(templates, path) 56 if err != nil { 57 return err 58 } 59 tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b) 60 61 return nil 62 }) 63 64 return tpls, nil 65 } 66 67 // Assemble all the information we need to provide back to the driver 68 func (m *MySQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 69 defer func() { 70 if r := recover(); r != nil && err == nil { 71 dbinfo = nil 72 err = r.(error) 73 } 74 }() 75 76 user := config.MustString(drivers.ConfigUser) 77 pass, _ := config.String(drivers.ConfigPass) 78 dbname := config.MustString(drivers.ConfigDBName) 79 host := config.MustString(drivers.ConfigHost) 80 port := config.DefaultInt(drivers.ConfigPort, 3306) 81 sslmode := config.DefaultString(drivers.ConfigSSLMode, "true") 82 83 schema := dbname 84 whitelist, _ := config.StringSlice(drivers.ConfigWhitelist) 85 blacklist, _ := config.StringSlice(drivers.ConfigBlacklist) 86 concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) 87 88 tinyIntAsIntIntf, ok := config["tinyint_as_int"] 89 if ok { 90 if b, ok := tinyIntAsIntIntf.(bool); ok { 91 m.tinyIntAsInt = b 92 } 93 } 94 95 m.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool) 96 m.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null")) 97 m.connStr = MySQLBuildQueryString(user, pass, dbname, host, port, sslmode) 98 m.conn, err = sql.Open("mysql", m.connStr) 99 if err != nil { 100 return nil, errors.Wrap(err, "sqlboiler-mysql failed to connect to database") 101 } 102 103 defer func() { 104 if e := m.conn.Close(); e != nil { 105 dbinfo = nil 106 err = e 107 } 108 }() 109 110 dbinfo = &drivers.DBInfo{ 111 Dialect: drivers.Dialect{ 112 LQ: '`', 113 RQ: '`', 114 115 UseLastInsertID: true, 116 UseSchema: false, 117 }, 118 } 119 120 dbinfo.Tables, err = drivers.TablesConcurrently(m, schema, whitelist, blacklist, concurrency) 121 if err != nil { 122 return nil, err 123 } 124 125 return dbinfo, err 126 } 127 128 // MySQLBuildQueryString builds a query string for MySQL. 129 func MySQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { 130 config := mysql.NewConfig() 131 132 config.User = user 133 if len(pass) != 0 { 134 config.Passwd = pass 135 } 136 config.DBName = dbname 137 config.Net = "tcp" 138 config.Addr = host 139 if port == 0 { 140 port = 3306 141 } 142 config.Addr += ":" + strconv.Itoa(port) 143 config.TLSConfig = sslmode 144 145 // MySQL is a bad, and by default reads date/datetime into a []byte 146 // instead of a time.Time. Tell it to stop being a bad. 147 config.ParseTime = true 148 149 return config.FormatDSN() 150 } 151 152 // TableNames connects to the mysql database and 153 // retrieves all table names from the information_schema where the 154 // table schema is public. 155 func (m *MySQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { 156 var names []string 157 158 query := fmt.Sprintf(`select table_name from information_schema.tables where table_schema = ? and table_type = 'BASE TABLE'`) 159 args := []interface{}{schema} 160 if len(whitelist) > 0 { 161 tables := drivers.TablesFromList(whitelist) 162 if len(tables) > 0 { 163 query += fmt.Sprintf(" and table_name in (%s)", strings.Repeat(",?", len(tables))[1:]) 164 for _, w := range tables { 165 args = append(args, w) 166 } 167 } 168 } else if len(blacklist) > 0 { 169 tables := drivers.TablesFromList(blacklist) 170 if len(tables) > 0 { 171 query += fmt.Sprintf(" and table_name not in (%s)", strings.Repeat(",?", len(tables))[1:]) 172 for _, b := range tables { 173 args = append(args, b) 174 } 175 } 176 } 177 178 query += ` order by table_name;` 179 180 rows, err := m.conn.Query(query, args...) 181 182 if err != nil { 183 return nil, err 184 } 185 186 defer rows.Close() 187 for rows.Next() { 188 var name string 189 if err := rows.Scan(&name); err != nil { 190 return nil, err 191 } 192 names = append(names, name) 193 } 194 195 return names, nil 196 } 197 198 // ViewNames connects to the postgres database and 199 // retrieves all view names from the information_schema where the 200 // view schema is schema. It uses a whitelist and blacklist. 201 func (m *MySQLDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) { 202 var names []string 203 204 query := `select table_name from information_schema.views where table_schema = ?` 205 args := []interface{}{schema} 206 if len(whitelist) > 0 { 207 tables := drivers.TablesFromList(whitelist) 208 if len(tables) > 0 { 209 query += fmt.Sprintf(" and table_name in (%s)", strings.Repeat(",?", len(tables))[1:]) 210 for _, w := range tables { 211 args = append(args, w) 212 } 213 } 214 } else if len(blacklist) > 0 { 215 tables := drivers.TablesFromList(blacklist) 216 if len(tables) > 0 { 217 query += fmt.Sprintf(" and table_name not in (%s)", strings.Repeat(",?", len(tables))[1:]) 218 for _, b := range tables { 219 args = append(args, b) 220 } 221 } 222 } 223 224 query += ` order by table_name;` 225 226 rows, err := m.conn.Query(query, args...) 227 228 if err != nil { 229 return nil, err 230 } 231 232 defer rows.Close() 233 for rows.Next() { 234 var name string 235 if err := rows.Scan(&name); err != nil { 236 return nil, err 237 } 238 239 names = append(names, name) 240 } 241 242 return names, nil 243 } 244 245 // ViewCapabilities return what actions are allowed for a view. 246 func (m *MySQLDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) { 247 capabilities := drivers.ViewCapabilities{ 248 // No definite way to check if a view is insertable 249 // See: https://dba.stackexchange.com/questions/285451/does-mysql-have-a-built-in-way-to-tell-whether-a-view-is-insertable-not-just-up?newreg=e6c571353a0948638bec10cf7f8c6f6f 250 CanInsert: false, 251 CanUpsert: false, 252 } 253 254 return capabilities, nil 255 } 256 257 func (m *MySQLDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 258 return m.Columns(schema, tableName, whitelist, blacklist) 259 } 260 261 // Columns takes a table name and attempts to retrieve the table information 262 // from the database information_schema.columns. It retrieves the column names 263 // and column types and returns those as a []Column after TranslateColumnType() 264 // converts the SQL types to Go types, for example: "varchar" to "string" 265 func (m *MySQLDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 266 var columns []drivers.Column 267 args := []interface{}{tableName, tableName, schema, schema, schema, schema, tableName, tableName, schema} 268 269 query := ` 270 select 271 c.column_name, 272 c.column_type, 273 c.column_comment, 274 if(c.data_type = 'enum', c.column_type, c.data_type), 275 if(extra = 'auto_increment','auto_increment', 276 if(version() like '%MariaDB%' and c.column_default = 'NULL', '', 277 if(version() like '%MariaDB%' and c.data_type in ('varchar','char','binary','date','datetime','time'), 278 replace(substring(c.column_default,2,length(c.column_default)-2),'\'\'','\''), 279 c.column_default))), 280 c.is_nullable = 'YES', 281 (c.extra = 'STORED GENERATED' OR c.extra = 'VIRTUAL GENERATED') is_generated, 282 exists ( 283 select c.column_name 284 from information_schema.table_constraints tc 285 inner join information_schema.key_column_usage kcu 286 on tc.constraint_name = kcu.constraint_name 287 where tc.table_name = ? and kcu.table_name = ? and tc.table_schema = ? and kcu.table_schema = ? and 288 c.column_name = kcu.column_name and 289 (tc.constraint_type = 'PRIMARY KEY' or tc.constraint_type = 'UNIQUE') and 290 (select count(*) from information_schema.key_column_usage where table_schema = ? and 291 constraint_schema = ? and table_name = ? and constraint_name = tc.constraint_name) = 1 292 ) as is_unique 293 from information_schema.columns as c 294 where table_name = ? and table_schema = ?` 295 296 if len(whitelist) > 0 { 297 cols := drivers.ColumnsFromList(whitelist, tableName) 298 if len(cols) > 0 { 299 query += fmt.Sprintf(" and c.column_name in (%s)", strings.Repeat(",?", len(cols))[1:]) 300 for _, w := range cols { 301 args = append(args, w) 302 } 303 } 304 } else if len(blacklist) > 0 { 305 cols := drivers.ColumnsFromList(blacklist, tableName) 306 if len(cols) > 0 { 307 query += fmt.Sprintf(" and c.column_name not in (%s)", strings.Repeat(",?", len(cols))[1:]) 308 for _, w := range cols { 309 args = append(args, w) 310 } 311 } 312 } 313 314 query += ` order by c.ordinal_position;` 315 316 rows, err := m.conn.Query(query, args...) 317 if err != nil { 318 return nil, err 319 } 320 defer rows.Close() 321 322 for rows.Next() { 323 var colName, colFullType, colComment, colType string 324 var nullable, generated, unique bool 325 var defaultValue *string 326 if err := rows.Scan(&colName, &colFullType, &colComment, &colType, &defaultValue, &nullable, &generated, &unique); err != nil { 327 return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) 328 } 329 330 column := drivers.Column{ 331 Name: colName, 332 Comment: colComment, 333 FullDBType: colFullType, // example: tinyint(1) instead of tinyint 334 DBType: colType, 335 Nullable: nullable, 336 Unique: unique, 337 AutoGenerated: generated, 338 } 339 340 if defaultValue != nil { 341 column.Default = *defaultValue 342 } 343 344 // A generated column technically has a default value 345 if column.Default == "" && column.AutoGenerated { 346 column.Default = "AUTO_GENERATED" 347 } 348 349 columns = append(columns, column) 350 } 351 352 return columns, nil 353 } 354 355 // PrimaryKeyInfo looks up the primary key for a table. 356 func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) { 357 pkey := &drivers.PrimaryKey{} 358 var err error 359 360 query := ` 361 select tc.constraint_name 362 from information_schema.table_constraints as tc 363 where tc.table_name = ? and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = ?;` 364 365 row := m.conn.QueryRow(query, tableName, schema) 366 if err = row.Scan(&pkey.Name); err != nil { 367 if errors.Is(err, sql.ErrNoRows) { 368 return nil, nil 369 } 370 return nil, err 371 } 372 373 queryColumns := ` 374 select kcu.column_name 375 from information_schema.key_column_usage as kcu 376 where table_name = ? and constraint_name = ? and table_schema = ? 377 order by kcu.ordinal_position;` 378 379 var rows *sql.Rows 380 if rows, err = m.conn.Query(queryColumns, tableName, pkey.Name, schema); err != nil { 381 return nil, err 382 } 383 defer rows.Close() 384 385 var columns []string 386 for rows.Next() { 387 var column string 388 389 err = rows.Scan(&column) 390 if err != nil { 391 return nil, err 392 } 393 394 columns = append(columns, column) 395 } 396 397 if err = rows.Err(); err != nil { 398 return nil, err 399 } 400 401 pkey.Columns = columns 402 403 return pkey, nil 404 } 405 406 // ForeignKeyInfo retrieves the foreign keys for a given table name. 407 func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { 408 var fkeys []drivers.ForeignKey 409 410 query := ` 411 select constraint_name, table_name, column_name, referenced_table_name, referenced_column_name 412 from information_schema.key_column_usage 413 where table_schema = ? and referenced_table_schema = ? and table_name = ? 414 order by constraint_name, table_name, column_name, referenced_table_name, referenced_column_name 415 ` 416 417 var rows *sql.Rows 418 var err error 419 if rows, err = m.conn.Query(query, schema, schema, tableName); err != nil { 420 return nil, err 421 } 422 423 for rows.Next() { 424 var fkey drivers.ForeignKey 425 var sourceTable string 426 427 fkey.Table = tableName 428 err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn) 429 if err != nil { 430 return nil, err 431 } 432 433 fkeys = append(fkeys, fkey) 434 } 435 436 if err = rows.Err(); err != nil { 437 return nil, err 438 } 439 440 return fkeys, nil 441 } 442 443 // TranslateColumnType converts mysql database types to Go types, for example 444 // "varchar" to "string" and "bigint" to "int64". It returns this parsed data 445 // as a Column object. 446 // Deprecated: for MySQL enum types to be created properly TranslateTableColumnType method should be used instead. 447 func (m *MySQLDriver) TranslateColumnType(drivers.Column) drivers.Column { 448 panic("TranslateTableColumnType should be called") 449 } 450 451 // TranslateTableColumnType converts mysql database types to Go types, for example 452 // "varchar" to "string" and "bigint" to "int64". It returns this parsed data 453 // as a Column object. 454 func (m *MySQLDriver) TranslateTableColumnType(c drivers.Column, tableName string) drivers.Column { 455 unsigned := strings.Contains(c.FullDBType, "unsigned") 456 if c.Nullable { 457 switch c.DBType { 458 case "tinyint": 459 // map tinyint(1) to bool if TinyintAsBool is true 460 if !m.tinyIntAsInt && c.FullDBType == "tinyint(1)" { 461 c.Type = "null.Bool" 462 } else if unsigned { 463 c.Type = "null.Uint8" 464 } else { 465 c.Type = "null.Int8" 466 } 467 case "smallint": 468 if unsigned { 469 c.Type = "null.Uint16" 470 } else { 471 c.Type = "null.Int16" 472 } 473 case "mediumint": 474 if unsigned { 475 c.Type = "null.Uint32" 476 } else { 477 c.Type = "null.Int32" 478 } 479 case "int", "integer": 480 if unsigned { 481 c.Type = "null.Uint" 482 } else { 483 c.Type = "null.Int" 484 } 485 case "bigint": 486 if unsigned { 487 c.Type = "null.Uint64" 488 } else { 489 c.Type = "null.Int64" 490 } 491 case "float": 492 c.Type = "null.Float32" 493 case "double", "double precision", "real": 494 c.Type = "null.Float64" 495 case "boolean", "bool": 496 c.Type = "null.Bool" 497 case "date", "datetime", "timestamp": 498 c.Type = "null.Time" 499 case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": 500 c.Type = "null.Bytes" 501 case "numeric", "decimal", "dec", "fixed": 502 c.Type = "types.NullDecimal" 503 case "json": 504 c.Type = "null.JSON" 505 default: 506 if len(strmangle.ParseEnumVals(c.DBType)) > 0 && m.addEnumTypes { 507 c.Type = strmangle.TitleCase(tableName) + m.enumNullPrefix + strmangle.TitleCase(c.Name) 508 } else { 509 c.Type = "null.String" 510 } 511 } 512 } else { 513 switch c.DBType { 514 case "tinyint": 515 // map tinyint(1) to bool if TinyintAsBool is true 516 if !m.tinyIntAsInt && c.FullDBType == "tinyint(1)" { 517 c.Type = "bool" 518 } else if unsigned { 519 c.Type = "uint8" 520 } else { 521 c.Type = "int8" 522 } 523 case "smallint": 524 if unsigned { 525 c.Type = "uint16" 526 } else { 527 c.Type = "int16" 528 } 529 case "mediumint": 530 if unsigned { 531 c.Type = "uint32" 532 } else { 533 c.Type = "int32" 534 } 535 case "int", "integer": 536 if unsigned { 537 c.Type = "uint" 538 } else { 539 c.Type = "int" 540 } 541 case "bigint": 542 if unsigned { 543 c.Type = "uint64" 544 } else { 545 c.Type = "int64" 546 } 547 case "float": 548 c.Type = "float32" 549 case "double", "double precision", "real": 550 c.Type = "float64" 551 case "boolean", "bool": 552 c.Type = "bool" 553 case "date", "datetime", "timestamp": 554 c.Type = "time.Time" 555 case "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": 556 c.Type = "[]byte" 557 case "numeric", "decimal", "dec", "fixed": 558 c.Type = "types.Decimal" 559 case "json": 560 c.Type = "types.JSON" 561 default: 562 if len(strmangle.ParseEnumVals(c.DBType)) > 0 && m.addEnumTypes { 563 c.Type = strmangle.TitleCase(tableName) + strmangle.TitleCase(c.Name) 564 } else { 565 c.Type = "string" 566 } 567 } 568 } 569 570 return c 571 } 572 573 // Imports returns important imports for the driver 574 func (MySQLDriver) Imports() (col importers.Collection, err error) { 575 col.All = importers.Set{ 576 Standard: importers.List{ 577 `"strconv"`, 578 }, 579 } 580 581 col.Singleton = importers.Map{ 582 "mysql_upsert": { 583 Standard: importers.List{ 584 `"fmt"`, 585 `"strings"`, 586 }, 587 ThirdParty: importers.List{ 588 `"github.com/volatiletech/strmangle"`, 589 `"github.com/volatiletech/sqlboiler/v4/drivers"`, 590 }, 591 }, 592 } 593 594 col.TestSingleton = importers.Map{ 595 "mysql_suites_test": { 596 Standard: importers.List{ 597 `"testing"`, 598 }, 599 }, 600 "mysql_main_test": { 601 Standard: importers.List{ 602 `"bytes"`, 603 `"database/sql"`, 604 `"fmt"`, 605 `"io"`, 606 `"os"`, 607 `"os/exec"`, 608 `"regexp"`, 609 `"strings"`, 610 }, 611 ThirdParty: importers.List{ 612 `"github.com/kat-co/vala"`, 613 `"github.com/friendsofgo/errors"`, 614 `"github.com/spf13/viper"`, 615 `"github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-mysql/driver"`, 616 `"github.com/volatiletech/randomize"`, 617 `_ "github.com/go-sql-driver/mysql"`, 618 }, 619 }, 620 } 621 622 col.BasedOnType = importers.Map{ 623 "null.Float32": { 624 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 625 }, 626 "null.Float64": { 627 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 628 }, 629 "null.Int": { 630 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 631 }, 632 "null.Int8": { 633 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 634 }, 635 "null.Int16": { 636 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 637 }, 638 "null.Int32": { 639 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 640 }, 641 "null.Int64": { 642 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 643 }, 644 "null.Uint": { 645 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 646 }, 647 "null.Uint8": { 648 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 649 }, 650 "null.Uint16": { 651 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 652 }, 653 "null.Uint32": { 654 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 655 }, 656 "null.Uint64": { 657 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 658 }, 659 "null.String": { 660 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 661 }, 662 "null.Bool": { 663 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 664 }, 665 "null.Time": { 666 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 667 }, 668 "null.Bytes": { 669 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 670 }, 671 "null.JSON": { 672 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 673 }, 674 675 "time.Time": { 676 Standard: importers.List{`"time"`}, 677 }, 678 "types.JSON": { 679 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 680 }, 681 "types.Decimal": { 682 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 683 }, 684 "types.NullDecimal": { 685 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 686 }, 687 } 688 return col, err 689 }