github.com/mg98/scriptup@v0.1.0/pkg/scriptup/storage/sql_storage.go (about)

     1  package storage
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	_ "github.com/go-sql-driver/mysql"
     8  	_ "github.com/lib/pq"
     9  	"regexp"
    10  )
    11  
    12  type SQLConnectionDetails struct {
    13  	Dialect      string
    14  	Host         string
    15  	Port         int
    16  	User         string
    17  	Password     string
    18  	DatabaseName string
    19  	TableName    string
    20  	SSLMode      string
    21  }
    22  
    23  type SQLStorage struct {
    24  	connDetails *SQLConnectionDetails
    25  	conn        *sql.DB
    26  }
    27  
    28  func NewSQLStorage(scd *SQLConnectionDetails) *SQLStorage {
    29  	return &SQLStorage{connDetails: scd}
    30  }
    31  
    32  func (s *SQLStorage) Open() error {
    33  	var connString string
    34  	switch s.connDetails.Dialect {
    35  	case "postgres":
    36  		connString = fmt.Sprintf(
    37  			"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
    38  			s.connDetails.Host,
    39  			s.connDetails.Port,
    40  			s.connDetails.User,
    41  			s.connDetails.Password,
    42  			s.connDetails.DatabaseName,
    43  			s.connDetails.SSLMode,
    44  		)
    45  		break
    46  	case "mysql":
    47  		connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
    48  			s.connDetails.User,
    49  			s.connDetails.Password,
    50  			s.connDetails.Host,
    51  			s.connDetails.Port,
    52  			s.connDetails.DatabaseName,
    53  		)
    54  	default:
    55  		return errors.New("sql dialect not supported")
    56  	}
    57  
    58  	var err error
    59  	s.conn, err = sql.Open(s.connDetails.Dialect, connString)
    60  	if err != nil {
    61  		return err
    62  	}
    63  	err = s.conn.Ping()
    64  	if err != nil {
    65  		return err
    66  	}
    67  	if err := s.setUp(); err != nil {
    68  		return err
    69  	}
    70  	return nil
    71  }
    72  
    73  // Close connection to the database.
    74  func (s *SQLStorage) Close() error {
    75  	return s.conn.Close()
    76  }
    77  
    78  // setUp adds the required schema to the database if it does not already exist.
    79  func (s *SQLStorage) setUp() error {
    80  	// loose validation just to protect against sql injection
    81  	if !regexp.MustCompile("^[a-zA-Z0-9-_.]*$").MatchString(s.connDetails.TableName) {
    82  		return errors.New("table name contains invalid characters")
    83  	}
    84  	var query string
    85  	switch s.connDetails.Dialect {
    86  	case "postgres":
    87  		query = fmt.Sprintf(
    88  			"CREATE TABLE IF NOT EXISTS %s (migration_name VARCHAR PRIMARY KEY)",
    89  			s.connDetails.TableName,
    90  		)
    91  		break
    92  	case "mysql":
    93  		query = fmt.Sprintf(
    94  			"CREATE TABLE IF NOT EXISTS %s (migration_name VARCHAR(255), PRIMARY KEY (migration_name))",
    95  			s.connDetails.TableName,
    96  		)
    97  	}
    98  	_, err := s.conn.Exec(query)
    99  	return err
   100  }
   101  
   102  // Append inserts a new record for that entry into the database.
   103  func (s *SQLStorage) Append(entry string) error {
   104  	if _, err := s.conn.Exec(fmt.Sprintf("INSERT INTO %s VALUES (?)", s.connDetails.TableName), entry); err != nil {
   105  		return err
   106  	}
   107  	return nil
   108  }
   109  
   110  // Pop deletes the most recent record (according to the migration date).
   111  func (s *SQLStorage) Pop() error {
   112  	var query string
   113  	switch s.connDetails.Dialect {
   114  	case "postgres":
   115  		query = fmt.Sprintf(
   116  			"DELETE FROM %s WHERE ctid IN (SELECT ctid FROM %s ORDER BY migration_name DESC LIMIT 1)",
   117  			s.connDetails.TableName, s.connDetails.TableName,
   118  		)
   119  		break
   120  	case "mysql":
   121  		query = fmt.Sprintf("DELETE FROM %s ORDER BY migration_name DESC LIMIT 1", s.connDetails.TableName)
   122  	}
   123  	res, err := s.conn.Exec(query)
   124  	if err != nil {
   125  		return err
   126  	}
   127  	i, err := res.RowsAffected()
   128  	if err != nil {
   129  		return err
   130  	}
   131  	if i == 0 {
   132  		return errors.New("no migrations left")
   133  	}
   134  	return nil
   135  }
   136  
   137  // All returns all entries from the database.
   138  func (s *SQLStorage) All(o Order) ([]string, error) {
   139  	rows, err := s.conn.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY migration_name %s", s.connDetails.TableName, o))
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	defer rows.Close()
   144  	var res []string
   145  	for rows.Next() {
   146  		var value string
   147  		if err := rows.Scan(&value); err != nil {
   148  			return nil, err
   149  		}
   150  		res = append(res, value)
   151  	}
   152  	return res, nil
   153  }
   154  
   155  // Latest returns the most recent record (according to the migration date).
   156  func (s *SQLStorage) Latest() (*string, error) {
   157  	row := s.conn.QueryRow(fmt.Sprintf("SELECT * FROM %s ORDER BY migration_name DESC LIMIT 1", s.connDetails.TableName))
   158  	if err := row.Err(); err != nil {
   159  		return nil, err
   160  	}
   161  	var name string
   162  	if err := row.Scan(&name); err != nil && !errors.Is(err, sql.ErrNoRows) {
   163  		return nil, err
   164  	}
   165  	if name == "" {
   166  		return nil, nil
   167  	}
   168  	return &name, nil
   169  }