github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-psql/driver/override/test/singleton/psql_main_test.go.tpl (about)

     1  var rgxPGFkey = regexp.MustCompile(`(?m)^ALTER TABLE .*\n\s+ADD CONSTRAINT .*? FOREIGN KEY .*?;\n`)
     2  
     3  type pgTester struct {
     4  	dbConn *sql.DB
     5  
     6  	dbName  string
     7  	host    string
     8  	user    string
     9  	pass    string
    10  	sslmode string
    11  	port    int
    12  
    13  	pgPassFile string
    14  
    15  	testDBName string
    16  	skipSQLCmd bool
    17  }
    18  
    19  func init() {
    20  	dbMain = &pgTester{}
    21  }
    22  
    23  // setup dumps the database schema and imports it into a temporary randomly
    24  // generated test database so that tests can be run against it using the
    25  // generated sqlboiler ORM package.
    26  func (p *pgTester) setup() error {
    27  	var err error
    28  
    29  	viper.SetDefault("psql.schema", "public")
    30  	viper.SetDefault("psql.port", 5432)
    31  	viper.SetDefault("psql.sslmode", "require")
    32  
    33  	p.dbName = viper.GetString("psql.dbname")
    34  	p.host = viper.GetString("psql.host")
    35  	p.user = viper.GetString("psql.user")
    36  	p.pass = viper.GetString("psql.pass")
    37  	p.port = viper.GetInt("psql.port")
    38  	p.sslmode = viper.GetString("psql.sslmode")
    39  	p.testDBName = viper.GetString("psql.testdbname")
    40  	p.skipSQLCmd = viper.GetBool("psql.skipsqlcmd")
    41  
    42  	err = vala.BeginValidation().Validate(
    43  		vala.StringNotEmpty(p.user, "psql.user"),
    44  		vala.StringNotEmpty(p.host, "psql.host"),
    45  		vala.Not(vala.Equals(p.port, 0, "psql.port")),
    46  		vala.StringNotEmpty(p.dbName, "psql.dbname"),
    47  		vala.StringNotEmpty(p.sslmode, "psql.sslmode"),
    48  	).Check()
    49  
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	// if no testing DB passed
    55  	if len(p.testDBName) == 0 {
    56  		// Create a randomized db name.
    57  		p.testDBName = randomize.StableDBName(p.dbName)
    58  	}
    59  
    60  	if err = p.makePGPassFile(); err != nil {
    61  		return err
    62  	}
    63  
    64  	if !p.skipSQLCmd {
    65  		if err = p.dropTestDB(); err != nil {
    66  			return err
    67  		}
    68  		if err = p.createTestDB(); err != nil {
    69  			return err
    70  		}
    71  
    72  		dumpCmd := exec.Command("pg_dump", "--schema-only", p.dbName)
    73  		dumpCmd.Env = append(os.Environ(), p.pgEnv()...)
    74  		createCmd := exec.Command("psql", p.testDBName)
    75  		createCmd.Env = append(os.Environ(), p.pgEnv()...)
    76  
    77  		r, w := io.Pipe()
    78  		dumpCmdStderr := &bytes.Buffer{}
    79  		createCmdStderr := &bytes.Buffer{}
    80  
    81  		dumpCmd.Stdout = w
    82  		dumpCmd.Stderr = dumpCmdStderr
    83  
    84  		createCmd.Stdin = newFKeyDestroyer(rgxPGFkey, r)
    85  		createCmd.Stderr = createCmdStderr
    86  
    87  		if err = dumpCmd.Start(); err != nil {
    88  			return errors.Wrap(err, "failed to start pg_dump command")
    89  		}
    90  		if err = createCmd.Start(); err != nil {
    91  			return errors.Wrap(err, "failed to start psql command")
    92  		}
    93  
    94  		if err = dumpCmd.Wait(); err != nil {
    95  			fmt.Println(err)
    96  			fmt.Println(dumpCmdStderr.String())
    97  			return errors.Wrap(err, "failed to wait for pg_dump command")
    98  		}
    99  
   100  		_ = w.Close() // After dumpCmd is done, close the write end of the pipe
   101  
   102  		if err = createCmd.Wait(); err != nil {
   103  			fmt.Println(err)
   104  			fmt.Println(createCmdStderr.String())
   105  			return errors.Wrap(err, "failed to wait for psql command")
   106  		}
   107  	}
   108  
   109  	return nil
   110  }
   111  
   112  func (p *pgTester) runCmd(stdin, command string, args ...string) error {
   113  	cmd := exec.Command(command, args...)
   114  	cmd.Env = append(os.Environ(), p.pgEnv()...)
   115  
   116  	if len(stdin) != 0 {
   117  		cmd.Stdin = strings.NewReader(stdin)
   118  	}
   119  
   120  	stdout := &bytes.Buffer{}
   121  	stderr := &bytes.Buffer{}
   122  	cmd.Stdout = stdout
   123  	cmd.Stderr = stderr
   124  	if err := cmd.Run(); err != nil {
   125  		fmt.Println("failed running:", command, args)
   126  		fmt.Println(stdout.String())
   127  		fmt.Println(stderr.String())
   128  		return err
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  func (p *pgTester) pgEnv() []string {
   135  	return []string{
   136  		fmt.Sprintf("PGHOST=%s", p.host),
   137  		fmt.Sprintf("PGPORT=%d", p.port),
   138  		fmt.Sprintf("PGUSER=%s", p.user),
   139  		fmt.Sprintf("PGPASSFILE=%s", p.pgPassFile),
   140  	}
   141  }
   142  
   143  func (p *pgTester) makePGPassFile() error {
   144  	tmp, err := os.CreateTemp("", "pgpass")
   145  	if err != nil {
   146  		return errors.Wrap(err, "failed to create option file")
   147  	}
   148  
   149  	fmt.Fprintf(tmp, "%s:%d:postgres:%s", p.host, p.port, p.user)
   150  	if len(p.pass) != 0 {
   151  		fmt.Fprintf(tmp, ":%s", p.pass)
   152  	}
   153  	fmt.Fprintln(tmp)
   154  
   155  	fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.dbName, p.user)
   156  	if len(p.pass) != 0 {
   157  		fmt.Fprintf(tmp, ":%s", p.pass)
   158  	}
   159  	fmt.Fprintln(tmp)
   160  
   161  	fmt.Fprintf(tmp, "%s:%d:%s:%s", p.host, p.port, p.testDBName, p.user)
   162  	if len(p.pass) != 0 {
   163  		fmt.Fprintf(tmp, ":%s", p.pass)
   164  	}
   165  	fmt.Fprintln(tmp)
   166  
   167  	p.pgPassFile = tmp.Name()
   168  	return tmp.Close()
   169  }
   170  
   171  func (p *pgTester) createTestDB() error {
   172  	return p.runCmd("", "createdb", p.testDBName)
   173  }
   174  
   175  func (p *pgTester) dropTestDB() error {
   176  	return p.runCmd("", "dropdb", "--if-exists", p.testDBName)
   177  }
   178  
   179  // teardown executes cleanup tasks when the tests finish running
   180  func (p *pgTester) teardown() error {
   181  	var err error
   182  	if err = p.dbConn.Close(); err != nil {
   183  		return err
   184  	}
   185  	p.dbConn = nil
   186  
   187  	if !p.skipSQLCmd {
   188  		if err = p.dropTestDB(); err != nil {
   189  			return err
   190  		}
   191  	}
   192  
   193  	return os.Remove(p.pgPassFile)
   194  }
   195  
   196  func (p *pgTester) conn() (*sql.DB, error) {
   197  	if p.dbConn != nil {
   198  		return p.dbConn, nil
   199  	}
   200  
   201  	var err error
   202  	p.dbConn, err = sql.Open("postgres", driver.PSQLBuildQueryString(p.user, p.pass, p.testDBName, p.host, p.port, p.sslmode))
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	return p.dbConn, nil
   208  }