github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-mssql/driver/override/test/singleton/mssql_main_test.go.tpl (about) 1 var rgxMSSQLkey = regexp.MustCompile(`(?m)^ALTER TABLE .*ADD\s+CONSTRAINT .* FOREIGN KEY.*?.*\n?REFERENCES.*`) 2 3 type mssqlTester struct { 4 dbConn *sql.DB 5 dbName string 6 host string 7 user string 8 pass string 9 sslmode string 10 port int 11 testDBName string 12 skipSQLCmd bool 13 } 14 15 func init() { 16 dbMain = &mssqlTester{} 17 } 18 19 func (m *mssqlTester) setup() error { 20 var err error 21 22 viper.SetDefault("mssql.schema", "dbo") 23 viper.SetDefault("mssql.sslmode", "true") 24 viper.SetDefault("mssql.port", 1433) 25 26 m.dbName = viper.GetString("mssql.dbname") 27 m.host = viper.GetString("mssql.host") 28 m.user = viper.GetString("mssql.user") 29 m.pass = viper.GetString("mssql.pass") 30 m.port = viper.GetInt("mssql.port") 31 m.sslmode = viper.GetString("mssql.sslmode") 32 m.testDBName = viper.GetString("mssql.testdbname") 33 m.skipSQLCmd = viper.GetBool("mssql.skipsqlcmd") 34 35 err = vala.BeginValidation().Validate( 36 vala.StringNotEmpty(viper.GetString("mssql.user"), "mssql.user"), 37 vala.StringNotEmpty(viper.GetString("mssql.host"), "mssql.host"), 38 vala.Not(vala.Equals(viper.GetInt("mssql.port"), 0, "mssql.port")), 39 vala.StringNotEmpty(viper.GetString("mssql.dbname"), "mssql.dbname"), 40 vala.StringNotEmpty(viper.GetString("mssql.sslmode"), "mssql.sslmode"), 41 ).Check() 42 43 if err != nil { 44 return err 45 } 46 47 // Create a randomized db name. 48 if len(m.testDBName) == 0 { 49 m.testDBName = randomize.StableDBName(m.dbName) 50 } 51 52 if !m.skipSQLCmd { 53 if err = m.dropTestDB(); err != nil { 54 return err 55 } 56 if err = m.createTestDB(); err != nil { 57 return err 58 } 59 60 createCmd := exec.Command("sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass, "-d", m.testDBName) 61 62 f, err := os.Open("tables_schema.sql") 63 if err != nil { 64 return errors.Wrap(err, "failed to open tables_schema.sql file") 65 } 66 67 defer func() { _ = f.Close() }() 68 69 stderr := &bytes.Buffer{} 70 createCmd.Stdin = newFKeyDestroyer(rgxMSSQLkey, f) 71 createCmd.Stderr = stderr 72 73 if err = createCmd.Start(); err != nil { 74 return errors.Wrap(err, "failed to start sqlcmd command") 75 } 76 77 if err = createCmd.Wait(); err != nil { 78 fmt.Println(err) 79 fmt.Println(stderr.String()) 80 return errors.Wrap(err, "failed to wait for sqlcmd command") 81 } 82 } 83 84 return nil 85 } 86 87 func (m *mssqlTester) sslMode(mode string) string { 88 switch mode { 89 case "true": 90 return "true" 91 case "false": 92 return "false" 93 default: 94 return "disable" 95 } 96 } 97 98 func (m *mssqlTester) createTestDB() error { 99 sql := fmt.Sprintf(` 100 CREATE DATABASE %s; 101 GO 102 ALTER DATABASE %[1]s 103 SET READ_COMMITTED_SNAPSHOT ON; 104 GO`, m.testDBName) 105 return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass) 106 } 107 108 func (m *mssqlTester) dropTestDB() error { 109 // Since MS SQL 2016 it can be done with 110 // DROP DATABASE [ IF EXISTS ] { database_name | database_snapshot_name } [ ,...n ] [;] 111 sql := fmt.Sprintf(` 112 IF EXISTS(SELECT name FROM sys.databases 113 WHERE name = '%s') 114 DROP DATABASE %s 115 GO`, m.testDBName, m.testDBName) 116 return m.runCmd(sql, "sqlcmd", "-S", m.host, "-U", m.user, "-P", m.pass) 117 } 118 119 func (m *mssqlTester) teardown() error { 120 if m.dbConn != nil { 121 m.dbConn.Close() 122 } 123 124 if !m.skipSQLCmd { 125 if err := m.dropTestDB(); err != nil { 126 return err 127 } 128 } 129 130 return nil 131 } 132 133 func (m *mssqlTester) runCmd(stdin, command string, args ...string) error { 134 cmd := exec.Command(command, args...) 135 cmd.Stdin = strings.NewReader(stdin) 136 137 stdout := &bytes.Buffer{} 138 stderr := &bytes.Buffer{} 139 cmd.Stdout = stdout 140 cmd.Stderr = stderr 141 if err := cmd.Run(); err != nil { 142 fmt.Println("failed running:", command, args) 143 fmt.Println(stdout.String()) 144 fmt.Println(stderr.String()) 145 return err 146 } 147 148 return nil 149 } 150 151 func (m *mssqlTester) conn() (*sql.DB, error) { 152 if m.dbConn != nil { 153 return m.dbConn, nil 154 } 155 156 var err error 157 m.dbConn, err = sql.Open("mssql", driver.MSSQLBuildQueryString(m.user, m.pass, m.testDBName, m.host, m.port, m.sslmode)) 158 if err != nil { 159 return nil, err 160 } 161 162 return m.dbConn, nil 163 }