github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-psql/driver/psql.go (about) 1 // Package driver implements an sqlboiler driver. 2 // It can be used by either building the main.go in the same project 3 // and using as a binary or using the side effect import. 4 package driver 5 6 import ( 7 "database/sql" 8 "embed" 9 "encoding/base64" 10 "fmt" 11 "io/fs" 12 "os" 13 "strings" 14 15 "github.com/volatiletech/sqlboiler/v4/importers" 16 17 "github.com/friendsofgo/errors" 18 "github.com/volatiletech/sqlboiler/v4/drivers" 19 "github.com/volatiletech/strmangle" 20 21 // Side-effect import sql driver 22 _ "github.com/lib/pq" 23 ) 24 25 //go:embed override 26 var templates embed.FS 27 28 func init() { 29 drivers.RegisterFromInit("psql", &PostgresDriver{}) 30 } 31 32 // Assemble is more useful for calling into the library so you don't 33 // have to instantiate an empty type. 34 func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 35 driver := PostgresDriver{} 36 return driver.Assemble(config) 37 } 38 39 // PostgresDriver holds the database connection string and a handle 40 // to the database connection. 41 type PostgresDriver struct { 42 connStr string 43 conn *sql.DB 44 version int 45 addEnumTypes bool 46 enumNullPrefix string 47 48 uniqueColumns map[columnIdentifier]struct{} 49 } 50 51 type columnIdentifier struct { 52 Schema string 53 Table string 54 Column string 55 } 56 57 // Templates that should be added/overridden 58 func (p *PostgresDriver) Templates() (map[string]string, error) { 59 tpls := make(map[string]string) 60 fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error { 61 if err != nil { 62 return err 63 } 64 65 if d.IsDir() { 66 return nil 67 } 68 69 b, err := fs.ReadFile(templates, path) 70 if err != nil { 71 return err 72 } 73 tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b) 74 75 return nil 76 }) 77 78 return tpls, nil 79 } 80 81 // Assemble all the information we need to provide back to the driver 82 func (p *PostgresDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 83 defer func() { 84 if r := recover(); r != nil && err == nil { 85 dbinfo = nil 86 err = r.(error) 87 } 88 }() 89 90 user := config.MustString(drivers.ConfigUser) 91 pass, _ := config.String(drivers.ConfigPass) 92 dbname := config.MustString(drivers.ConfigDBName) 93 host := config.MustString(drivers.ConfigHost) 94 port := config.DefaultInt(drivers.ConfigPort, 5432) 95 sslmode := config.DefaultString(drivers.ConfigSSLMode, "require") 96 schema := config.DefaultString(drivers.ConfigSchema, "public") 97 whitelist, _ := config.StringSlice(drivers.ConfigWhitelist) 98 blacklist, _ := config.StringSlice(drivers.ConfigBlacklist) 99 concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) 100 101 useSchema := schema != "public" 102 103 p.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool) 104 p.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null")) 105 p.connStr = PSQLBuildQueryString(user, pass, dbname, host, port, sslmode) 106 p.conn, err = sql.Open("postgres", p.connStr) 107 if err != nil { 108 return nil, errors.Wrap(err, "sqlboiler-psql failed to connect to database") 109 } 110 111 defer func() { 112 if e := p.conn.Close(); e != nil { 113 dbinfo = nil 114 err = e 115 } 116 }() 117 118 p.version, err = p.getVersion() 119 if err != nil { 120 return nil, errors.Wrap(err, "sqlboiler-psql failed to get database version") 121 } 122 123 if err = p.loadUniqueColumns(); err != nil { 124 return nil, errors.Wrap(err, "sqlboiler-psql failed to load unique columns") 125 } 126 127 dbinfo = &drivers.DBInfo{ 128 Schema: schema, 129 Dialect: drivers.Dialect{ 130 LQ: '"', 131 RQ: '"', 132 133 UseIndexPlaceholders: true, 134 UseSchema: useSchema, 135 UseDefaultKeyword: true, 136 }, 137 } 138 dbinfo.Tables, err = drivers.TablesConcurrently(p, schema, whitelist, blacklist, concurrency) 139 if err != nil { 140 return nil, err 141 } 142 143 return dbinfo, err 144 } 145 146 // PSQLBuildQueryString builds a query string. 147 func PSQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { 148 parts := []string{} 149 if len(user) != 0 { 150 parts = append(parts, fmt.Sprintf("user=%s", user)) 151 } 152 if len(pass) != 0 { 153 parts = append(parts, fmt.Sprintf("password=%s", pass)) 154 } 155 if len(dbname) != 0 { 156 parts = append(parts, fmt.Sprintf("dbname=%s", dbname)) 157 } 158 if len(host) != 0 { 159 parts = append(parts, fmt.Sprintf("host=%s", host)) 160 } 161 if port != 0 { 162 parts = append(parts, fmt.Sprintf("port=%d", port)) 163 } 164 if len(sslmode) != 0 { 165 parts = append(parts, fmt.Sprintf("sslmode=%s", sslmode)) 166 } 167 168 return strings.Join(parts, " ") 169 } 170 171 // TableNames connects to the postgres database and 172 // retrieves all table names from the information_schema where the 173 // table schema is schema. It uses a whitelist and blacklist. 174 func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { 175 var names []string 176 177 query := `select table_name from information_schema.tables where table_schema = $1 and table_type = 'BASE TABLE'` 178 args := []interface{}{schema} 179 if len(whitelist) > 0 { 180 tables := drivers.TablesFromList(whitelist) 181 if len(tables) > 0 { 182 query += fmt.Sprintf(" and table_name in (%s)", strmangle.Placeholders(true, len(tables), 2, 1)) 183 for _, w := range tables { 184 args = append(args, w) 185 } 186 } 187 } else if len(blacklist) > 0 { 188 tables := drivers.TablesFromList(blacklist) 189 if len(tables) > 0 { 190 query += fmt.Sprintf(" and table_name not in (%s)", strmangle.Placeholders(true, len(tables), 2, 1)) 191 for _, b := range tables { 192 args = append(args, b) 193 } 194 } 195 } 196 197 query += ` order by table_name;` 198 199 rows, err := p.conn.Query(query, args...) 200 if err != nil { 201 return nil, err 202 } 203 204 defer rows.Close() 205 for rows.Next() { 206 var name string 207 if err := rows.Scan(&name); err != nil { 208 return nil, err 209 } 210 names = append(names, name) 211 } 212 213 return names, nil 214 } 215 216 // ViewNames connects to the postgres database and 217 // retrieves all view names from the information_schema where the 218 // view schema is schema. It uses a whitelist and blacklist. 219 func (p *PostgresDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) { 220 var names []string 221 222 query := `select 223 table_name 224 from ( 225 select 226 table_name, 227 table_schema 228 from information_schema.views 229 UNION 230 select 231 matviewname as table_name, 232 schemaname as table_schema 233 from pg_matviews 234 ) as v where v.table_schema= $1` 235 args := []interface{}{schema} 236 if len(whitelist) > 0 { 237 views := drivers.TablesFromList(whitelist) 238 if len(views) > 0 { 239 query += fmt.Sprintf(" and table_name in (%s)", strmangle.Placeholders(true, len(views), 2, 1)) 240 for _, w := range views { 241 args = append(args, w) 242 } 243 } 244 } else if len(blacklist) > 0 { 245 views := drivers.TablesFromList(blacklist) 246 if len(views) > 0 { 247 query += fmt.Sprintf(" and table_name not in (%s)", strmangle.Placeholders(true, len(views), 2, 1)) 248 for _, b := range views { 249 args = append(args, b) 250 } 251 } 252 } 253 254 query += ` order by table_name;` 255 256 rows, err := p.conn.Query(query, args...) 257 if err != nil { 258 return nil, err 259 } 260 261 defer rows.Close() 262 for rows.Next() { 263 var name string 264 if err := rows.Scan(&name); err != nil { 265 return nil, err 266 } 267 268 names = append(names, name) 269 } 270 271 return names, nil 272 } 273 274 // ViewCapabilities return what actions are allowed for a view. 275 func (p *PostgresDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) { 276 capabilities := drivers.ViewCapabilities{} 277 278 query := `select 279 is_insertable_into, 280 is_updatable, 281 is_trigger_insertable_into, 282 is_trigger_updatable, 283 is_trigger_deletable 284 from ( 285 select 286 table_schema, 287 table_name, 288 is_insertable_into = 'YES' as is_insertable_into, 289 is_updatable = 'YES' as is_updatable, 290 is_trigger_insertable_into = 'YES' as is_trigger_insertable_into, 291 is_trigger_updatable = 'YES' as is_trigger_updatable, 292 is_trigger_deletable = 'YES' as is_trigger_deletable 293 from information_schema.views 294 UNION 295 select 296 schemaname as table_schema, 297 matviewname as table_name, 298 false as is_insertable_into, 299 false as is_updatable, 300 false as is_trigger_insertable_into, 301 false as is_trigger_updatable, 302 false as is_trigger_deletable 303 from pg_matviews 304 ) as v where v.table_schema= $1 and v.table_name = $2 305 order by table_name;` 306 307 row := p.conn.QueryRow(query, schema, name) 308 309 var insertable, updatable, trInsert, trUpdate, trDelete bool 310 if err := row.Scan(&insertable, &updatable, &trInsert, &trUpdate, &trDelete); err != nil { 311 return capabilities, err 312 } 313 314 capabilities.CanInsert = insertable || trInsert 315 capabilities.CanUpsert = insertable && updatable 316 317 return capabilities, nil 318 } 319 320 // loadUniqueColumns is responsible for populating p.uniqueColumns with an entry 321 // for every table or view column that is made unique by an index or constraint. 322 // This information is queried once, rather than for each table, for performance 323 // reasons. 324 func (p *PostgresDriver) loadUniqueColumns() error { 325 if p.uniqueColumns != nil { 326 return nil 327 } 328 p.uniqueColumns = map[columnIdentifier]struct{}{} 329 query := `with 330 method_a as ( 331 select 332 tc.table_schema as schema_name, 333 ccu.table_name as table_name, 334 ccu.column_name as column_name 335 from information_schema.table_constraints tc 336 inner join information_schema.constraint_column_usage as ccu 337 on tc.constraint_name = ccu.constraint_name 338 where 339 tc.constraint_type = 'UNIQUE' and ( 340 (select count(*) 341 from information_schema.constraint_column_usage 342 where constraint_schema = tc.table_schema and constraint_name = tc.constraint_name 343 ) = 1 344 ) 345 ), 346 method_b as ( 347 select 348 pgix.schemaname as schema_name, 349 pgix.tablename as table_name, 350 pga.attname as column_name 351 from pg_indexes pgix 352 inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i' and pgc.relnatts = 1 353 inner join pg_index pgi on pgi.indexrelid = pgc.oid 354 inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey) 355 where pgi.indisunique = true 356 ), 357 results as ( 358 select * from method_a 359 union 360 select * from method_b 361 ) 362 select * from results; 363 ` 364 rows, err := p.conn.Query(query) 365 if err != nil { 366 return err 367 } 368 defer rows.Close() 369 370 for rows.Next() { 371 var c columnIdentifier 372 if err := rows.Scan(&c.Schema, &c.Table, &c.Column); err != nil { 373 return errors.Wrapf(err, "unable to scan unique entry row") 374 } 375 p.uniqueColumns[c] = struct{}{} 376 } 377 return nil 378 } 379 380 func (p *PostgresDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 381 return p.Columns(schema, tableName, whitelist, blacklist) 382 } 383 384 // Columns takes a table name and attempts to retrieve the table information 385 // from the database information_schema.columns. It retrieves the column names 386 // and column types and returns those as a []Column after TranslateColumnType() 387 // converts the SQL types to Go types, for example: "varchar" to "string" 388 func (p *PostgresDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 389 var columns []drivers.Column 390 args := []interface{}{schema, tableName} 391 392 matviewQuery := `WITH cte_pg_attribute AS ( 393 SELECT 394 pg_catalog.format_type(a.atttypid, NULL) LIKE '%[]' = TRUE as is_array, 395 pg_catalog.format_type(a.atttypid, a.atttypmod) as column_full_type, 396 a.* 397 FROM pg_attribute a 398 ), cte_pg_namespace AS ( 399 SELECT 400 n.nspname NOT IN ('pg_catalog', 'information_schema') = TRUE as is_user_defined, 401 n.oid 402 FROM pg_namespace n 403 ), cte_information_schema_domains AS ( 404 SELECT 405 domain_name IS NOT NULL = TRUE as is_domain, 406 data_type LIKE '%[]' = TRUE as is_array, 407 domain_name, 408 udt_name, 409 data_type 410 FROM information_schema.domains 411 ) 412 SELECT 413 a.attnum as ordinal_position, 414 a.attname as column_name, 415 ( 416 case 417 when t.typtype = 'e' 418 then ( 419 select 'enum.' || t.typname || '(''' || string_agg(labels.label, ''',''') || ''')' 420 from ( 421 select pg_enum.enumlabel as label 422 from pg_enum 423 where pg_enum.enumtypid = 424 ( 425 select typelem 426 from pg_type 427 inner join pg_namespace ON pg_type.typnamespace = pg_namespace.oid 428 where pg_type.typtype = 'b' and pg_type.typname = ('_' || t.typname) and pg_namespace.nspname=$1 429 limit 1 430 ) 431 order by pg_enum.enumsortorder 432 ) as labels 433 ) 434 when a.is_array OR d.is_array 435 then 'ARRAY' 436 when d.is_domain 437 then d.data_type 438 when tn.is_user_defined 439 then 'USER-DEFINED' 440 else pg_catalog.format_type(a.atttypid, NULL) 441 end 442 ) as column_type, 443 ( 444 case 445 when d.is_domain 446 then d.udt_name 447 when a.column_full_type LIKE '%(%)%' AND t.typcategory IN ('S', 'V') 448 then a.column_full_type 449 else t.typname 450 end 451 ) as column_full_type, 452 ( 453 case 454 when d.is_domain 455 then d.udt_name 456 else t.typname 457 end 458 ) as udt_name, 459 ( 460 case when a.is_array 461 then 462 case when tn.is_user_defined 463 then 'USER-DEFINED' 464 else RTRIM(pg_catalog.format_type(a.atttypid, NULL), '[]') 465 end 466 else NULL 467 end 468 ) as array_type, 469 d.domain_name, 470 NULL as column_default, 471 '' as column_comment, 472 a.attnotnull = FALSE as is_nullable, 473 FALSE as is_generated, 474 a.attidentity <> '' as is_identity 475 FROM cte_pg_attribute a 476 JOIN pg_class c on a.attrelid = c.oid 477 JOIN pg_namespace cn on c.relnamespace = cn.oid 478 JOIN pg_type t ON t.oid = a.atttypid 479 LEFT JOIN cte_pg_namespace tn ON t.typnamespace = tn.oid 480 LEFT JOIN cte_information_schema_domains d ON d.domain_name = pg_catalog.format_type(a.atttypid, NULL) 481 WHERE a.attnum > 0 482 AND c.relkind = 'm' 483 AND NOT a.attisdropped 484 AND c.relname = $2 485 AND cn.nspname = $1` 486 487 tableQuery := ` 488 select 489 c.ordinal_position, 490 c.column_name, 491 ct.column_type, 492 ( 493 case when c.character_maximum_length != 0 494 then 495 ( 496 ct.column_type || '(' || c.character_maximum_length || ')' 497 ) 498 else c.udt_name 499 end 500 ) as column_full_type, 501 502 c.udt_name, 503 ( 504 SELECT 505 data_type 506 FROM 507 information_schema.element_types e 508 WHERE 509 c.table_catalog = e.object_catalog 510 AND c.table_schema = e.object_schema 511 AND c.table_name = e.object_name 512 AND 'TABLE' = e.object_type 513 AND c.dtd_identifier = e.collection_type_identifier 514 ) AS array_type, 515 c.domain_name, 516 c.column_default, 517 518 COALESCE(col_description(('"'||c.table_schema||'"."'||c.table_name||'"')::regclass::oid, ordinal_position), '') as column_comment, 519 520 c.is_nullable = 'YES' as is_nullable, 521 ( 522 case when c.is_generated = 'ALWAYS' or c.identity_generation = 'ALWAYS' 523 then TRUE else FALSE end 524 ) as is_generated, 525 (case 526 when (select 527 case 528 when column_name = 'is_identity' then (select c.is_identity = 'YES' as is_identity) 529 else 530 false 531 end as is_identity from information_schema.columns 532 WHERE table_schema='information_schema' and table_name='columns' and column_name='is_identity') IS NULL then 'NO' else is_identity end 533 ) = 'YES' as is_identity 534 535 from information_schema.columns as c 536 inner join pg_namespace as pgn on pgn.nspname = c.udt_schema 537 left join pg_type pgt on c.data_type = 'USER-DEFINED' and pgn.oid = pgt.typnamespace and c.udt_name = pgt.typname, 538 lateral (select 539 ( 540 case when pgt.typtype = 'e' 541 then 542 ( 543 select 'enum.' || c.udt_name || '(''' || string_agg(labels.label, ''',''') || ''')' 544 from ( 545 select pg_enum.enumlabel as label 546 from pg_enum 547 where pg_enum.enumtypid = 548 ( 549 select typelem 550 from pg_type 551 inner join pg_namespace ON pg_type.typnamespace = pg_namespace.oid 552 where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name) and pg_namespace.nspname=$1 553 limit 1 554 ) 555 order by pg_enum.enumsortorder 556 ) as labels 557 ) 558 else c.data_type 559 end 560 ) as column_type 561 ) ct 562 where c.table_name = $2 and c.table_schema = $1` 563 564 query := fmt.Sprintf(`SELECT 565 column_name, 566 column_type, 567 column_full_type, 568 udt_name, 569 array_type, 570 domain_name, 571 column_default, 572 column_comment, 573 is_nullable, 574 is_generated, 575 is_identity 576 FROM ( 577 %s 578 UNION 579 %s 580 ) AS c`, matviewQuery, tableQuery) 581 582 if len(whitelist) > 0 { 583 cols := drivers.ColumnsFromList(whitelist, tableName) 584 if len(cols) > 0 { 585 query += fmt.Sprintf(" where c.column_name in (%s)", strmangle.Placeholders(true, len(cols), 3, 1)) 586 for _, w := range cols { 587 args = append(args, w) 588 } 589 } 590 } else if len(blacklist) > 0 { 591 cols := drivers.ColumnsFromList(blacklist, tableName) 592 if len(cols) > 0 { 593 query += fmt.Sprintf(" where c.column_name not in (%s)", strmangle.Placeholders(true, len(cols), 3, 1)) 594 for _, w := range cols { 595 args = append(args, w) 596 } 597 } 598 } 599 600 query += ` order by c.ordinal_position;` 601 602 rows, err := p.conn.Query(query, args...) 603 if err != nil { 604 return nil, err 605 } 606 defer rows.Close() 607 608 for rows.Next() { 609 var colName, colType, colFullType, udtName, comment string 610 var defaultValue, arrayType, domainName *string 611 var nullable, generated, identity bool 612 if err := rows.Scan(&colName, &colType, &colFullType, &udtName, &arrayType, &domainName, &defaultValue, &comment, &nullable, &generated, &identity); err != nil { 613 return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) 614 } 615 616 _, unique := p.uniqueColumns[columnIdentifier{schema, tableName, colName}] 617 column := drivers.Column{ 618 Name: colName, 619 DBType: colType, 620 FullDBType: colFullType, 621 ArrType: arrayType, 622 DomainName: domainName, 623 UDTName: udtName, 624 Comment: comment, 625 Nullable: nullable, 626 AutoGenerated: generated, 627 Unique: unique, 628 } 629 if defaultValue != nil { 630 column.Default = *defaultValue 631 } 632 633 if identity { 634 column.Default = "IDENTITY" 635 } 636 637 // A generated column technically has a default value 638 if generated && column.Default == "" { 639 column.Default = "GENERATED" 640 } 641 642 // A nullable column can always default to NULL 643 if nullable && column.Default == "" { 644 column.Default = "NULL" 645 } 646 647 columns = append(columns, column) 648 } 649 650 return columns, nil 651 } 652 653 // PrimaryKeyInfo looks up the primary key for a table. 654 func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) { 655 pkey := &drivers.PrimaryKey{} 656 var err error 657 658 query := ` 659 select tc.constraint_name 660 from information_schema.table_constraints as tc 661 where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;` 662 663 row := p.conn.QueryRow(query, tableName, schema) 664 if err = row.Scan(&pkey.Name); err != nil { 665 if errors.Is(err, sql.ErrNoRows) { 666 return nil, nil 667 } 668 return nil, err 669 } 670 671 queryColumns := ` 672 select kcu.column_name 673 from information_schema.key_column_usage as kcu 674 where constraint_name = $1 and table_name = $2 and table_schema = $3 675 order by kcu.ordinal_position;` 676 677 var rows *sql.Rows 678 if rows, err = p.conn.Query(queryColumns, pkey.Name, tableName, schema); err != nil { 679 return nil, err 680 } 681 defer rows.Close() 682 683 var columns []string 684 for rows.Next() { 685 var column string 686 687 err = rows.Scan(&column) 688 if err != nil { 689 return nil, err 690 } 691 692 columns = append(columns, column) 693 } 694 695 if err = rows.Err(); err != nil { 696 return nil, err 697 } 698 699 pkey.Columns = columns 700 701 return pkey, nil 702 } 703 704 // ForeignKeyInfo retrieves the foreign keys for a given table name. 705 func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { 706 var fkeys []drivers.ForeignKey 707 708 whereConditions := []string{"pgn.nspname = $2", "pgc.relname = $1", "pgcon.contype = 'f'"} 709 if p.version >= 120000 { 710 whereConditions = append(whereConditions, "pgasrc.attgenerated = ''", "pgadst.attgenerated = ''") 711 } 712 713 query := fmt.Sprintf(` 714 select 715 pgcon.conname, 716 pgc.relname as source_table, 717 pgasrc.attname as source_column, 718 dstlookupname.relname as dest_table, 719 pgadst.attname as dest_column 720 from pg_namespace pgn 721 inner join pg_class pgc on pgn.oid = pgc.relnamespace and pgc.relkind = 'r' 722 inner join pg_constraint pgcon on pgn.oid = pgcon.connamespace and pgc.oid = pgcon.conrelid 723 inner join pg_class dstlookupname on pgcon.confrelid = dstlookupname.oid 724 inner join pg_attribute pgasrc on pgc.oid = pgasrc.attrelid and pgasrc.attnum = ANY(pgcon.conkey) 725 inner join pg_attribute pgadst on pgcon.confrelid = pgadst.attrelid and pgadst.attnum = ANY(pgcon.confkey) 726 where %s 727 order by pgcon.conname, source_table, source_column, dest_table, dest_column`, 728 strings.Join(whereConditions, " and "), 729 ) 730 731 var rows *sql.Rows 732 var err error 733 if rows, err = p.conn.Query(query, tableName, schema); err != nil { 734 return nil, err 735 } 736 737 for rows.Next() { 738 var fkey drivers.ForeignKey 739 var sourceTable string 740 741 fkey.Table = tableName 742 err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn) 743 if err != nil { 744 return nil, err 745 } 746 747 fkeys = append(fkeys, fkey) 748 } 749 750 if err = rows.Err(); err != nil { 751 return nil, err 752 } 753 754 return fkeys, nil 755 } 756 757 // TranslateColumnType converts postgres database types to Go types, for example 758 // "varchar" to "string" and "bigint" to "int64". It returns this parsed data 759 // as a Column object. 760 func (p *PostgresDriver) TranslateColumnType(c drivers.Column) drivers.Column { 761 if c.Nullable { 762 switch c.DBType { 763 case "bigint", "bigserial": 764 c.Type = "null.Int64" 765 case "integer", "serial": 766 c.Type = "null.Int" 767 case "oid": 768 c.Type = "null.Uint32" 769 case "smallint", "smallserial": 770 c.Type = "null.Int16" 771 case "decimal", "numeric": 772 c.Type = "types.NullDecimal" 773 case "double precision": 774 c.Type = "null.Float64" 775 case "real": 776 c.Type = "null.Float32" 777 case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": 778 c.Type = "null.String" 779 case `"char"`: 780 c.Type = "null.Byte" 781 case "bytea": 782 c.Type = "null.Bytes" 783 case "json", "jsonb": 784 c.Type = "null.JSON" 785 case "boolean": 786 c.Type = "null.Bool" 787 case "date", "time", "timestamp without time zone", "timestamp with time zone", "time without time zone", "time with time zone": 788 c.Type = "null.Time" 789 case "point": 790 c.Type = "pgeo.NullPoint" 791 case "line": 792 c.Type = "pgeo.NullLine" 793 case "lseg": 794 c.Type = "pgeo.NullLseg" 795 case "box": 796 c.Type = "pgeo.NullBox" 797 case "path": 798 c.Type = "pgeo.NullPath" 799 case "polygon": 800 c.Type = "pgeo.NullPolygon" 801 case "circle": 802 c.Type = "pgeo.NullCircle" 803 case "ARRAY": 804 var dbType string 805 c.Type, dbType = getArrayType(c) 806 // Make DBType something like ARRAYinteger for parsing with randomize.Struct 807 c.DBType += dbType 808 case "USER-DEFINED": 809 switch c.UDTName { 810 case "hstore": 811 c.Type = "types.HStore" 812 c.DBType = "hstore" 813 case "citext": 814 c.Type = "null.String" 815 default: 816 c.Type = "string" 817 fmt.Fprintf(os.Stderr, "warning: incompatible data type detected: %s\n", c.UDTName) 818 } 819 default: 820 if enumName := strmangle.ParseEnumName(c.DBType); enumName != "" && p.addEnumTypes { 821 c.Type = p.enumNullPrefix + strmangle.TitleCase(enumName) 822 } else { 823 c.Type = "null.String" 824 } 825 } 826 } else { 827 switch c.DBType { 828 case "bigint", "bigserial": 829 c.Type = "int64" 830 case "integer", "serial": 831 c.Type = "int" 832 case "oid": 833 c.Type = "uint32" 834 case "smallint", "smallserial": 835 c.Type = "int16" 836 case "decimal", "numeric": 837 c.Type = "types.Decimal" 838 case "double precision": 839 c.Type = "float64" 840 case "real": 841 c.Type = "float32" 842 case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": 843 c.Type = "string" 844 case `"char"`: 845 c.Type = "types.Byte" 846 case "json", "jsonb": 847 c.Type = "types.JSON" 848 case "bytea": 849 c.Type = "[]byte" 850 case "boolean": 851 c.Type = "bool" 852 case "date", "time", "timestamp without time zone", "timestamp with time zone", "time without time zone", "time with time zone": 853 c.Type = "time.Time" 854 case "point": 855 c.Type = "pgeo.Point" 856 case "line": 857 c.Type = "pgeo.Line" 858 case "lseg": 859 c.Type = "pgeo.Lseg" 860 case "box": 861 c.Type = "pgeo.Box" 862 case "path": 863 c.Type = "pgeo.Path" 864 case "polygon": 865 c.Type = "pgeo.Polygon" 866 case "circle": 867 c.Type = "pgeo.Circle" 868 case "ARRAY": 869 var dbType string 870 c.Type, dbType = getArrayType(c) 871 // Make DBType something like ARRAYinteger for parsing with randomize.Struct 872 c.DBType += dbType 873 case "USER-DEFINED": 874 switch c.UDTName { 875 case "hstore": 876 c.Type = "types.HStore" 877 c.DBType = "hstore" 878 case "citext": 879 c.Type = "string" 880 default: 881 c.Type = "string" 882 fmt.Fprintf(os.Stderr, "warning: incompatible data type detected: %s\n", c.UDTName) 883 } 884 default: 885 if enumName := strmangle.ParseEnumName(c.DBType); enumName != "" && p.addEnumTypes { 886 c.Type = strmangle.TitleCase(enumName) 887 } else { 888 c.Type = "string" 889 } 890 } 891 } 892 893 return c 894 } 895 896 // getArrayType returns the correct boil.Array type for each database type 897 func getArrayType(c drivers.Column) (string, string) { 898 // If a domain is created with a statement like this: "CREATE DOMAIN 899 // text_array AS TEXT[] CHECK ( ... )" then the array type will be null, 900 // but the udt name will be whatever the underlying type is with a leading 901 // underscore. Note that this code handles some types, but not nearly all 902 // the possibities. Notably, an array of a user-defined type ("CREATE 903 // DOMAIN my_array AS my_type[]") will be treated as an array of strings, 904 // which is not guaranteed to be correct. 905 if c.ArrType != nil { 906 switch *c.ArrType { 907 case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial", "oid": 908 return "types.Int64Array", *c.ArrType 909 case "bytea": 910 return "types.BytesArray", *c.ArrType 911 case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml": 912 return "types.StringArray", *c.ArrType 913 case "boolean": 914 return "types.BoolArray", *c.ArrType 915 case "decimal", "numeric": 916 return "types.DecimalArray", *c.ArrType 917 case "double precision", "real": 918 return "types.Float64Array", *c.ArrType 919 default: 920 return "types.StringArray", *c.ArrType 921 } 922 } else { 923 switch c.UDTName { 924 case "_int4", "_int8": 925 return "types.Int64Array", c.UDTName 926 case "_bytea": 927 return "types.BytesArray", c.UDTName 928 case "_bit", "_interval", "_varbit", "_char", "_money", "_varchar", "_cidr", "_inet", "_macaddr", "_citext", "_text", "_uuid", "_xml": 929 return "types.StringArray", c.UDTName 930 case "_bool": 931 return "types.BoolArray", c.UDTName 932 case "_numeric": 933 return "types.DecimalArray", c.UDTName 934 case "_float4", "_float8": 935 return "types.Float64Array", c.UDTName 936 default: 937 return "types.StringArray", c.UDTName 938 } 939 } 940 } 941 942 // Imports for the postgres driver 943 func (p PostgresDriver) Imports() (importers.Collection, error) { 944 var col importers.Collection 945 946 col.All = importers.Set{ 947 Standard: importers.List{ 948 `"strconv"`, 949 }, 950 } 951 col.Singleton = importers.Map{ 952 "psql_upsert": { 953 Standard: importers.List{ 954 `"fmt"`, 955 `"strings"`, 956 }, 957 ThirdParty: importers.List{ 958 `"github.com/volatiletech/strmangle"`, 959 `"github.com/volatiletech/sqlboiler/v4/drivers"`, 960 }, 961 }, 962 } 963 col.TestSingleton = importers.Map{ 964 "psql_suites_test": { 965 Standard: importers.List{ 966 `"testing"`, 967 }, 968 }, 969 "psql_main_test": { 970 Standard: importers.List{ 971 `"bytes"`, 972 `"database/sql"`, 973 `"fmt"`, 974 `"io"`, 975 `"os"`, 976 `"os/exec"`, 977 `"regexp"`, 978 `"strings"`, 979 }, 980 ThirdParty: importers.List{ 981 `"github.com/kat-co/vala"`, 982 `"github.com/friendsofgo/errors"`, 983 `"github.com/spf13/viper"`, 984 `"github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-psql/driver"`, 985 `"github.com/volatiletech/randomize"`, 986 `_ "github.com/lib/pq"`, 987 }, 988 }, 989 } 990 col.BasedOnType = importers.Map{ 991 "null.Float32": { 992 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 993 }, 994 "null.Float64": { 995 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 996 }, 997 "null.Int": { 998 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 999 }, 1000 "null.Int8": { 1001 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1002 }, 1003 "null.Int16": { 1004 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1005 }, 1006 "null.Int32": { 1007 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1008 }, 1009 "null.Int64": { 1010 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1011 }, 1012 "null.Uint": { 1013 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1014 }, 1015 "null.Uint8": { 1016 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1017 }, 1018 "null.Uint16": { 1019 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1020 }, 1021 "null.Uint32": { 1022 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1023 }, 1024 "null.Uint64": { 1025 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1026 }, 1027 "null.String": { 1028 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1029 }, 1030 "null.Bool": { 1031 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1032 }, 1033 "null.Time": { 1034 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1035 }, 1036 "null.JSON": { 1037 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1038 }, 1039 "null.Bytes": { 1040 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 1041 }, 1042 "time.Time": { 1043 Standard: importers.List{`"time"`}, 1044 }, 1045 "types.JSON": { 1046 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1047 }, 1048 "types.Decimal": { 1049 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1050 }, 1051 "types.BytesArray": { 1052 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1053 }, 1054 "types.Int64Array": { 1055 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1056 }, 1057 "types.Float64Array": { 1058 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1059 }, 1060 "types.BoolArray": { 1061 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1062 }, 1063 "types.StringArray": { 1064 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1065 }, 1066 "types.DecimalArray": { 1067 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1068 }, 1069 "types.HStore": { 1070 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1071 }, 1072 "pgeo.Point": { 1073 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1074 }, 1075 "pgeo.Line": { 1076 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1077 }, 1078 "pgeo.Lseg": { 1079 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1080 }, 1081 "pgeo.Box": { 1082 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1083 }, 1084 "pgeo.Path": { 1085 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1086 }, 1087 "pgeo.Polygon": { 1088 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1089 }, 1090 "types.NullDecimal": { 1091 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 1092 }, 1093 "pgeo.Circle": { 1094 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1095 }, 1096 "pgeo.NullPoint": { 1097 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1098 }, 1099 "pgeo.NullLine": { 1100 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1101 }, 1102 "pgeo.NullLseg": { 1103 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1104 }, 1105 "pgeo.NullBox": { 1106 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1107 }, 1108 "pgeo.NullPath": { 1109 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1110 }, 1111 "pgeo.NullPolygon": { 1112 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1113 }, 1114 "pgeo.NullCircle": { 1115 ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`}, 1116 }, 1117 } 1118 1119 return col, nil 1120 } 1121 1122 // getVersion gets the version of underlying database 1123 func (p *PostgresDriver) getVersion() (int, error) { 1124 type versionInfoType struct { 1125 ServerVersionNum int `json:"server_version_num"` 1126 } 1127 versionInfo := &versionInfoType{} 1128 1129 row := p.conn.QueryRow("SHOW server_version_num") 1130 if err := row.Scan(&versionInfo.ServerVersionNum); err != nil { 1131 return 0, err 1132 } 1133 1134 return versionInfo.ServerVersionNum, nil 1135 }