github.com/cloudfoundry/postgres-release/src/acceptance-tests@v0.0.0-20240511030151-872bdd2e0dba/testing/helpers/postgres.go (about) 1 package helpers 2 3 import ( 4 "bytes" 5 "database/sql" 6 "encoding/gob" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "os" 11 "strings" 12 ) 13 14 const DefaultDB = "postgres" 15 16 type PGData struct { 17 Data PGCommon 18 DBs []PGConn 19 } 20 21 type User struct { 22 Name string 23 Password string 24 Certificate string 25 Key string 26 } 27 28 type PGCommon struct { 29 Address string 30 Port int 31 SSLMode string 32 SSLRootCert string 33 DefUser User 34 AdminUser User 35 CertUser User 36 UseCert bool 37 } 38 39 type PGConn struct { 40 TargetDB string 41 User string 42 password string 43 DB *sql.DB 44 } 45 46 type PGSetting struct { 47 Name string `json:"name"` 48 Setting string `json:"setting"` 49 VarType string `json:"vartype"` 50 } 51 type PGDatabase struct { 52 Name string `json:"datname"` 53 DBExts []PGDatabaseExtensions 54 Tables []PGTable 55 } 56 type PGDatabaseExtensions struct { 57 Name string `json:"extname"` 58 } 59 type PGTable struct { 60 SchemaName string `json:"schemaname"` 61 TableName string `json:"tablename"` 62 TableOwner string `json:"tableowner"` 63 TableColumns []PGTableColumn 64 TableRowsCount PGCount 65 } 66 type PGTableColumn struct { 67 ColumnName string `json:"column_name"` 68 DataType string `json:"data_type"` 69 Position int `json:"ordinal_position"` 70 } 71 type PGCount struct { 72 Num int `json:"count"` 73 } 74 type PGVersion struct { 75 Version string `json:"version"` 76 } 77 type PGRole struct { 78 Name string `json:"rolname"` 79 Super bool `json:"rolsuper"` 80 Inherit bool `json:"rolinherit"` 81 CreateRole bool `json:"rolcreaterole"` 82 CreateDb bool `json:"rolcreatedb"` 83 CanLogin bool `json:"rolcanlogin"` 84 Replication bool `json:"rolreplication"` 85 ConnLimit int `json:"rolconnlimit"` 86 ValidUntil string `json:"rolvaliduntil"` 87 } 88 89 type PGOutputData struct { 90 Roles map[string]PGRole 91 Databases []PGDatabase 92 Settings map[string]string 93 Version PGVersion 94 } 95 96 const GetSettingsQuery = "SELECT * FROM pg_settings" 97 const ListRolesQuery = "SELECT * from pg_roles" 98 const GetRoleQuery = "SELECT * from pg_roles where rolname='%s'" 99 const GetTableQuery = "SELECT * from pg_catalog.pg_tables where tablename='%s'" 100 const ListDatabasesQuery = "SELECT datname from pg_database where datistemplate=false" 101 const ListDBExtensionsQuery = "SELECT extname from pg_extension" 102 const ConvertToDateCommand = "SELECT '%s'::timestamptz" 103 const ListTablesQuery = "SELECT * from pg_catalog.pg_tables where schemaname not like 'pg_%' and schemaname != 'information_schema'" 104 const ListTableColumnsQuery = "SELECT column_name, data_type, ordinal_position FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s' order by ordinal_position asc" 105 const CountTableRowsQuery = "SELECT COUNT(*) FROM %s" 106 const GetPostgreSQLVersionQuery = "SELECT version()" 107 const QueryResultAsJson = "SELECT row_to_json(t) from (%s) as t;" 108 const DropTable = "DROP TABLE %s" 109 110 const NoConnectionAvailableErr = "No connections available" 111 const MissingDBAddressErr = "Database address not specified" 112 const MissingDBPortErr = "Database port not specified" 113 const MissingDefaultUserErr = "Default user not specified" 114 const MissingDefaultPasswordErr = "Default password not specified" 115 const NoSuperUserProvidedErr = "No super user provided" 116 const IncorrectSSLModeErr = "Incorrect SSL mode specified" 117 const MissingSSLRootCertErr = "SSL Root Certificate missing" 118 const MissingCertUserErr = "No user specified to authenticate with certificates" 119 const MissingCertCertErr = "No certificate specified for cert user" 120 const MissingCertKeyErr = "No private key specified for cert user's certificate" 121 122 func GetFormattedQuery(query string) string { 123 return fmt.Sprintf(QueryResultAsJson, query) 124 } 125 126 func NewPostgres(props PGCommon) (PGData, error) { 127 var pg PGData 128 if props.SSLMode == "" { 129 props.SSLMode = "disable" 130 } 131 if err := checkSSLMode(props.SSLMode, props.SSLRootCert); err != nil { 132 return PGData{}, err 133 } 134 if props.Address == "" { 135 return PGData{}, errors.New(MissingDBAddressErr) 136 } 137 if props.Port == 0 { 138 return PGData{}, errors.New(MissingDBPortErr) 139 } 140 if props.DefUser == (User{}) { 141 return PGData{}, errors.New(MissingDefaultUserErr) 142 } 143 if props.DefUser.Name == "" { 144 return PGData{}, errors.New(MissingDefaultUserErr) 145 } 146 if props.DefUser.Password == "" { 147 return PGData{}, errors.New(MissingDefaultPasswordErr) 148 } 149 pg.Data = props 150 return pg, nil 151 } 152 153 func checkSSLMode(sslmode string, sslrootcert string) error { 154 var strong_sslmodes = [...]string{"verify-ca", "verify-full"} 155 var valid_sslmodes = [...]string{"disable", "require", "verify-ca", "verify-full"} 156 for _, valid_mode := range valid_sslmodes { 157 if valid_mode == sslmode { 158 for _, strong_mode := range strong_sslmodes { 159 if strong_mode == sslmode && sslrootcert == "" { 160 return errors.New(MissingSSLRootCertErr) 161 } 162 } 163 return nil 164 } 165 } 166 return errors.New(IncorrectSSLModeErr) 167 } 168 169 func (pg PGData) getDefaultUser() User { 170 if pg.Data.UseCert { 171 return pg.Data.CertUser 172 } else { 173 return pg.Data.DefUser 174 } 175 } 176 177 func (pg PGData) checkCertUser() error { 178 if pg.Data.CertUser == (User{}) { 179 return errors.New(MissingCertUserErr) 180 } 181 if pg.Data.CertUser.Certificate == "" { 182 return errors.New(MissingCertCertErr) 183 } 184 if pg.Data.CertUser.Key == "" { 185 return errors.New(MissingCertKeyErr) 186 } 187 return nil 188 } 189 func (pg *PGData) SetCertUserCertificates(user string, certs map[interface{}]interface{}) error { 190 if user == "" { 191 if pg.Data.UseCert { 192 return errors.New(MissingCertUserErr) 193 } 194 pg.Data.CertUser = User{} 195 } else { 196 clientCertPath, err := WriteFile(certs["certificate"].(string)) 197 if err != nil { 198 return err 199 } 200 clientKeyPath, err := WriteFile(certs["private_key"].(string)) 201 if err != nil { 202 return err 203 } 204 if pg.Data.CertUser.Certificate != "" { 205 os.Remove(pg.Data.CertUser.Certificate) 206 } 207 if pg.Data.CertUser.Key != "" { 208 os.Remove(pg.Data.CertUser.Key) 209 } 210 pg.Data.CertUser.Name = user 211 pg.Data.CertUser.Certificate = clientCertPath 212 pg.Data.CertUser.Key = clientKeyPath 213 } 214 return nil 215 } 216 217 func (pg *PGData) UseCertAuthentication(useCert bool) error { 218 if useCert { 219 if err := pg.checkCertUser(); err != nil { 220 return err 221 } 222 } 223 pg.Data.UseCert = useCert 224 pg.CloseConnections() 225 return nil 226 } 227 func (pg *PGData) ChangeSSLMode(sslmode string, sslrootcert string) error { 228 var err error 229 rootCertpath := "" 230 if err := checkSSLMode(sslmode, sslrootcert); err != nil { 231 return err 232 } 233 if sslrootcert != "" { 234 rootCertpath, err = WriteFile(sslrootcert) 235 if err != nil { 236 return err 237 } 238 } 239 if pg.Data.SSLRootCert != "" { 240 os.Remove(pg.Data.SSLRootCert) 241 } 242 pg.Data.SSLMode = sslmode 243 pg.Data.SSLRootCert = rootCertpath 244 pg.CloseConnections() 245 return nil 246 } 247 248 func (pg PGData) buildConnectionData(dbname string, user User) string { 249 result := fmt.Sprintf("dbname=%s user=%s host=%s port=%d sslmode=%s", dbname, user.Name, pg.Data.Address, pg.Data.Port, pg.Data.SSLMode) 250 if pg.Data.SSLRootCert != "" { 251 result = fmt.Sprintf("%s sslrootcert=%s", result, pg.Data.SSLRootCert) 252 } 253 if user.Password != "" { 254 result = fmt.Sprintf("%s password=%s", result, user.Password) 255 } else { 256 result = fmt.Sprintf("%s sslcert=%s sslkey=%s", result, user.Certificate, user.Key) 257 } 258 return result 259 } 260 261 func (pg *PGData) OpenConnection(dbname string, user User) (PGConn, error) { 262 var newConn PGConn 263 var err error 264 265 connectionData := pg.buildConnectionData(dbname, user) 266 newConn.DB, err = sql.Open("postgres", connectionData) 267 if err != nil { 268 return PGConn{}, err 269 } 270 err = newConn.DB.Ping() 271 if err != nil { 272 return PGConn{}, err 273 } 274 newConn.User = user.Name 275 newConn.password = user.Password 276 newConn.TargetDB = dbname 277 newConn.DB.SetMaxIdleConns(10) 278 pg.DBs = append(pg.DBs, newConn) 279 return newConn, nil 280 } 281 func (pg *PGData) CloseConnections() { 282 for _, conn := range pg.DBs { 283 conn.DB.Close() 284 } 285 pg.DBs = nil 286 } 287 func (pg *PGData) GetDefaultConnection() (PGConn, error) { 288 return pg.GetDBConnection(DefaultDB) 289 } 290 291 func (pg *PGData) GetDBSuperUserConnection(dbname string) (PGConn, error) { 292 if pg.Data.AdminUser == (User{}) || 293 pg.Data.AdminUser.Name == "" || 294 pg.Data.AdminUser.Password == "" { 295 return PGConn{}, errors.New(NoSuperUserProvidedErr) 296 } 297 conn, err := pg.GetDBConnectionForUser(dbname, pg.Data.AdminUser) 298 if err != nil { 299 conn, err = pg.OpenConnection(dbname, pg.Data.AdminUser) 300 if err != nil { 301 return PGConn{}, err 302 } 303 } 304 return conn, nil 305 } 306 func (pg *PGData) GetSuperUserConnection() (PGConn, error) { 307 return pg.GetDBSuperUserConnection(DefaultDB) 308 } 309 310 func (pg *PGData) GetDBConnection(dbname string) (PGConn, error) { 311 result, err := pg.GetDBConnectionForUser(dbname, pg.getDefaultUser()) 312 if (PGConn{}) == result { 313 result, err = pg.OpenConnection(dbname, pg.getDefaultUser()) 314 if err != nil { 315 return PGConn{}, err 316 } 317 } 318 return result, nil 319 } 320 func (pg PGData) GetDBConnectionForUser(dbname string, user User) (PGConn, error) { 321 if len(pg.DBs) == 0 { 322 return PGConn{}, errors.New(NoConnectionAvailableErr) 323 } 324 var result PGConn 325 for _, conn := range pg.DBs { 326 if conn.TargetDB == dbname { 327 if user.Name == "" || conn.User == user.Name { 328 result = conn 329 break 330 } 331 } 332 } 333 if (PGConn{}) == result { 334 return PGConn{}, errors.New(NoConnectionAvailableErr) 335 } 336 return result, nil 337 } 338 339 func (pg PGConn) Run(query string) ([]string, error) { 340 var result []string 341 if rows, err := pg.DB.Query(GetFormattedQuery(query)); err != nil { 342 return nil, err 343 } else { 344 defer rows.Close() 345 for rows.Next() { 346 var jsonRow string 347 if err := rows.Scan(&jsonRow); err != nil { 348 break 349 } 350 result = append(result, jsonRow) 351 } 352 if err := rows.Err(); err != nil { 353 return nil, err 354 } 355 } 356 return result, nil 357 } 358 359 func (pg PGConn) Exec(query string) error { 360 if _, err := pg.DB.Exec(query); err != nil { 361 return err 362 } 363 return nil 364 } 365 366 func (pg PGData) DropTable(dbName string, tableName string) error { 367 368 conn, err := pg.GetDBConnection(dbName) 369 if err != nil { 370 return err 371 } 372 err = conn.Exec(fmt.Sprintf(DropTable, tableName)) 373 if err != nil { 374 return err 375 } 376 return nil 377 } 378 379 func (pg PGData) CreateAndPopulateTables(dbName string, loadType LoadType) error { 380 return pg.CreateAndPopulateTablesWithPrefix(dbName, loadType, "pgats_table") 381 } 382 383 func (pg PGData) CreateAndPopulateTablesWithPrefix(dbName string, loadType LoadType, prefix string) error { 384 385 conn, err := pg.GetDBConnection(dbName) 386 if err != nil { 387 return err 388 } 389 tables := GetSampleLoadWithPrefix(loadType, prefix) 390 391 for _, table := range tables { 392 err = conn.Exec(table.PrepareCreate()) 393 if err != nil { 394 return err 395 } 396 err = conn.Exec(table.PrepareCreateIndex()) 397 if err != nil { 398 return err 399 } 400 txn, err := conn.DB.Begin() 401 if err != nil { 402 return err 403 } 404 405 stmt, err := txn.Prepare(table.PrepareStatement()) 406 if err != nil { 407 return err 408 } 409 410 for i := 0; i < table.NumRows; i++ { 411 _, err = stmt.Exec(table.PrepareRow(i)...) 412 if err != nil { 413 return err 414 } 415 } 416 417 _, err = stmt.Exec() 418 if err != nil { 419 return err 420 } 421 422 err = stmt.Close() 423 if err != nil { 424 return err 425 } 426 427 err = txn.Commit() 428 if err != nil { 429 return err 430 } 431 } 432 433 return err 434 } 435 436 func (pg PGData) ReadAllSettings() (map[string]string, error) { 437 result := make(map[string]string) 438 conn, err := pg.GetDefaultConnection() 439 if err != nil { 440 return nil, err 441 } 442 rows, err := conn.Run(GetSettingsQuery) 443 if err != nil { 444 return nil, err 445 } 446 for _, row := range rows { 447 out := PGSetting{} 448 err = json.Unmarshal([]byte(row), &out) 449 if err != nil { 450 return nil, err 451 } 452 result[out.Name] = out.Setting 453 } 454 return result, nil 455 } 456 func (pg PGData) GetPostgreSQLVersion() (PGVersion, error) { 457 var result PGVersion 458 459 conn, err := pg.GetDefaultConnection() 460 if err != nil { 461 return PGVersion{}, err 462 } 463 rows, err := conn.Run(GetPostgreSQLVersionQuery) 464 if err != nil { 465 return PGVersion{}, err 466 } 467 err = json.Unmarshal([]byte(rows[0]), &result) 468 if err != nil { 469 return PGVersion{}, err 470 } 471 return result, nil 472 } 473 func (pg PGData) ListDatabases() ([]PGDatabase, error) { 474 var result []PGDatabase 475 conn, err := pg.GetSuperUserConnection() 476 if err != nil { 477 return nil, err 478 } 479 rows, err := conn.Run(ListDatabasesQuery) 480 if err != nil { 481 return nil, err 482 } 483 for _, row := range rows { 484 out := PGDatabase{} 485 err = json.Unmarshal([]byte(row), &out) 486 if err != nil { 487 return nil, err 488 } 489 result = append(result, out) 490 } 491 for idx, database := range result { 492 result[idx].DBExts, err = pg.ListDatabaseExtensions(database.Name) 493 if err != nil { 494 return nil, err 495 } 496 result[idx].Tables, err = pg.ListDatabaseTables(database.Name) 497 if err != nil { 498 return nil, err 499 } 500 } 501 return result, nil 502 } 503 func (pg PGData) ListDatabaseExtensions(dbName string) ([]PGDatabaseExtensions, error) { 504 conn, err := pg.GetDBSuperUserConnection(dbName) 505 if err != nil { 506 return nil, err 507 } 508 rows, err := conn.Run(ListDBExtensionsQuery) 509 if err != nil { 510 return nil, err 511 } 512 extensionsList := []PGDatabaseExtensions{} 513 for _, row := range rows { 514 out := PGDatabaseExtensions{} 515 err = json.Unmarshal([]byte(row), &out) 516 if err != nil { 517 return nil, err 518 } 519 extensionsList = append(extensionsList, out) 520 } 521 return extensionsList, nil 522 } 523 func (pg PGData) ListDatabaseTables(dbName string) ([]PGTable, error) { 524 conn, err := pg.GetDBSuperUserConnection(dbName) 525 if err != nil { 526 return nil, err 527 } 528 rows, err := conn.Run(ListTablesQuery) 529 if err != nil { 530 return nil, err 531 } 532 tableList := []PGTable{} 533 for _, row := range rows { 534 tableData := PGTable{} 535 err = json.Unmarshal([]byte(row), &tableData) 536 if err != nil { 537 return nil, err 538 } 539 tableData.TableColumns = []PGTableColumn{} 540 colRows, err := conn.Run(fmt.Sprintf(ListTableColumnsQuery, tableData.SchemaName, tableData.TableName)) 541 if err != nil { 542 return nil, err 543 } 544 for _, colRow := range colRows { 545 colData := PGTableColumn{} 546 err = json.Unmarshal([]byte(colRow), &colData) 547 if err != nil { 548 return nil, err 549 } 550 tableData.TableColumns = append(tableData.TableColumns, colData) 551 } 552 countRows, err := conn.Run(fmt.Sprintf(CountTableRowsQuery, tableData.TableName)) 553 if err != nil { 554 return nil, err 555 } 556 count := PGCount{} 557 err = json.Unmarshal([]byte(countRows[0]), &count) 558 if err != nil { 559 return nil, err 560 } 561 tableData.TableRowsCount = count 562 563 tableList = append(tableList, tableData) 564 } 565 return tableList, nil 566 } 567 func (pg PGData) CheckTableExist(table_name string, dbName string) (bool, error) { 568 conn, err := pg.GetDBSuperUserConnection(dbName) 569 if err != nil { 570 return false, err 571 } 572 rows, err := conn.Run(fmt.Sprintf(GetTableQuery, table_name)) 573 if err != nil { 574 return false, err 575 } 576 if rows != nil && len(rows) != 0 { 577 return true, nil 578 } 579 return false, nil 580 } 581 func (pg PGData) ListRoles() (map[string]PGRole, error) { 582 result := make(map[string]PGRole) 583 conn, err := pg.GetDefaultConnection() 584 if err != nil { 585 return nil, err 586 } 587 rows, err := conn.Run(ListRolesQuery) 588 if err != nil { 589 return nil, err 590 } 591 for _, row := range rows { 592 out := PGRole{} 593 err = json.Unmarshal([]byte(row), &out) 594 if err != nil { 595 return nil, err 596 } 597 result[out.Name] = out 598 } 599 return result, nil 600 } 601 func (pg PGData) CheckRoleExist(role_name string) (bool, error) { 602 conn, err := pg.GetDefaultConnection() 603 if err != nil { 604 return false, err 605 } 606 rows, err := conn.Run(fmt.Sprintf(GetRoleQuery, role_name)) 607 if err != nil { 608 return false, err 609 } 610 if rows != nil && len(rows) != 0 { 611 return true, nil 612 } 613 return false, nil 614 } 615 616 func (pg PGData) ConvertToPostgresDate(inputDate string) (string, error) { 617 type ConvertedDate struct { 618 Date string `json:"timestamptz"` 619 } 620 result := ConvertedDate{} 621 inputDate = strings.TrimLeft(inputDate, "'\"") 622 inputDate = strings.TrimRight(inputDate, "'\"") 623 conn, err := pg.GetDefaultConnection() 624 if err != nil { 625 return "", err 626 } 627 rows, err := conn.Run(fmt.Sprintf(ConvertToDateCommand, inputDate)) 628 if err != nil { 629 return "", err 630 } 631 err = json.Unmarshal([]byte(rows[0]), &result) 632 if err != nil { 633 return "", err 634 } 635 return result.Date, nil 636 } 637 638 func (pg PGData) GetData() (PGOutputData, error) { 639 var result PGOutputData 640 var err error 641 result.Settings, err = pg.ReadAllSettings() 642 if err != nil { 643 return PGOutputData{}, err 644 } 645 result.Databases, err = pg.ListDatabases() 646 if err != nil { 647 return PGOutputData{}, err 648 } 649 result.Roles, err = pg.ListRoles() 650 if err != nil { 651 return PGOutputData{}, err 652 } 653 result.Version, err = pg.GetPostgreSQLVersion() 654 if err != nil { 655 return PGOutputData{}, err 656 } 657 return result, nil 658 } 659 660 func (o PGOutputData) CopyData() (PGOutputData, error) { 661 var to PGOutputData 662 var buffer bytes.Buffer 663 enc := gob.NewEncoder(&buffer) 664 dec := gob.NewDecoder(&buffer) 665 err := enc.Encode(o) 666 if err != nil { 667 return PGOutputData{}, err 668 } 669 err = dec.Decode(&to) 670 if err != nil { 671 return PGOutputData{}, err 672 } 673 return to, nil 674 }