github.com/mg98/scriptup@v0.1.0/pkg/scriptup/storage/sql_storage.go (about) 1 package storage 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 _ "github.com/go-sql-driver/mysql" 8 _ "github.com/lib/pq" 9 "regexp" 10 ) 11 12 type SQLConnectionDetails struct { 13 Dialect string 14 Host string 15 Port int 16 User string 17 Password string 18 DatabaseName string 19 TableName string 20 SSLMode string 21 } 22 23 type SQLStorage struct { 24 connDetails *SQLConnectionDetails 25 conn *sql.DB 26 } 27 28 func NewSQLStorage(scd *SQLConnectionDetails) *SQLStorage { 29 return &SQLStorage{connDetails: scd} 30 } 31 32 func (s *SQLStorage) Open() error { 33 var connString string 34 switch s.connDetails.Dialect { 35 case "postgres": 36 connString = fmt.Sprintf( 37 "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", 38 s.connDetails.Host, 39 s.connDetails.Port, 40 s.connDetails.User, 41 s.connDetails.Password, 42 s.connDetails.DatabaseName, 43 s.connDetails.SSLMode, 44 ) 45 break 46 case "mysql": 47 connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", 48 s.connDetails.User, 49 s.connDetails.Password, 50 s.connDetails.Host, 51 s.connDetails.Port, 52 s.connDetails.DatabaseName, 53 ) 54 default: 55 return errors.New("sql dialect not supported") 56 } 57 58 var err error 59 s.conn, err = sql.Open(s.connDetails.Dialect, connString) 60 if err != nil { 61 return err 62 } 63 err = s.conn.Ping() 64 if err != nil { 65 return err 66 } 67 if err := s.setUp(); err != nil { 68 return err 69 } 70 return nil 71 } 72 73 // Close connection to the database. 74 func (s *SQLStorage) Close() error { 75 return s.conn.Close() 76 } 77 78 // setUp adds the required schema to the database if it does not already exist. 79 func (s *SQLStorage) setUp() error { 80 // loose validation just to protect against sql injection 81 if !regexp.MustCompile("^[a-zA-Z0-9-_.]*$").MatchString(s.connDetails.TableName) { 82 return errors.New("table name contains invalid characters") 83 } 84 var query string 85 switch s.connDetails.Dialect { 86 case "postgres": 87 query = fmt.Sprintf( 88 "CREATE TABLE IF NOT EXISTS %s (migration_name VARCHAR PRIMARY KEY)", 89 s.connDetails.TableName, 90 ) 91 break 92 case "mysql": 93 query = fmt.Sprintf( 94 "CREATE TABLE IF NOT EXISTS %s (migration_name VARCHAR(255), PRIMARY KEY (migration_name))", 95 s.connDetails.TableName, 96 ) 97 } 98 _, err := s.conn.Exec(query) 99 return err 100 } 101 102 // Append inserts a new record for that entry into the database. 103 func (s *SQLStorage) Append(entry string) error { 104 if _, err := s.conn.Exec(fmt.Sprintf("INSERT INTO %s VALUES (?)", s.connDetails.TableName), entry); err != nil { 105 return err 106 } 107 return nil 108 } 109 110 // Pop deletes the most recent record (according to the migration date). 111 func (s *SQLStorage) Pop() error { 112 var query string 113 switch s.connDetails.Dialect { 114 case "postgres": 115 query = fmt.Sprintf( 116 "DELETE FROM %s WHERE ctid IN (SELECT ctid FROM %s ORDER BY migration_name DESC LIMIT 1)", 117 s.connDetails.TableName, s.connDetails.TableName, 118 ) 119 break 120 case "mysql": 121 query = fmt.Sprintf("DELETE FROM %s ORDER BY migration_name DESC LIMIT 1", s.connDetails.TableName) 122 } 123 res, err := s.conn.Exec(query) 124 if err != nil { 125 return err 126 } 127 i, err := res.RowsAffected() 128 if err != nil { 129 return err 130 } 131 if i == 0 { 132 return errors.New("no migrations left") 133 } 134 return nil 135 } 136 137 // All returns all entries from the database. 138 func (s *SQLStorage) All(o Order) ([]string, error) { 139 rows, err := s.conn.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY migration_name %s", s.connDetails.TableName, o)) 140 if err != nil { 141 return nil, err 142 } 143 defer rows.Close() 144 var res []string 145 for rows.Next() { 146 var value string 147 if err := rows.Scan(&value); err != nil { 148 return nil, err 149 } 150 res = append(res, value) 151 } 152 return res, nil 153 } 154 155 // Latest returns the most recent record (according to the migration date). 156 func (s *SQLStorage) Latest() (*string, error) { 157 row := s.conn.QueryRow(fmt.Sprintf("SELECT * FROM %s ORDER BY migration_name DESC LIMIT 1", s.connDetails.TableName)) 158 if err := row.Err(); err != nil { 159 return nil, err 160 } 161 var name string 162 if err := row.Scan(&name); err != nil && !errors.Is(err, sql.ErrNoRows) { 163 return nil, err 164 } 165 if name == "" { 166 return nil, nil 167 } 168 return &name, nil 169 }