github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-mssql/driver/override/test/singleton/mssql_main_test.go.tpl (about)

     1  var rgxMSSQLkey = regexp.MustCompile(`(?m)^ALTER TABLE .*ADD\s+CONSTRAINT .* FOREIGN KEY.*?.*\n?REFERENCES.*`)
     2  
     3  type mssqlTester struct {
     4  	dbConn     *sql.DB
     5  	dbName     string
     6  	host       string
     7  	user       string
     8  	pass       string
     9  	sslmode    string
    10  	port       int
    11  	testDBName string
    12  	skipSQLCmd bool
    13  }
    14  
    15  func init() {
    16  	dbMain = &mssqlTester{}
    17  }
    18  
    19  func (m *mssqlTester) setup() error {
    20  	var err error
    21  
    22  	viper.SetDefault("mssql.schema", "dbo")
    23  	viper.SetDefault("mssql.sslmode", "true")
    24  	viper.SetDefault("mssql.port", 1433)
    25  
    26  	m.dbName = viper.GetString("mssql.dbname")
    27  	m.host = viper.GetString("mssql.host")
    28  	m.user = viper.GetString("mssql.user")
    29  	m.pass = viper.GetString("mssql.pass")
    30  	m.port = viper.GetInt("mssql.port")
    31  	m.sslmode = viper.GetString("mssql.sslmode")
    32  	m.testDBName = viper.GetString("mssql.testdbname")
    33  	m.skipSQLCmd = viper.GetBool("mssql.skipsqlcmd")
    34  
    35  	err = vala.BeginValidation().Validate(
    36  		vala.StringNotEmpty(viper.GetString("mssql.user"), "mssql.user"),
    37  		vala.StringNotEmpty(viper.GetString("mssql.host"), "mssql.host"),
    38  		vala.Not(vala.Equals(viper.GetInt("mssql.port"), 0, "mssql.port")),
    39  		vala.StringNotEmpty(viper.GetString("mssql.dbname"), "mssql.dbname"),
    40  		vala.StringNotEmpty(viper.GetString("mssql.sslmode"), "mssql.sslmode"),
    41  	).Check()
    42  
    43  	if err != nil {
    44  		return err
    45  	}
    46  
    47  	// Create a randomized db name.
    48  	if len(m.testDBName) == 0 {
    49  		m.testDBName = randomize.StableDBName(m.dbName)
    50  	}
    51  
    52  	if !m.skipSQLCmd {
    53  		if err = m.dropTestDB(); err != nil {
    54  			return err
    55  		}
    56  		if err = m.createTestDB(); err != nil {
    57  			return err
    58  		}
    59  
    60  		createCmd := exec.Command("sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass, "-d", m.testDBName)
    61  
    62  		f, err := os.Open("tables_schema.sql")
    63  		if err != nil {
    64  			return errors.Wrap(err, "failed to open tables_schema.sql file")
    65  		}
    66  
    67  		defer func() { _ = f.Close() }()
    68  
    69  		stderr := &bytes.Buffer{}
    70  		createCmd.Stdin = newFKeyDestroyer(rgxMSSQLkey, f)
    71  		createCmd.Stderr = stderr
    72  
    73  		if err = createCmd.Start(); err != nil {
    74  			return errors.Wrap(err, "failed to start sqlcmd command")
    75  		}
    76  
    77  		if err = createCmd.Wait(); err != nil {
    78  			fmt.Println(err)
    79  			fmt.Println(stderr.String())
    80  			return errors.Wrap(err, "failed to wait for sqlcmd command")
    81  		}
    82  	}
    83  
    84  	return nil
    85  }
    86  
    87  func (m *mssqlTester) sslMode(mode string) string {
    88  	switch mode {
    89  	case "true":
    90  		return "true"
    91  	case "false":
    92  		return "false"
    93  	default:
    94  		return "disable"
    95  	}
    96  }
    97  
    98  func (m *mssqlTester) createTestDB() error {
    99  	sql := fmt.Sprintf(`
   100  	CREATE DATABASE %s;
   101  	GO
   102  	ALTER DATABASE %[1]s
   103  	SET READ_COMMITTED_SNAPSHOT ON;
   104  	GO`, m.testDBName)
   105  	return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass)
   106  }
   107  
   108  func (m *mssqlTester) dropTestDB() error {
   109  	// Since MS SQL 2016 it can be done with
   110  	// DROP DATABASE [ IF EXISTS ] { database_name | database_snapshot_name } [ ,...n ] [;]
   111  	sql := fmt.Sprintf(`
   112  	IF EXISTS(SELECT name FROM sys.databases 
   113  		WHERE name = '%s')
   114  		DROP DATABASE %s
   115  	GO`, m.testDBName, m.testDBName)
   116  	return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass)
   117  }
   118  
   119  func (m *mssqlTester) teardown() error {
   120  	if m.dbConn != nil {
   121  		m.dbConn.Close()
   122  	}
   123  
   124  	if !m.skipSQLCmd {
   125  		if err := m.dropTestDB(); err != nil {
   126  			return err
   127  		}
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  func (m *mssqlTester) runCmd(stdin, command string, args ...string) error {
   134  	cmd := exec.Command(command, args...)
   135  	cmd.Stdin = strings.NewReader(stdin)
   136  
   137  	stdout := &bytes.Buffer{}
   138  	stderr := &bytes.Buffer{}
   139  	cmd.Stdout = stdout
   140  	cmd.Stderr = stderr
   141  	if err := cmd.Run(); err != nil {
   142  		fmt.Println("failed running:", command, args)
   143  		fmt.Println(stdout.String())
   144  		fmt.Println(stderr.String())
   145  		return err
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  func (m *mssqlTester) conn() (*sql.DB, error) {
   152  	if m.dbConn != nil {
   153  		return m.dbConn, nil
   154  	}
   155  
   156  	var err error
   157  	m.dbConn, err = sql.Open("mssql", driver.MSSQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode))
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	return m.dbConn, nil
   163  }