github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-psql/driver/override/test/singleton/psql_main_test.go.tpl (about) 1 var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`) 2 3 type pgTester struct { 4 dbConn *sql.DB 5 6 dbName string 7 host string 8 user string 9 pass string 10 sslmode string 11 port int 12 13 pgPassFile string 14 15 testDBName string 16 skipSQLCmd bool 17 } 18 19 func init() { 20 dbMain = &pgTester{} 21 } 22 23 // setup dumps the database schema and imports it into a temporary randomly 24 // generated test database so that tests can be run against it using the 25 // generated sqlboiler ORM package. 26 func (p *pgTester) setup() error { 27 var err error 28 29 viper.SetDefault("psql.schema", "public") 30 viper.SetDefault("psql.port", 5432) 31 viper.SetDefault("psql.sslmode", "require") 32 33 p.dbName = viper.GetString("psql.dbname") 34 p.host = viper.GetString("psql.host") 35 p.user = viper.GetString("psql.user") 36 p.pass = viper.GetString("psql.pass") 37 p.port = viper.GetInt("psql.port") 38 p.sslmode = viper.GetString("psql.sslmode") 39 p.testDBName = viper.GetString("psql.testdbname") 40 p.skipSQLCmd = viper.GetBool("psql.skipsqlcmd") 41 42 err = vala.BeginValidation().Validate( 43 vala.StringNotEmpty(p.user, "psql.user"), 44 vala.StringNotEmpty(p.host, "psql.host"), 45 vala.Not(vala.Equals(p.port, 0, "psql.port")), 46 vala.StringNotEmpty(p.dbName, "psql.dbname"), 47 vala.StringNotEmpty(p.sslmode, "psql.sslmode"), 48 ).Check() 49 50 if err != nil { 51 return err 52 } 53 54 // if no testing DB passed 55 if len(p.testDBName) == 0 { 56 // Create a randomized db name. 57 p.testDBName = randomize.StableDBName(p.dbName) 58 } 59 60 if err = p.makePGPassFile(); err != nil { 61 return err 62 } 63 64 if !p.skipSQLCmd { 65 if err = p.dropTestDB(); err != nil { 66 return err 67 } 68 if err = p.createTestDB(); err != nil { 69 return err 70 } 71 72 dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName) 73 dumpCmd.Env = append(os.Environ(), p.pgEnv()...) 74 createCmd := exec.Command("psql", p.testDBName) 75 createCmd.Env = append(os.Environ(), p.pgEnv()...) 76 77 r, w := io.Pipe() 78 dumpCmdStderr := &bytes.Buffer{} 79 createCmdStderr := &bytes.Buffer{} 80 81 dumpCmd.Stdout = w 82 dumpCmd.Stderr = dumpCmdStderr 83 84 createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r) 85 createCmd.Stderr = createCmdStderr 86 87 if err = dumpCmd.Start(); err != nil { 88 return errors.Wrap(err, "failed to start pg_dump command") 89 } 90 if err = createCmd.Start(); err != nil { 91 return errors.Wrap(err, "failed to start psql command") 92 } 93 94 if err = dumpCmd.Wait(); err != nil { 95 fmt.Println(err) 96 fmt.Println(dumpCmdStderr.String()) 97 return errors.Wrap(err, "failed to wait for pg_dump command") 98 } 99 100 _ = w.Close() // After dumpCmd is done, close the write end of the pipe 101 102 if err = createCmd.Wait(); err != nil { 103 fmt.Println(err) 104 fmt.Println(createCmdStderr.String()) 105 return errors.Wrap(err, "failed to wait for psql command") 106 } 107 } 108 109 return nil 110 } 111 112 func (p *pgTester) runCmd(stdin, command string, args ...string) error { 113 cmd := exec.Command(command, args...) 114 cmd.Env = append(os.Environ(), p.pgEnv()...) 115 116 if len(stdin) != 0 { 117 cmd.Stdin = strings.NewReader(stdin) 118 } 119 120 stdout := &bytes.Buffer{} 121 stderr := &bytes.Buffer{} 122 cmd.Stdout = stdout 123 cmd.Stderr = stderr 124 if err := cmd.Run(); err != nil { 125 fmt.Println("failed running:", command, args) 126 fmt.Println(stdout.String()) 127 fmt.Println(stderr.String()) 128 return err 129 } 130 131 return nil 132 } 133 134 func (p *pgTester) pgEnv() []string { 135 return []string{ 136 fmt.Sprintf("PGHOST=%s", p.host), 137 fmt.Sprintf("PGPORT=%d", p.port), 138 fmt.Sprintf("PGUSER=%s", p.user), 139 fmt.Sprintf("PGPASSFILE=%s", p.pgPassFile), 140 } 141 } 142 143 func (p *pgTester) makePGPassFile() error { 144 tmp, err := os.CreateTemp("", "pgpass") 145 if err != nil { 146 return errors.Wrap(err, "failed to create option file") 147 } 148 149 fmt.Fprintf(tmp, "%s:%d:postgres:%s", p.host, p.port, p.user) 150 if len(p.pass) != 0 { 151 fmt.Fprintf(tmp, ":%s", p.pass) 152 } 153 fmt.Fprintln(tmp) 154 155 fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user) 156 if len(p.pass) != 0 { 157 fmt.Fprintf(tmp, ":%s", p.pass) 158 } 159 fmt.Fprintln(tmp) 160 161 fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user) 162 if len(p.pass) != 0 { 163 fmt.Fprintf(tmp, ":%s", p.pass) 164 } 165 fmt.Fprintln(tmp) 166 167 p.pgPassFile = tmp.Name() 168 return tmp.Close() 169 } 170 171 func (p *pgTester) createTestDB() error { 172 return p.runCmd("", "createdb", p.testDBName) 173 } 174 175 func (p *pgTester) dropTestDB() error { 176 return p.runCmd("", "dropdb", "--if-exists", p.testDBName) 177 } 178 179 // teardown executes cleanup tasks when the tests finish running 180 func (p *pgTester) teardown() error { 181 var err error 182 if err = p.dbConn.Close(); err != nil { 183 return err 184 } 185 p.dbConn = nil 186 187 if !p.skipSQLCmd { 188 if err = p.dropTestDB(); err != nil { 189 return err 190 } 191 } 192 193 return os.Remove(p.pgPassFile) 194 } 195 196 func (p *pgTester) conn() (*sql.DB, error) { 197 if p.dbConn != nil { 198 return p.dbConn, nil 199 } 200 201 var err error 202 p.dbConn, err = sql.Open("postgres", driver.PSQLBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode)) 203 if err != nil { 204 return nil, err 205 } 206 207 return p.dbConn, nil 208 }