github.com/cloudfoundry/postgres-release/src/acceptance-tests@v0.0.0-20240511030151-872bdd2e0dba/testing/helpers/postgres.go (about)

     1  package helpers
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"encoding/gob"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"strings"
    12  )
    13  
    14  const DefaultDB = "postgres"
    15  
    16  type PGData struct {
    17  	Data PGCommon
    18  	DBs  []PGConn
    19  }
    20  
    21  type User struct {
    22  	Name        string
    23  	Password    string
    24  	Certificate string
    25  	Key         string
    26  }
    27  
    28  type PGCommon struct {
    29  	Address     string
    30  	Port        int
    31  	SSLMode     string
    32  	SSLRootCert string
    33  	DefUser     User
    34  	AdminUser   User
    35  	CertUser    User
    36  	UseCert     bool
    37  }
    38  
    39  type PGConn struct {
    40  	TargetDB string
    41  	User     string
    42  	password string
    43  	DB       *sql.DB
    44  }
    45  
    46  type PGSetting struct {
    47  	Name    string `json:"name"`
    48  	Setting string `json:"setting"`
    49  	VarType string `json:"vartype"`
    50  }
    51  type PGDatabase struct {
    52  	Name   string `json:"datname"`
    53  	DBExts []PGDatabaseExtensions
    54  	Tables []PGTable
    55  }
    56  type PGDatabaseExtensions struct {
    57  	Name string `json:"extname"`
    58  }
    59  type PGTable struct {
    60  	SchemaName     string `json:"schemaname"`
    61  	TableName      string `json:"tablename"`
    62  	TableOwner     string `json:"tableowner"`
    63  	TableColumns   []PGTableColumn
    64  	TableRowsCount PGCount
    65  }
    66  type PGTableColumn struct {
    67  	ColumnName string `json:"column_name"`
    68  	DataType   string `json:"data_type"`
    69  	Position   int    `json:"ordinal_position"`
    70  }
    71  type PGCount struct {
    72  	Num int `json:"count"`
    73  }
    74  type PGVersion struct {
    75  	Version string `json:"version"`
    76  }
    77  type PGRole struct {
    78  	Name        string `json:"rolname"`
    79  	Super       bool   `json:"rolsuper"`
    80  	Inherit     bool   `json:"rolinherit"`
    81  	CreateRole  bool   `json:"rolcreaterole"`
    82  	CreateDb    bool   `json:"rolcreatedb"`
    83  	CanLogin    bool   `json:"rolcanlogin"`
    84  	Replication bool   `json:"rolreplication"`
    85  	ConnLimit   int    `json:"rolconnlimit"`
    86  	ValidUntil  string `json:"rolvaliduntil"`
    87  }
    88  
    89  type PGOutputData struct {
    90  	Roles     map[string]PGRole
    91  	Databases []PGDatabase
    92  	Settings  map[string]string
    93  	Version   PGVersion
    94  }
    95  
    96  const GetSettingsQuery = "SELECT * FROM pg_settings"
    97  const ListRolesQuery = "SELECT * from pg_roles"
    98  const GetRoleQuery = "SELECT * from pg_roles where rolname='%s'"
    99  const GetTableQuery = "SELECT * from pg_catalog.pg_tables where tablename='%s'"
   100  const ListDatabasesQuery = "SELECT datname from pg_database where datistemplate=false"
   101  const ListDBExtensionsQuery = "SELECT extname from pg_extension"
   102  const ConvertToDateCommand = "SELECT '%s'::timestamptz"
   103  const ListTablesQuery = "SELECT * from pg_catalog.pg_tables where schemaname not like 'pg_%' and schemaname != 'information_schema'"
   104  const ListTableColumnsQuery = "SELECT column_name, data_type, ordinal_position FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s' order by ordinal_position asc"
   105  const CountTableRowsQuery = "SELECT COUNT(*) FROM %s"
   106  const GetPostgreSQLVersionQuery = "SELECT version()"
   107  const QueryResultAsJson = "SELECT row_to_json(t) from (%s) as t;"
   108  const DropTable = "DROP TABLE %s"
   109  
   110  const NoConnectionAvailableErr = "No connections available"
   111  const MissingDBAddressErr = "Database address not specified"
   112  const MissingDBPortErr = "Database port not specified"
   113  const MissingDefaultUserErr = "Default user not specified"
   114  const MissingDefaultPasswordErr = "Default password not specified"
   115  const NoSuperUserProvidedErr = "No super user provided"
   116  const IncorrectSSLModeErr = "Incorrect SSL mode specified"
   117  const MissingSSLRootCertErr = "SSL Root Certificate missing"
   118  const MissingCertUserErr = "No user specified to authenticate with certificates"
   119  const MissingCertCertErr = "No certificate specified for cert user"
   120  const MissingCertKeyErr = "No private key specified for cert user's certificate"
   121  
   122  func GetFormattedQuery(query string) string {
   123  	return fmt.Sprintf(QueryResultAsJson, query)
   124  }
   125  
   126  func NewPostgres(props PGCommon) (PGData, error) {
   127  	var pg PGData
   128  	if props.SSLMode == "" {
   129  		props.SSLMode = "disable"
   130  	}
   131  	if err := checkSSLMode(props.SSLMode, props.SSLRootCert); err != nil {
   132  		return PGData{}, err
   133  	}
   134  	if props.Address == "" {
   135  		return PGData{}, errors.New(MissingDBAddressErr)
   136  	}
   137  	if props.Port == 0 {
   138  		return PGData{}, errors.New(MissingDBPortErr)
   139  	}
   140  	if props.DefUser == (User{}) {
   141  		return PGData{}, errors.New(MissingDefaultUserErr)
   142  	}
   143  	if props.DefUser.Name == "" {
   144  		return PGData{}, errors.New(MissingDefaultUserErr)
   145  	}
   146  	if props.DefUser.Password == "" {
   147  		return PGData{}, errors.New(MissingDefaultPasswordErr)
   148  	}
   149  	pg.Data = props
   150  	return pg, nil
   151  }
   152  
   153  func checkSSLMode(sslmode string, sslrootcert string) error {
   154  	var strong_sslmodes = [...]string{"verify-ca", "verify-full"}
   155  	var valid_sslmodes = [...]string{"disable", "require", "verify-ca", "verify-full"}
   156  	for _, valid_mode := range valid_sslmodes {
   157  		if valid_mode == sslmode {
   158  			for _, strong_mode := range strong_sslmodes {
   159  				if strong_mode == sslmode && sslrootcert == "" {
   160  					return errors.New(MissingSSLRootCertErr)
   161  				}
   162  			}
   163  			return nil
   164  		}
   165  	}
   166  	return errors.New(IncorrectSSLModeErr)
   167  }
   168  
   169  func (pg PGData) getDefaultUser() User {
   170  	if pg.Data.UseCert {
   171  		return pg.Data.CertUser
   172  	} else {
   173  		return pg.Data.DefUser
   174  	}
   175  }
   176  
   177  func (pg PGData) checkCertUser() error {
   178  	if pg.Data.CertUser == (User{}) {
   179  		return errors.New(MissingCertUserErr)
   180  	}
   181  	if pg.Data.CertUser.Certificate == "" {
   182  		return errors.New(MissingCertCertErr)
   183  	}
   184  	if pg.Data.CertUser.Key == "" {
   185  		return errors.New(MissingCertKeyErr)
   186  	}
   187  	return nil
   188  }
   189  func (pg *PGData) SetCertUserCertificates(user string, certs map[interface{}]interface{}) error {
   190  	if user == "" {
   191  		if pg.Data.UseCert {
   192  			return errors.New(MissingCertUserErr)
   193  		}
   194  		pg.Data.CertUser = User{}
   195  	} else {
   196  		clientCertPath, err := WriteFile(certs["certificate"].(string))
   197  		if err != nil {
   198  			return err
   199  		}
   200  		clientKeyPath, err := WriteFile(certs["private_key"].(string))
   201  		if err != nil {
   202  			return err
   203  		}
   204  		if pg.Data.CertUser.Certificate != "" {
   205  			os.Remove(pg.Data.CertUser.Certificate)
   206  		}
   207  		if pg.Data.CertUser.Key != "" {
   208  			os.Remove(pg.Data.CertUser.Key)
   209  		}
   210  		pg.Data.CertUser.Name = user
   211  		pg.Data.CertUser.Certificate = clientCertPath
   212  		pg.Data.CertUser.Key = clientKeyPath
   213  	}
   214  	return nil
   215  }
   216  
   217  func (pg *PGData) UseCertAuthentication(useCert bool) error {
   218  	if useCert {
   219  		if err := pg.checkCertUser(); err != nil {
   220  			return err
   221  		}
   222  	}
   223  	pg.Data.UseCert = useCert
   224  	pg.CloseConnections()
   225  	return nil
   226  }
   227  func (pg *PGData) ChangeSSLMode(sslmode string, sslrootcert string) error {
   228  	var err error
   229  	rootCertpath := ""
   230  	if err := checkSSLMode(sslmode, sslrootcert); err != nil {
   231  		return err
   232  	}
   233  	if sslrootcert != "" {
   234  		rootCertpath, err = WriteFile(sslrootcert)
   235  		if err != nil {
   236  			return err
   237  		}
   238  	}
   239  	if pg.Data.SSLRootCert != "" {
   240  		os.Remove(pg.Data.SSLRootCert)
   241  	}
   242  	pg.Data.SSLMode = sslmode
   243  	pg.Data.SSLRootCert = rootCertpath
   244  	pg.CloseConnections()
   245  	return nil
   246  }
   247  
   248  func (pg PGData) buildConnectionData(dbname string, user User) string {
   249  	result := fmt.Sprintf("dbname=%s user=%s host=%s port=%d sslmode=%s", dbname, user.Name, pg.Data.Address, pg.Data.Port, pg.Data.SSLMode)
   250  	if pg.Data.SSLRootCert != "" {
   251  		result = fmt.Sprintf("%s sslrootcert=%s", result, pg.Data.SSLRootCert)
   252  	}
   253  	if user.Password != "" {
   254  		result = fmt.Sprintf("%s password=%s", result, user.Password)
   255  	} else {
   256  		result = fmt.Sprintf("%s sslcert=%s sslkey=%s", result, user.Certificate, user.Key)
   257  	}
   258  	return result
   259  }
   260  
   261  func (pg *PGData) OpenConnection(dbname string, user User) (PGConn, error) {
   262  	var newConn PGConn
   263  	var err error
   264  
   265  	connectionData := pg.buildConnectionData(dbname, user)
   266  	newConn.DB, err = sql.Open("postgres", connectionData)
   267  	if err != nil {
   268  		return PGConn{}, err
   269  	}
   270  	err = newConn.DB.Ping()
   271  	if err != nil {
   272  		return PGConn{}, err
   273  	}
   274  	newConn.User = user.Name
   275  	newConn.password = user.Password
   276  	newConn.TargetDB = dbname
   277  	newConn.DB.SetMaxIdleConns(10)
   278  	pg.DBs = append(pg.DBs, newConn)
   279  	return newConn, nil
   280  }
   281  func (pg *PGData) CloseConnections() {
   282  	for _, conn := range pg.DBs {
   283  		conn.DB.Close()
   284  	}
   285  	pg.DBs = nil
   286  }
   287  func (pg *PGData) GetDefaultConnection() (PGConn, error) {
   288  	return pg.GetDBConnection(DefaultDB)
   289  }
   290  
   291  func (pg *PGData) GetDBSuperUserConnection(dbname string) (PGConn, error) {
   292  	if pg.Data.AdminUser == (User{}) ||
   293  		pg.Data.AdminUser.Name == "" ||
   294  		pg.Data.AdminUser.Password == "" {
   295  		return PGConn{}, errors.New(NoSuperUserProvidedErr)
   296  	}
   297  	conn, err := pg.GetDBConnectionForUser(dbname, pg.Data.AdminUser)
   298  	if err != nil {
   299  		conn, err = pg.OpenConnection(dbname, pg.Data.AdminUser)
   300  		if err != nil {
   301  			return PGConn{}, err
   302  		}
   303  	}
   304  	return conn, nil
   305  }
   306  func (pg *PGData) GetSuperUserConnection() (PGConn, error) {
   307  	return pg.GetDBSuperUserConnection(DefaultDB)
   308  }
   309  
   310  func (pg *PGData) GetDBConnection(dbname string) (PGConn, error) {
   311  	result, err := pg.GetDBConnectionForUser(dbname, pg.getDefaultUser())
   312  	if (PGConn{}) == result {
   313  		result, err = pg.OpenConnection(dbname, pg.getDefaultUser())
   314  		if err != nil {
   315  			return PGConn{}, err
   316  		}
   317  	}
   318  	return result, nil
   319  }
   320  func (pg PGData) GetDBConnectionForUser(dbname string, user User) (PGConn, error) {
   321  	if len(pg.DBs) == 0 {
   322  		return PGConn{}, errors.New(NoConnectionAvailableErr)
   323  	}
   324  	var result PGConn
   325  	for _, conn := range pg.DBs {
   326  		if conn.TargetDB == dbname {
   327  			if user.Name == "" || conn.User == user.Name {
   328  				result = conn
   329  				break
   330  			}
   331  		}
   332  	}
   333  	if (PGConn{}) == result {
   334  		return PGConn{}, errors.New(NoConnectionAvailableErr)
   335  	}
   336  	return result, nil
   337  }
   338  
   339  func (pg PGConn) Run(query string) ([]string, error) {
   340  	var result []string
   341  	if rows, err := pg.DB.Query(GetFormattedQuery(query)); err != nil {
   342  		return nil, err
   343  	} else {
   344  		defer rows.Close()
   345  		for rows.Next() {
   346  			var jsonRow string
   347  			if err := rows.Scan(&jsonRow); err != nil {
   348  				break
   349  			}
   350  			result = append(result, jsonRow)
   351  		}
   352  		if err := rows.Err(); err != nil {
   353  			return nil, err
   354  		}
   355  	}
   356  	return result, nil
   357  }
   358  
   359  func (pg PGConn) Exec(query string) error {
   360  	if _, err := pg.DB.Exec(query); err != nil {
   361  		return err
   362  	}
   363  	return nil
   364  }
   365  
   366  func (pg PGData) DropTable(dbName string, tableName string) error {
   367  
   368  	conn, err := pg.GetDBConnection(dbName)
   369  	if err != nil {
   370  		return err
   371  	}
   372  	err = conn.Exec(fmt.Sprintf(DropTable, tableName))
   373  	if err != nil {
   374  		return err
   375  	}
   376  	return nil
   377  }
   378  
   379  func (pg PGData) CreateAndPopulateTables(dbName string, loadType LoadType) error {
   380  	return pg.CreateAndPopulateTablesWithPrefix(dbName, loadType, "pgats_table")
   381  }
   382  
   383  func (pg PGData) CreateAndPopulateTablesWithPrefix(dbName string, loadType LoadType, prefix string) error {
   384  
   385  	conn, err := pg.GetDBConnection(dbName)
   386  	if err != nil {
   387  		return err
   388  	}
   389  	tables := GetSampleLoadWithPrefix(loadType, prefix)
   390  
   391  	for _, table := range tables {
   392  		err = conn.Exec(table.PrepareCreate())
   393  		if err != nil {
   394  			return err
   395  		}
   396  		err = conn.Exec(table.PrepareCreateIndex())
   397  		if err != nil {
   398  			return err
   399  		}
   400  		txn, err := conn.DB.Begin()
   401  		if err != nil {
   402  			return err
   403  		}
   404  
   405  		stmt, err := txn.Prepare(table.PrepareStatement())
   406  		if err != nil {
   407  			return err
   408  		}
   409  
   410  		for i := 0; i < table.NumRows; i++ {
   411  			_, err = stmt.Exec(table.PrepareRow(i)...)
   412  			if err != nil {
   413  				return err
   414  			}
   415  		}
   416  
   417  		_, err = stmt.Exec()
   418  		if err != nil {
   419  			return err
   420  		}
   421  
   422  		err = stmt.Close()
   423  		if err != nil {
   424  			return err
   425  		}
   426  
   427  		err = txn.Commit()
   428  		if err != nil {
   429  			return err
   430  		}
   431  	}
   432  
   433  	return err
   434  }
   435  
   436  func (pg PGData) ReadAllSettings() (map[string]string, error) {
   437  	result := make(map[string]string)
   438  	conn, err := pg.GetDefaultConnection()
   439  	if err != nil {
   440  		return nil, err
   441  	}
   442  	rows, err := conn.Run(GetSettingsQuery)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  	for _, row := range rows {
   447  		out := PGSetting{}
   448  		err = json.Unmarshal([]byte(row), &out)
   449  		if err != nil {
   450  			return nil, err
   451  		}
   452  		result[out.Name] = out.Setting
   453  	}
   454  	return result, nil
   455  }
   456  func (pg PGData) GetPostgreSQLVersion() (PGVersion, error) {
   457  	var result PGVersion
   458  
   459  	conn, err := pg.GetDefaultConnection()
   460  	if err != nil {
   461  		return PGVersion{}, err
   462  	}
   463  	rows, err := conn.Run(GetPostgreSQLVersionQuery)
   464  	if err != nil {
   465  		return PGVersion{}, err
   466  	}
   467  	err = json.Unmarshal([]byte(rows[0]), &result)
   468  	if err != nil {
   469  		return PGVersion{}, err
   470  	}
   471  	return result, nil
   472  }
   473  func (pg PGData) ListDatabases() ([]PGDatabase, error) {
   474  	var result []PGDatabase
   475  	conn, err := pg.GetSuperUserConnection()
   476  	if err != nil {
   477  		return nil, err
   478  	}
   479  	rows, err := conn.Run(ListDatabasesQuery)
   480  	if err != nil {
   481  		return nil, err
   482  	}
   483  	for _, row := range rows {
   484  		out := PGDatabase{}
   485  		err = json.Unmarshal([]byte(row), &out)
   486  		if err != nil {
   487  			return nil, err
   488  		}
   489  		result = append(result, out)
   490  	}
   491  	for idx, database := range result {
   492  		result[idx].DBExts, err = pg.ListDatabaseExtensions(database.Name)
   493  		if err != nil {
   494  			return nil, err
   495  		}
   496  		result[idx].Tables, err = pg.ListDatabaseTables(database.Name)
   497  		if err != nil {
   498  			return nil, err
   499  		}
   500  	}
   501  	return result, nil
   502  }
   503  func (pg PGData) ListDatabaseExtensions(dbName string) ([]PGDatabaseExtensions, error) {
   504  	conn, err := pg.GetDBSuperUserConnection(dbName)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	rows, err := conn.Run(ListDBExtensionsQuery)
   509  	if err != nil {
   510  		return nil, err
   511  	}
   512  	extensionsList := []PGDatabaseExtensions{}
   513  	for _, row := range rows {
   514  		out := PGDatabaseExtensions{}
   515  		err = json.Unmarshal([]byte(row), &out)
   516  		if err != nil {
   517  			return nil, err
   518  		}
   519  		extensionsList = append(extensionsList, out)
   520  	}
   521  	return extensionsList, nil
   522  }
   523  func (pg PGData) ListDatabaseTables(dbName string) ([]PGTable, error) {
   524  	conn, err := pg.GetDBSuperUserConnection(dbName)
   525  	if err != nil {
   526  		return nil, err
   527  	}
   528  	rows, err := conn.Run(ListTablesQuery)
   529  	if err != nil {
   530  		return nil, err
   531  	}
   532  	tableList := []PGTable{}
   533  	for _, row := range rows {
   534  		tableData := PGTable{}
   535  		err = json.Unmarshal([]byte(row), &tableData)
   536  		if err != nil {
   537  			return nil, err
   538  		}
   539  		tableData.TableColumns = []PGTableColumn{}
   540  		colRows, err := conn.Run(fmt.Sprintf(ListTableColumnsQuery, tableData.SchemaName, tableData.TableName))
   541  		if err != nil {
   542  			return nil, err
   543  		}
   544  		for _, colRow := range colRows {
   545  			colData := PGTableColumn{}
   546  			err = json.Unmarshal([]byte(colRow), &colData)
   547  			if err != nil {
   548  				return nil, err
   549  			}
   550  			tableData.TableColumns = append(tableData.TableColumns, colData)
   551  		}
   552  		countRows, err := conn.Run(fmt.Sprintf(CountTableRowsQuery, tableData.TableName))
   553  		if err != nil {
   554  			return nil, err
   555  		}
   556  		count := PGCount{}
   557  		err = json.Unmarshal([]byte(countRows[0]), &count)
   558  		if err != nil {
   559  			return nil, err
   560  		}
   561  		tableData.TableRowsCount = count
   562  
   563  		tableList = append(tableList, tableData)
   564  	}
   565  	return tableList, nil
   566  }
   567  func (pg PGData) CheckTableExist(table_name string, dbName string) (bool, error) {
   568  	conn, err := pg.GetDBSuperUserConnection(dbName)
   569  	if err != nil {
   570  		return false, err
   571  	}
   572  	rows, err := conn.Run(fmt.Sprintf(GetTableQuery, table_name))
   573  	if err != nil {
   574  		return false, err
   575  	}
   576  	if rows != nil && len(rows) != 0 {
   577  		return true, nil
   578  	}
   579  	return false, nil
   580  }
   581  func (pg PGData) ListRoles() (map[string]PGRole, error) {
   582  	result := make(map[string]PGRole)
   583  	conn, err := pg.GetDefaultConnection()
   584  	if err != nil {
   585  		return nil, err
   586  	}
   587  	rows, err := conn.Run(ListRolesQuery)
   588  	if err != nil {
   589  		return nil, err
   590  	}
   591  	for _, row := range rows {
   592  		out := PGRole{}
   593  		err = json.Unmarshal([]byte(row), &out)
   594  		if err != nil {
   595  			return nil, err
   596  		}
   597  		result[out.Name] = out
   598  	}
   599  	return result, nil
   600  }
   601  func (pg PGData) CheckRoleExist(role_name string) (bool, error) {
   602  	conn, err := pg.GetDefaultConnection()
   603  	if err != nil {
   604  		return false, err
   605  	}
   606  	rows, err := conn.Run(fmt.Sprintf(GetRoleQuery, role_name))
   607  	if err != nil {
   608  		return false, err
   609  	}
   610  	if rows != nil && len(rows) != 0 {
   611  		return true, nil
   612  	}
   613  	return false, nil
   614  }
   615  
   616  func (pg PGData) ConvertToPostgresDate(inputDate string) (string, error) {
   617  	type ConvertedDate struct {
   618  		Date string `json:"timestamptz"`
   619  	}
   620  	result := ConvertedDate{}
   621  	inputDate = strings.TrimLeft(inputDate, "'\"")
   622  	inputDate = strings.TrimRight(inputDate, "'\"")
   623  	conn, err := pg.GetDefaultConnection()
   624  	if err != nil {
   625  		return "", err
   626  	}
   627  	rows, err := conn.Run(fmt.Sprintf(ConvertToDateCommand, inputDate))
   628  	if err != nil {
   629  		return "", err
   630  	}
   631  	err = json.Unmarshal([]byte(rows[0]), &result)
   632  	if err != nil {
   633  		return "", err
   634  	}
   635  	return result.Date, nil
   636  }
   637  
   638  func (pg PGData) GetData() (PGOutputData, error) {
   639  	var result PGOutputData
   640  	var err error
   641  	result.Settings, err = pg.ReadAllSettings()
   642  	if err != nil {
   643  		return PGOutputData{}, err
   644  	}
   645  	result.Databases, err = pg.ListDatabases()
   646  	if err != nil {
   647  		return PGOutputData{}, err
   648  	}
   649  	result.Roles, err = pg.ListRoles()
   650  	if err != nil {
   651  		return PGOutputData{}, err
   652  	}
   653  	result.Version, err = pg.GetPostgreSQLVersion()
   654  	if err != nil {
   655  		return PGOutputData{}, err
   656  	}
   657  	return result, nil
   658  }
   659  
   660  func (o PGOutputData) CopyData() (PGOutputData, error) {
   661  	var to PGOutputData
   662  	var buffer bytes.Buffer
   663  	enc := gob.NewEncoder(&buffer)
   664  	dec := gob.NewDecoder(&buffer)
   665  	err := enc.Encode(o)
   666  	if err != nil {
   667  		return PGOutputData{}, err
   668  	}
   669  	err = dec.Decode(&to)
   670  	if err != nil {
   671  		return PGOutputData{}, err
   672  	}
   673  	return to, nil
   674  }