github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-psql/driver/psql.go (about)

     1  // Package driver implements an sqlboiler driver.
     2  // It can be used by either building the main.go in the same project
     3  // and using as a binary or using the side effect import.
     4  package driver
     5  
     6  import (
     7  	"database/sql"
     8  	"embed"
     9  	"encoding/base64"
    10  	"fmt"
    11  	"io/fs"
    12  	"os"
    13  	"strings"
    14  
    15  	"github.com/volatiletech/sqlboiler/v4/importers"
    16  
    17  	"github.com/friendsofgo/errors"
    18  	"github.com/volatiletech/sqlboiler/v4/drivers"
    19  	"github.com/volatiletech/strmangle"
    20  
    21  	// Side-effect import sql driver
    22  	_ "github.com/lib/pq"
    23  )
    24  
    25  //go:embed override
    26  var templates embed.FS
    27  
    28  func init() {
    29  	drivers.RegisterFromInit("psql", &PostgresDriver{})
    30  }
    31  
    32  // Assemble is more useful for calling into the library so you don't
    33  // have to instantiate an empty type.
    34  func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
    35  	driver := PostgresDriver{}
    36  	return driver.Assemble(config)
    37  }
    38  
    39  // PostgresDriver holds the database connection string and a handle
    40  // to the database connection.
    41  type PostgresDriver struct {
    42  	connStr        string
    43  	conn           *sql.DB
    44  	version        int
    45  	addEnumTypes   bool
    46  	enumNullPrefix string
    47  
    48  	uniqueColumns map[columnIdentifier]struct{}
    49  }
    50  
    51  type columnIdentifier struct {
    52  	Schema string
    53  	Table  string
    54  	Column string
    55  }
    56  
    57  // Templates that should be added/overridden
    58  func (p *PostgresDriver) Templates() (map[string]string, error) {
    59  	tpls := make(map[string]string)
    60  	fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error {
    61  		if err != nil {
    62  			return err
    63  		}
    64  
    65  		if d.IsDir() {
    66  			return nil
    67  		}
    68  
    69  		b, err := fs.ReadFile(templates, path)
    70  		if err != nil {
    71  			return err
    72  		}
    73  		tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b)
    74  
    75  		return nil
    76  	})
    77  
    78  	return tpls, nil
    79  }
    80  
    81  // Assemble all the information we need to provide back to the driver
    82  func (p *PostgresDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
    83  	defer func() {
    84  		if r := recover(); r != nil && err == nil {
    85  			dbinfo = nil
    86  			err = r.(error)
    87  		}
    88  	}()
    89  
    90  	user := config.MustString(drivers.ConfigUser)
    91  	pass, _ := config.String(drivers.ConfigPass)
    92  	dbname := config.MustString(drivers.ConfigDBName)
    93  	host := config.MustString(drivers.ConfigHost)
    94  	port := config.DefaultInt(drivers.ConfigPort, 5432)
    95  	sslmode := config.DefaultString(drivers.ConfigSSLMode, "require")
    96  	schema := config.DefaultString(drivers.ConfigSchema, "public")
    97  	whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
    98  	blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
    99  	concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency)
   100  
   101  	useSchema := schema != "public"
   102  
   103  	p.addEnumTypes, _ = config[drivers.ConfigAddEnumTypes].(bool)
   104  	p.enumNullPrefix = strmangle.TitleCase(config.DefaultString(drivers.ConfigEnumNullPrefix, "Null"))
   105  	p.connStr = PSQLBuildQueryString(user, pass, dbname, host, port, sslmode)
   106  	p.conn, err = sql.Open("postgres", p.connStr)
   107  	if err != nil {
   108  		return nil, errors.Wrap(err, "sqlboiler-psql failed to connect to database")
   109  	}
   110  
   111  	defer func() {
   112  		if e := p.conn.Close(); e != nil {
   113  			dbinfo = nil
   114  			err = e
   115  		}
   116  	}()
   117  
   118  	p.version, err = p.getVersion()
   119  	if err != nil {
   120  		return nil, errors.Wrap(err, "sqlboiler-psql failed to get database version")
   121  	}
   122  
   123  	if err = p.loadUniqueColumns(); err != nil {
   124  		return nil, errors.Wrap(err, "sqlboiler-psql failed to load unique columns")
   125  	}
   126  
   127  	dbinfo = &drivers.DBInfo{
   128  		Schema: schema,
   129  		Dialect: drivers.Dialect{
   130  			LQ: '"',
   131  			RQ: '"',
   132  
   133  			UseIndexPlaceholders: true,
   134  			UseSchema:            useSchema,
   135  			UseDefaultKeyword:    true,
   136  		},
   137  	}
   138  	dbinfo.Tables, err = drivers.TablesConcurrently(p, schema, whitelist, blacklist, concurrency)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return dbinfo, err
   144  }
   145  
   146  // PSQLBuildQueryString builds a query string.
   147  func PSQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
   148  	parts := []string{}
   149  	if len(user) != 0 {
   150  		parts = append(parts, fmt.Sprintf("user=%s", user))
   151  	}
   152  	if len(pass) != 0 {
   153  		parts = append(parts, fmt.Sprintf("password=%s", pass))
   154  	}
   155  	if len(dbname) != 0 {
   156  		parts = append(parts, fmt.Sprintf("dbname=%s", dbname))
   157  	}
   158  	if len(host) != 0 {
   159  		parts = append(parts, fmt.Sprintf("host=%s", host))
   160  	}
   161  	if port != 0 {
   162  		parts = append(parts, fmt.Sprintf("port=%d", port))
   163  	}
   164  	if len(sslmode) != 0 {
   165  		parts = append(parts, fmt.Sprintf("sslmode=%s", sslmode))
   166  	}
   167  
   168  	return strings.Join(parts, " ")
   169  }
   170  
   171  // TableNames connects to the postgres database and
   172  // retrieves all table names from the information_schema where the
   173  // table schema is schema. It uses a whitelist and blacklist.
   174  func (p *PostgresDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
   175  	var names []string
   176  
   177  	query := `select table_name from information_schema.tables where table_schema = $1 and table_type = 'BASE TABLE'`
   178  	args := []interface{}{schema}
   179  	if len(whitelist) > 0 {
   180  		tables := drivers.TablesFromList(whitelist)
   181  		if len(tables) > 0 {
   182  			query += fmt.Sprintf(" and table_name in (%s)", strmangle.Placeholders(true, len(tables), 2, 1))
   183  			for _, w := range tables {
   184  				args = append(args, w)
   185  			}
   186  		}
   187  	} else if len(blacklist) > 0 {
   188  		tables := drivers.TablesFromList(blacklist)
   189  		if len(tables) > 0 {
   190  			query += fmt.Sprintf(" and table_name not in (%s)", strmangle.Placeholders(true, len(tables), 2, 1))
   191  			for _, b := range tables {
   192  				args = append(args, b)
   193  			}
   194  		}
   195  	}
   196  
   197  	query += ` order by table_name;`
   198  
   199  	rows, err := p.conn.Query(query, args...)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	defer rows.Close()
   205  	for rows.Next() {
   206  		var name string
   207  		if err := rows.Scan(&name); err != nil {
   208  			return nil, err
   209  		}
   210  		names = append(names, name)
   211  	}
   212  
   213  	return names, nil
   214  }
   215  
   216  // ViewNames connects to the postgres database and
   217  // retrieves all view names from the information_schema where the
   218  // view schema is schema. It uses a whitelist and blacklist.
   219  func (p *PostgresDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) {
   220  	var names []string
   221  
   222  	query := `select 
   223  		table_name 
   224  	from (
   225  			select 
   226  				table_name, 
   227  				table_schema 
   228  			from information_schema.views
   229  			UNION
   230  			select 
   231  				matviewname as table_name, 
   232  				schemaname as table_schema 
   233  			from pg_matviews 
   234  	) as v where v.table_schema= $1`
   235  	args := []interface{}{schema}
   236  	if len(whitelist) > 0 {
   237  		views := drivers.TablesFromList(whitelist)
   238  		if len(views) > 0 {
   239  			query += fmt.Sprintf(" and table_name in (%s)", strmangle.Placeholders(true, len(views), 2, 1))
   240  			for _, w := range views {
   241  				args = append(args, w)
   242  			}
   243  		}
   244  	} else if len(blacklist) > 0 {
   245  		views := drivers.TablesFromList(blacklist)
   246  		if len(views) > 0 {
   247  			query += fmt.Sprintf(" and table_name not in (%s)", strmangle.Placeholders(true, len(views), 2, 1))
   248  			for _, b := range views {
   249  				args = append(args, b)
   250  			}
   251  		}
   252  	}
   253  
   254  	query += ` order by table_name;`
   255  
   256  	rows, err := p.conn.Query(query, args...)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	defer rows.Close()
   262  	for rows.Next() {
   263  		var name string
   264  		if err := rows.Scan(&name); err != nil {
   265  			return nil, err
   266  		}
   267  
   268  		names = append(names, name)
   269  	}
   270  
   271  	return names, nil
   272  }
   273  
   274  // ViewCapabilities return what actions are allowed for a view.
   275  func (p *PostgresDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) {
   276  	capabilities := drivers.ViewCapabilities{}
   277  
   278  	query := `select 
   279  		is_insertable_into,
   280  		is_updatable,
   281  		is_trigger_insertable_into,
   282  		is_trigger_updatable,
   283  		is_trigger_deletable
   284  	from (
   285  		select
   286  			table_schema,
   287  			table_name,
   288  			is_insertable_into = 'YES' as is_insertable_into,
   289  			is_updatable = 'YES' as is_updatable,
   290  			is_trigger_insertable_into = 'YES' as is_trigger_insertable_into,
   291  			is_trigger_updatable = 'YES' as is_trigger_updatable,
   292  			is_trigger_deletable = 'YES' as is_trigger_deletable
   293  		from information_schema.views
   294  		UNION
   295  		select 
   296  			schemaname as table_schema,
   297  			matviewname as table_name, 
   298  			false as is_insertable_into,
   299  			false as is_updatable,
   300  			false as is_trigger_insertable_into,
   301  			false as is_trigger_updatable, 
   302  			false as is_trigger_deletable
   303  		from pg_matviews 
   304  	) as v where v.table_schema= $1 and v.table_name = $2 
   305  	order by table_name;`
   306  
   307  	row := p.conn.QueryRow(query, schema, name)
   308  
   309  	var insertable, updatable, trInsert, trUpdate, trDelete bool
   310  	if err := row.Scan(&insertable, &updatable, &trInsert, &trUpdate, &trDelete); err != nil {
   311  		return capabilities, err
   312  	}
   313  
   314  	capabilities.CanInsert = insertable || trInsert
   315  	capabilities.CanUpsert = insertable && updatable
   316  
   317  	return capabilities, nil
   318  }
   319  
   320  // loadUniqueColumns is responsible for populating p.uniqueColumns with an entry
   321  // for every table or view column that is made unique by an index or constraint.
   322  // This information is queried once, rather than for each table, for performance
   323  // reasons.
   324  func (p *PostgresDriver) loadUniqueColumns() error {
   325  	if p.uniqueColumns != nil {
   326  		return nil
   327  	}
   328  	p.uniqueColumns = map[columnIdentifier]struct{}{}
   329  	query := `with
   330  method_a as (
   331      select
   332          tc.table_schema as schema_name,
   333          ccu.table_name as table_name,
   334          ccu.column_name as column_name
   335      from information_schema.table_constraints tc
   336      inner join information_schema.constraint_column_usage as ccu
   337          on tc.constraint_name = ccu.constraint_name
   338      where
   339          tc.constraint_type = 'UNIQUE' and (
   340              (select count(*)
   341              from information_schema.constraint_column_usage
   342              where constraint_schema = tc.table_schema and constraint_name = tc.constraint_name
   343              ) = 1
   344          )
   345  ),
   346  method_b as (
   347      select
   348          pgix.schemaname as schema_name,
   349          pgix.tablename as table_name,
   350          pga.attname as column_name
   351      from pg_indexes pgix
   352      inner join pg_class pgc on pgix.indexname = pgc.relname and pgc.relkind = 'i' and pgc.relnatts = 1
   353      inner join pg_index pgi on pgi.indexrelid = pgc.oid
   354      inner join pg_attribute pga on pga.attrelid = pgi.indrelid and pga.attnum = ANY(pgi.indkey)
   355      where pgi.indisunique = true
   356  ),
   357  results as (
   358      select * from method_a
   359      union
   360      select * from method_b
   361  )
   362  select * from results;
   363  `
   364  	rows, err := p.conn.Query(query)
   365  	if err != nil {
   366  		return err
   367  	}
   368  	defer rows.Close()
   369  
   370  	for rows.Next() {
   371  		var c columnIdentifier
   372  		if err := rows.Scan(&c.Schema, &c.Table, &c.Column); err != nil {
   373  			return errors.Wrapf(err, "unable to scan unique entry row")
   374  		}
   375  		p.uniqueColumns[c] = struct{}{}
   376  	}
   377  	return nil
   378  }
   379  
   380  func (p *PostgresDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) {
   381  	return p.Columns(schema, tableName, whitelist, blacklist)
   382  }
   383  
   384  // Columns takes a table name and attempts to retrieve the table information
   385  // from the database information_schema.columns. It retrieves the column names
   386  // and column types and returns those as a []Column after TranslateColumnType()
   387  // converts the SQL types to Go types, for example: "varchar" to "string"
   388  func (p *PostgresDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) {
   389  	var columns []drivers.Column
   390  	args := []interface{}{schema, tableName}
   391  
   392  	matviewQuery := `WITH cte_pg_attribute AS (
   393  		SELECT
   394  			pg_catalog.format_type(a.atttypid, NULL) LIKE '%[]' = TRUE as is_array,
   395  			pg_catalog.format_type(a.atttypid, a.atttypmod) as column_full_type,
   396  			a.*
   397  		FROM pg_attribute a
   398  	), cte_pg_namespace AS (
   399  		SELECT
   400  			n.nspname NOT IN ('pg_catalog', 'information_schema') = TRUE as is_user_defined,
   401  			n.oid
   402  		FROM pg_namespace n
   403  	), cte_information_schema_domains AS (
   404  		SELECT
   405  			domain_name IS NOT NULL = TRUE as is_domain,
   406  			data_type LIKE '%[]' = TRUE as is_array,
   407  			domain_name,
   408  			udt_name,
   409  			data_type
   410  		FROM information_schema.domains
   411  	)
   412  	SELECT 
   413  		a.attnum as ordinal_position,
   414  		a.attname as column_name,
   415  		(
   416  			case 
   417  			when t.typtype = 'e'
   418  			then (
   419  				select 'enum.' || t.typname || '(''' || string_agg(labels.label, ''',''') || ''')'
   420  				from (
   421  					select pg_enum.enumlabel as label
   422  					from pg_enum
   423  					where pg_enum.enumtypid =
   424  					(
   425  						select typelem
   426  						from pg_type
   427  						inner join pg_namespace ON pg_type.typnamespace = pg_namespace.oid
   428  						where pg_type.typtype = 'b' and pg_type.typname = ('_' || t.typname) and pg_namespace.nspname=$1
   429  						limit 1
   430  					)
   431  					order by pg_enum.enumsortorder
   432  				) as labels
   433  			)
   434  			when a.is_array OR d.is_array
   435  			then 'ARRAY'
   436  			when d.is_domain
   437  			then d.data_type
   438  			when tn.is_user_defined
   439  			then 'USER-DEFINED'
   440  			else pg_catalog.format_type(a.atttypid, NULL)
   441  			end
   442  		) as column_type,
   443  		(
   444  			case 
   445  			when d.is_domain
   446  			then d.udt_name		
   447  			when a.column_full_type LIKE '%(%)%' AND t.typcategory IN ('S', 'V')
   448  			then a.column_full_type
   449  			else t.typname
   450  			end
   451  		) as column_full_type,
   452  		(
   453  			case 
   454  			when d.is_domain
   455  			then d.udt_name		
   456  			else t.typname
   457  			end
   458  		) as udt_name,
   459  		(
   460  			case when a.is_array
   461  			then
   462  				case when tn.is_user_defined
   463  				then 'USER-DEFINED'
   464  				else RTRIM(pg_catalog.format_type(a.atttypid, NULL), '[]')
   465  				end
   466  			else NULL
   467  			end
   468  		) as array_type,
   469  		d.domain_name,
   470  		NULL as column_default,
   471  		'' as column_comment,
   472  		a.attnotnull = FALSE as is_nullable,
   473  		FALSE as is_generated,
   474  		a.attidentity <> '' as is_identity
   475  	FROM cte_pg_attribute a
   476  		JOIN pg_class c on a.attrelid = c.oid
   477  		JOIN pg_namespace cn on c.relnamespace = cn.oid
   478  		JOIN pg_type t ON t.oid = a.atttypid
   479  		LEFT JOIN cte_pg_namespace tn ON t.typnamespace = tn.oid
   480  		LEFT JOIN cte_information_schema_domains d ON d.domain_name = pg_catalog.format_type(a.atttypid, NULL)
   481  		WHERE a.attnum > 0 
   482  		AND c.relkind = 'm'
   483  		AND NOT a.attisdropped
   484  		AND c.relname = $2
   485  		AND cn.nspname = $1`
   486  
   487  	tableQuery := `
   488  	select
   489  		c.ordinal_position,
   490  		c.column_name,
   491  		ct.column_type,
   492  		(
   493  			case when c.character_maximum_length != 0
   494  			then
   495  			(
   496  				ct.column_type || '(' || c.character_maximum_length || ')'
   497  			)
   498  			else c.udt_name
   499  			end
   500  		) as column_full_type,
   501  
   502  		c.udt_name,
   503  		(
   504  			SELECT
   505  				data_type
   506  			FROM
   507  				information_schema.element_types e
   508  			WHERE
   509  				c.table_catalog = e.object_catalog
   510  				AND c.table_schema = e.object_schema
   511  				AND c.table_name = e.object_name
   512  				AND 'TABLE' = e.object_type
   513  				AND c.dtd_identifier = e.collection_type_identifier
   514  		) AS array_type,
   515  		c.domain_name,
   516  		c.column_default,
   517  
   518  		COALESCE(col_description(('"'||c.table_schema||'"."'||c.table_name||'"')::regclass::oid, ordinal_position), '') as column_comment,
   519  
   520  		c.is_nullable = 'YES' as is_nullable,
   521  		(
   522  				case when c.is_generated = 'ALWAYS' or c.identity_generation = 'ALWAYS'
   523  				then TRUE else FALSE end
   524  		) as is_generated,
   525  		(case
   526  			when (select
   527  		    case
   528  			    when column_name = 'is_identity' then (select c.is_identity = 'YES' as is_identity)
   529  		    else
   530  			    false
   531  		    end as is_identity from information_schema.columns
   532  		    WHERE table_schema='information_schema' and table_name='columns' and column_name='is_identity') IS NULL then 'NO' else is_identity end
   533  		) = 'YES' as is_identity
   534  
   535  		from information_schema.columns as c
   536  		inner join pg_namespace as pgn on pgn.nspname = c.udt_schema
   537  		left join pg_type pgt on c.data_type = 'USER-DEFINED' and pgn.oid = pgt.typnamespace and c.udt_name = pgt.typname,
   538  		lateral (select
   539  			(
   540  				case when pgt.typtype = 'e'
   541  				then
   542  				(
   543  					select 'enum.' || c.udt_name || '(''' || string_agg(labels.label, ''',''') || ''')'
   544  					from (
   545  						select pg_enum.enumlabel as label
   546  						from pg_enum
   547  						where pg_enum.enumtypid =
   548  						(
   549  							select typelem
   550  							from pg_type
   551  							inner join pg_namespace ON pg_type.typnamespace = pg_namespace.oid
   552  							where pg_type.typtype = 'b' and pg_type.typname = ('_' || c.udt_name) and pg_namespace.nspname=$1
   553  							limit 1
   554  						)
   555  						order by pg_enum.enumsortorder
   556  					) as labels
   557  				)
   558  				else c.data_type
   559  				end
   560  			) as column_type
   561  		) ct
   562  		where c.table_name = $2 and c.table_schema = $1`
   563  
   564  	query := fmt.Sprintf(`SELECT 
   565  		column_name,
   566  		column_type,
   567  		column_full_type,
   568  		udt_name,
   569  		array_type,
   570  		domain_name,
   571  		column_default,
   572  		column_comment,
   573  		is_nullable,
   574  		is_generated,
   575  		is_identity
   576  	FROM (
   577  		%s
   578  		UNION
   579  		%s
   580  	) AS c`, matviewQuery, tableQuery)
   581  
   582  	if len(whitelist) > 0 {
   583  		cols := drivers.ColumnsFromList(whitelist, tableName)
   584  		if len(cols) > 0 {
   585  			query += fmt.Sprintf(" where c.column_name in (%s)", strmangle.Placeholders(true, len(cols), 3, 1))
   586  			for _, w := range cols {
   587  				args = append(args, w)
   588  			}
   589  		}
   590  	} else if len(blacklist) > 0 {
   591  		cols := drivers.ColumnsFromList(blacklist, tableName)
   592  		if len(cols) > 0 {
   593  			query += fmt.Sprintf(" where c.column_name not in (%s)", strmangle.Placeholders(true, len(cols), 3, 1))
   594  			for _, w := range cols {
   595  				args = append(args, w)
   596  			}
   597  		}
   598  	}
   599  
   600  	query += ` order by c.ordinal_position;`
   601  
   602  	rows, err := p.conn.Query(query, args...)
   603  	if err != nil {
   604  		return nil, err
   605  	}
   606  	defer rows.Close()
   607  
   608  	for rows.Next() {
   609  		var colName, colType, colFullType, udtName, comment string
   610  		var defaultValue, arrayType, domainName *string
   611  		var nullable, generated, identity bool
   612  		if err := rows.Scan(&colName, &colType, &colFullType, &udtName, &arrayType, &domainName, &defaultValue, &comment, &nullable, &generated, &identity); err != nil {
   613  			return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
   614  		}
   615  
   616  		_, unique := p.uniqueColumns[columnIdentifier{schema, tableName, colName}]
   617  		column := drivers.Column{
   618  			Name:          colName,
   619  			DBType:        colType,
   620  			FullDBType:    colFullType,
   621  			ArrType:       arrayType,
   622  			DomainName:    domainName,
   623  			UDTName:       udtName,
   624  			Comment:       comment,
   625  			Nullable:      nullable,
   626  			AutoGenerated: generated,
   627  			Unique:        unique,
   628  		}
   629  		if defaultValue != nil {
   630  			column.Default = *defaultValue
   631  		}
   632  
   633  		if identity {
   634  			column.Default = "IDENTITY"
   635  		}
   636  
   637  		// A generated column technically has a default value
   638  		if generated && column.Default == "" {
   639  			column.Default = "GENERATED"
   640  		}
   641  
   642  		// A nullable column can always default to NULL
   643  		if nullable && column.Default == "" {
   644  			column.Default = "NULL"
   645  		}
   646  
   647  		columns = append(columns, column)
   648  	}
   649  
   650  	return columns, nil
   651  }
   652  
   653  // PrimaryKeyInfo looks up the primary key for a table.
   654  func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) {
   655  	pkey := &drivers.PrimaryKey{}
   656  	var err error
   657  
   658  	query := `
   659  	select tc.constraint_name
   660  	from information_schema.table_constraints as tc
   661  	where tc.table_name = $1 and tc.constraint_type = 'PRIMARY KEY' and tc.table_schema = $2;`
   662  
   663  	row := p.conn.QueryRow(query, tableName, schema)
   664  	if err = row.Scan(&pkey.Name); err != nil {
   665  		if errors.Is(err, sql.ErrNoRows) {
   666  			return nil, nil
   667  		}
   668  		return nil, err
   669  	}
   670  
   671  	queryColumns := `
   672  	select kcu.column_name
   673  	from   information_schema.key_column_usage as kcu
   674  	where  constraint_name = $1 and table_name = $2 and table_schema = $3
   675  	order by kcu.ordinal_position;`
   676  
   677  	var rows *sql.Rows
   678  	if rows, err = p.conn.Query(queryColumns, pkey.Name, tableName, schema); err != nil {
   679  		return nil, err
   680  	}
   681  	defer rows.Close()
   682  
   683  	var columns []string
   684  	for rows.Next() {
   685  		var column string
   686  
   687  		err = rows.Scan(&column)
   688  		if err != nil {
   689  			return nil, err
   690  		}
   691  
   692  		columns = append(columns, column)
   693  	}
   694  
   695  	if err = rows.Err(); err != nil {
   696  		return nil, err
   697  	}
   698  
   699  	pkey.Columns = columns
   700  
   701  	return pkey, nil
   702  }
   703  
   704  // ForeignKeyInfo retrieves the foreign keys for a given table name.
   705  func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) {
   706  	var fkeys []drivers.ForeignKey
   707  
   708  	whereConditions := []string{"pgn.nspname = $2", "pgc.relname = $1", "pgcon.contype = 'f'"}
   709  	if p.version >= 120000 {
   710  		whereConditions = append(whereConditions, "pgasrc.attgenerated = ''", "pgadst.attgenerated = ''")
   711  	}
   712  
   713  	query := fmt.Sprintf(`
   714  	select
   715  		pgcon.conname,
   716  		pgc.relname as source_table,
   717  		pgasrc.attname as source_column,
   718  		dstlookupname.relname as dest_table,
   719  		pgadst.attname as dest_column
   720  	from pg_namespace pgn
   721  		inner join pg_class pgc on pgn.oid = pgc.relnamespace and pgc.relkind = 'r'
   722  		inner join pg_constraint pgcon on pgn.oid = pgcon.connamespace and pgc.oid = pgcon.conrelid
   723  		inner join pg_class dstlookupname on pgcon.confrelid = dstlookupname.oid
   724  		inner join pg_attribute pgasrc on pgc.oid = pgasrc.attrelid and pgasrc.attnum = ANY(pgcon.conkey)
   725  		inner join pg_attribute pgadst on pgcon.confrelid = pgadst.attrelid and pgadst.attnum = ANY(pgcon.confkey)
   726  	where %s
   727  	order by pgcon.conname, source_table, source_column, dest_table, dest_column`,
   728  		strings.Join(whereConditions, " and "),
   729  	)
   730  
   731  	var rows *sql.Rows
   732  	var err error
   733  	if rows, err = p.conn.Query(query, tableName, schema); err != nil {
   734  		return nil, err
   735  	}
   736  
   737  	for rows.Next() {
   738  		var fkey drivers.ForeignKey
   739  		var sourceTable string
   740  
   741  		fkey.Table = tableName
   742  		err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
   743  		if err != nil {
   744  			return nil, err
   745  		}
   746  
   747  		fkeys = append(fkeys, fkey)
   748  	}
   749  
   750  	if err = rows.Err(); err != nil {
   751  		return nil, err
   752  	}
   753  
   754  	return fkeys, nil
   755  }
   756  
   757  // TranslateColumnType converts postgres database types to Go types, for example
   758  // "varchar" to "string" and "bigint" to "int64". It returns this parsed data
   759  // as a Column object.
   760  func (p *PostgresDriver) TranslateColumnType(c drivers.Column) drivers.Column {
   761  	if c.Nullable {
   762  		switch c.DBType {
   763  		case "bigint", "bigserial":
   764  			c.Type = "null.Int64"
   765  		case "integer", "serial":
   766  			c.Type = "null.Int"
   767  		case "oid":
   768  			c.Type = "null.Uint32"
   769  		case "smallint", "smallserial":
   770  			c.Type = "null.Int16"
   771  		case "decimal", "numeric":
   772  			c.Type = "types.NullDecimal"
   773  		case "double precision":
   774  			c.Type = "null.Float64"
   775  		case "real":
   776  			c.Type = "null.Float32"
   777  		case "bit", "interval", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
   778  			c.Type = "null.String"
   779  		case `"char"`:
   780  			c.Type = "null.Byte"
   781  		case "bytea":
   782  			c.Type = "null.Bytes"
   783  		case "json", "jsonb":
   784  			c.Type = "null.JSON"
   785  		case "boolean":
   786  			c.Type = "null.Bool"
   787  		case "date", "time", "timestamp without time zone", "timestamp with time zone", "time without time zone", "time with time zone":
   788  			c.Type = "null.Time"
   789  		case "point":
   790  			c.Type = "pgeo.NullPoint"
   791  		case "line":
   792  			c.Type = "pgeo.NullLine"
   793  		case "lseg":
   794  			c.Type = "pgeo.NullLseg"
   795  		case "box":
   796  			c.Type = "pgeo.NullBox"
   797  		case "path":
   798  			c.Type = "pgeo.NullPath"
   799  		case "polygon":
   800  			c.Type = "pgeo.NullPolygon"
   801  		case "circle":
   802  			c.Type = "pgeo.NullCircle"
   803  		case "ARRAY":
   804  			var dbType string
   805  			c.Type, dbType = getArrayType(c)
   806  			// Make DBType something like ARRAYinteger for parsing with randomize.Struct
   807  			c.DBType += dbType
   808  		case "USER-DEFINED":
   809  			switch c.UDTName {
   810  			case "hstore":
   811  				c.Type = "types.HStore"
   812  				c.DBType = "hstore"
   813  			case "citext":
   814  				c.Type = "null.String"
   815  			default:
   816  				c.Type = "string"
   817  				fmt.Fprintf(os.Stderr, "warning: incompatible data type detected: %s\n", c.UDTName)
   818  			}
   819  		default:
   820  			if enumName := strmangle.ParseEnumName(c.DBType); enumName != "" && p.addEnumTypes {
   821  				c.Type = p.enumNullPrefix + strmangle.TitleCase(enumName)
   822  			} else {
   823  				c.Type = "null.String"
   824  			}
   825  		}
   826  	} else {
   827  		switch c.DBType {
   828  		case "bigint", "bigserial":
   829  			c.Type = "int64"
   830  		case "integer", "serial":
   831  			c.Type = "int"
   832  		case "oid":
   833  			c.Type = "uint32"
   834  		case "smallint", "smallserial":
   835  			c.Type = "int16"
   836  		case "decimal", "numeric":
   837  			c.Type = "types.Decimal"
   838  		case "double precision":
   839  			c.Type = "float64"
   840  		case "real":
   841  			c.Type = "float32"
   842  		case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
   843  			c.Type = "string"
   844  		case `"char"`:
   845  			c.Type = "types.Byte"
   846  		case "json", "jsonb":
   847  			c.Type = "types.JSON"
   848  		case "bytea":
   849  			c.Type = "[]byte"
   850  		case "boolean":
   851  			c.Type = "bool"
   852  		case "date", "time", "timestamp without time zone", "timestamp with time zone", "time without time zone", "time with time zone":
   853  			c.Type = "time.Time"
   854  		case "point":
   855  			c.Type = "pgeo.Point"
   856  		case "line":
   857  			c.Type = "pgeo.Line"
   858  		case "lseg":
   859  			c.Type = "pgeo.Lseg"
   860  		case "box":
   861  			c.Type = "pgeo.Box"
   862  		case "path":
   863  			c.Type = "pgeo.Path"
   864  		case "polygon":
   865  			c.Type = "pgeo.Polygon"
   866  		case "circle":
   867  			c.Type = "pgeo.Circle"
   868  		case "ARRAY":
   869  			var dbType string
   870  			c.Type, dbType = getArrayType(c)
   871  			// Make DBType something like ARRAYinteger for parsing with randomize.Struct
   872  			c.DBType += dbType
   873  		case "USER-DEFINED":
   874  			switch c.UDTName {
   875  			case "hstore":
   876  				c.Type = "types.HStore"
   877  				c.DBType = "hstore"
   878  			case "citext":
   879  				c.Type = "string"
   880  			default:
   881  				c.Type = "string"
   882  				fmt.Fprintf(os.Stderr, "warning: incompatible data type detected: %s\n", c.UDTName)
   883  			}
   884  		default:
   885  			if enumName := strmangle.ParseEnumName(c.DBType); enumName != "" && p.addEnumTypes {
   886  				c.Type = strmangle.TitleCase(enumName)
   887  			} else {
   888  				c.Type = "string"
   889  			}
   890  		}
   891  	}
   892  
   893  	return c
   894  }
   895  
   896  // getArrayType returns the correct boil.Array type for each database type
   897  func getArrayType(c drivers.Column) (string, string) {
   898  	// If a domain is created with a statement like this: "CREATE DOMAIN
   899  	// text_array AS TEXT[] CHECK ( ... )" then the array type will be null,
   900  	// but the udt name will be whatever the underlying type is with a leading
   901  	// underscore. Note that this code handles some types, but not nearly all
   902  	// the possibities. Notably, an array of a user-defined type ("CREATE
   903  	// DOMAIN my_array AS my_type[]") will be treated as an array of strings,
   904  	// which is not guaranteed to be correct.
   905  	if c.ArrType != nil {
   906  		switch *c.ArrType {
   907  		case "bigint", "bigserial", "integer", "serial", "smallint", "smallserial", "oid":
   908  			return "types.Int64Array", *c.ArrType
   909  		case "bytea":
   910  			return "types.BytesArray", *c.ArrType
   911  		case "bit", "interval", "uuint", "bit varying", "character", "money", "character varying", "cidr", "inet", "macaddr", "text", "uuid", "xml":
   912  			return "types.StringArray", *c.ArrType
   913  		case "boolean":
   914  			return "types.BoolArray", *c.ArrType
   915  		case "decimal", "numeric":
   916  			return "types.DecimalArray", *c.ArrType
   917  		case "double precision", "real":
   918  			return "types.Float64Array", *c.ArrType
   919  		default:
   920  			return "types.StringArray", *c.ArrType
   921  		}
   922  	} else {
   923  		switch c.UDTName {
   924  		case "_int4", "_int8":
   925  			return "types.Int64Array", c.UDTName
   926  		case "_bytea":
   927  			return "types.BytesArray", c.UDTName
   928  		case "_bit", "_interval", "_varbit", "_char", "_money", "_varchar", "_cidr", "_inet", "_macaddr", "_citext", "_text", "_uuid", "_xml":
   929  			return "types.StringArray", c.UDTName
   930  		case "_bool":
   931  			return "types.BoolArray", c.UDTName
   932  		case "_numeric":
   933  			return "types.DecimalArray", c.UDTName
   934  		case "_float4", "_float8":
   935  			return "types.Float64Array", c.UDTName
   936  		default:
   937  			return "types.StringArray", c.UDTName
   938  		}
   939  	}
   940  }
   941  
   942  // Imports for the postgres driver
   943  func (p PostgresDriver) Imports() (importers.Collection, error) {
   944  	var col importers.Collection
   945  
   946  	col.All = importers.Set{
   947  		Standard: importers.List{
   948  			`"strconv"`,
   949  		},
   950  	}
   951  	col.Singleton = importers.Map{
   952  		"psql_upsert": {
   953  			Standard: importers.List{
   954  				`"fmt"`,
   955  				`"strings"`,
   956  			},
   957  			ThirdParty: importers.List{
   958  				`"github.com/volatiletech/strmangle"`,
   959  				`"github.com/volatiletech/sqlboiler/v4/drivers"`,
   960  			},
   961  		},
   962  	}
   963  	col.TestSingleton = importers.Map{
   964  		"psql_suites_test": {
   965  			Standard: importers.List{
   966  				`"testing"`,
   967  			},
   968  		},
   969  		"psql_main_test": {
   970  			Standard: importers.List{
   971  				`"bytes"`,
   972  				`"database/sql"`,
   973  				`"fmt"`,
   974  				`"io"`,
   975  				`"os"`,
   976  				`"os/exec"`,
   977  				`"regexp"`,
   978  				`"strings"`,
   979  			},
   980  			ThirdParty: importers.List{
   981  				`"github.com/kat-co/vala"`,
   982  				`"github.com/friendsofgo/errors"`,
   983  				`"github.com/spf13/viper"`,
   984  				`"github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-psql/driver"`,
   985  				`"github.com/volatiletech/randomize"`,
   986  				`_ "github.com/lib/pq"`,
   987  			},
   988  		},
   989  	}
   990  	col.BasedOnType = importers.Map{
   991  		"null.Float32": {
   992  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   993  		},
   994  		"null.Float64": {
   995  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   996  		},
   997  		"null.Int": {
   998  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   999  		},
  1000  		"null.Int8": {
  1001  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1002  		},
  1003  		"null.Int16": {
  1004  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1005  		},
  1006  		"null.Int32": {
  1007  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1008  		},
  1009  		"null.Int64": {
  1010  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1011  		},
  1012  		"null.Uint": {
  1013  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1014  		},
  1015  		"null.Uint8": {
  1016  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1017  		},
  1018  		"null.Uint16": {
  1019  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1020  		},
  1021  		"null.Uint32": {
  1022  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1023  		},
  1024  		"null.Uint64": {
  1025  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1026  		},
  1027  		"null.String": {
  1028  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1029  		},
  1030  		"null.Bool": {
  1031  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1032  		},
  1033  		"null.Time": {
  1034  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1035  		},
  1036  		"null.JSON": {
  1037  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1038  		},
  1039  		"null.Bytes": {
  1040  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
  1041  		},
  1042  		"time.Time": {
  1043  			Standard: importers.List{`"time"`},
  1044  		},
  1045  		"types.JSON": {
  1046  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1047  		},
  1048  		"types.Decimal": {
  1049  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1050  		},
  1051  		"types.BytesArray": {
  1052  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1053  		},
  1054  		"types.Int64Array": {
  1055  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1056  		},
  1057  		"types.Float64Array": {
  1058  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1059  		},
  1060  		"types.BoolArray": {
  1061  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1062  		},
  1063  		"types.StringArray": {
  1064  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1065  		},
  1066  		"types.DecimalArray": {
  1067  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1068  		},
  1069  		"types.HStore": {
  1070  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1071  		},
  1072  		"pgeo.Point": {
  1073  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1074  		},
  1075  		"pgeo.Line": {
  1076  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1077  		},
  1078  		"pgeo.Lseg": {
  1079  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1080  		},
  1081  		"pgeo.Box": {
  1082  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1083  		},
  1084  		"pgeo.Path": {
  1085  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1086  		},
  1087  		"pgeo.Polygon": {
  1088  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1089  		},
  1090  		"types.NullDecimal": {
  1091  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
  1092  		},
  1093  		"pgeo.Circle": {
  1094  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1095  		},
  1096  		"pgeo.NullPoint": {
  1097  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1098  		},
  1099  		"pgeo.NullLine": {
  1100  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1101  		},
  1102  		"pgeo.NullLseg": {
  1103  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1104  		},
  1105  		"pgeo.NullBox": {
  1106  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1107  		},
  1108  		"pgeo.NullPath": {
  1109  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1110  		},
  1111  		"pgeo.NullPolygon": {
  1112  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1113  		},
  1114  		"pgeo.NullCircle": {
  1115  			ThirdParty: importers.List{`"github.com/volatiletech/sqlboiler/v4/types/pgeo"`},
  1116  		},
  1117  	}
  1118  
  1119  	return col, nil
  1120  }
  1121  
  1122  // getVersion gets the version of underlying database
  1123  func (p *PostgresDriver) getVersion() (int, error) {
  1124  	type versionInfoType struct {
  1125  		ServerVersionNum int `json:"server_version_num"`
  1126  	}
  1127  	versionInfo := &versionInfoType{}
  1128  
  1129  	row := p.conn.QueryRow("SHOW server_version_num")
  1130  	if err := row.Scan(&versionInfo.ServerVersionNum); err != nil {
  1131  		return 0, err
  1132  	}
  1133  
  1134  	return versionInfo.ServerVersionNum, nil
  1135  }