github.com/adharshmk96/stk@v1.2.3/pkg/sqlMigrator/core.go (about)

     1  package sqlmigrator
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"path"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/adharshmk96/stk/consts"
    12  	"github.com/adharshmk96/stk/pkg/utils"
    13  	"github.com/spf13/viper"
    14  )
    15  
    16  const (
    17  	DEFAULT_LOG_FILE = ".commit-status"
    18  )
    19  
    20  type MigrationType string
    21  
    22  const (
    23  	MigrationUp   MigrationType = "up"
    24  	MigrationDown MigrationType = "down"
    25  )
    26  
    27  type Database string
    28  
    29  const (
    30  	PostgresDB Database = "postgres"
    31  	MySQLDB    Database = "mysql"
    32  	SQLiteDB   Database = "sqlite"
    33  )
    34  
    35  type Migrations []*MigrationFileEntry
    36  
    37  type MigrationFileEntry struct {
    38  	Number       int
    39  	Name         string
    40  	Committed    bool
    41  	UpFilePath   string
    42  	DownFilePath string
    43  }
    44  
    45  func ParseMigrationEntry(migrationEntry string) (*MigrationFileEntry, error) {
    46  	parts := strings.Split(migrationEntry, "_")
    47  	partLength := len(parts)
    48  
    49  	if partLength == 0 {
    50  		return nil, ErrInvalidMigration
    51  	}
    52  
    53  	commit_status := parts[partLength-1]
    54  	if commit_status != "up" && commit_status != "down" {
    55  		return nil, ErrInvalidMigration
    56  	}
    57  
    58  	name := strings.Join(parts[1:partLength-1], "_")
    59  
    60  	number, err := strconv.Atoi(parts[0])
    61  	if err != nil {
    62  		return nil, ErrInvalidMigration
    63  	}
    64  
    65  	rawMigration := &MigrationFileEntry{
    66  		Name:      name,
    67  		Number:    number,
    68  		Committed: commit_status == "up",
    69  	}
    70  
    71  	return rawMigration, nil
    72  }
    73  
    74  func (r *MigrationFileEntry) String() string {
    75  	m_String := fmt.Sprintf("%d", r.Number)
    76  	if r.Name != "" {
    77  		m_String = m_String + "_" + r.Name
    78  	}
    79  	return m_String
    80  }
    81  
    82  func (r *MigrationFileEntry) EntryString() string {
    83  	entryString := fmt.Sprintf("%d", r.Number)
    84  	if r.Name != "" {
    85  		entryString += "_" + r.Name
    86  	}
    87  	if r.Committed {
    88  		entryString += "_up"
    89  	} else {
    90  		entryString += "_down"
    91  	}
    92  	return entryString
    93  }
    94  
    95  func (r *MigrationFileEntry) FileNames(extention string) (string, string) {
    96  	fileName := fmt.Sprintf("%d", r.Number)
    97  	if r.Name != "" {
    98  		fileName += "_" + r.Name
    99  	}
   100  	upFileName := fileName + "_up." + extention
   101  	downFileName := fileName + "_down." + extention
   102  	return upFileName, downFileName
   103  }
   104  
   105  func (r *MigrationFileEntry) LoadFileContent() (string, string) {
   106  
   107  	upContent, err := os.ReadFile(r.UpFilePath)
   108  	if err != nil {
   109  		return "", ""
   110  	}
   111  
   112  	downContent, err := os.ReadFile(r.DownFilePath)
   113  	if err != nil {
   114  		return "", ""
   115  	}
   116  
   117  	return string(upContent), string(downContent)
   118  }
   119  
   120  type Context struct {
   121  	WorkDir    string
   122  	LogFile    string
   123  	Database   Database
   124  	DryRun     bool
   125  	Migrations Migrations
   126  }
   127  
   128  func DefaultContextConfig() (string, Database, string) {
   129  	rootDirectory := viper.GetString(consts.CONFIG_MIGRATOR_WORKDIR)
   130  	dbChoice := viper.GetString(consts.CONFIG_MIGRATOR_DB_TYPE)
   131  	logFile := utils.GetFirst(viper.GetString(consts.CONFIG_MIGRATOR_LOGFILE), DEFAULT_LOG_FILE)
   132  
   133  	dbType := SelectDatabase(dbChoice)
   134  	subDir := SelectSubDirectory(dbType)
   135  
   136  	workDir := path.Join(rootDirectory, subDir)
   137  
   138  	return workDir, dbType, logFile
   139  }
   140  
   141  func NewContext(workDir string, dbType Database, logFile string, dry bool) *Context {
   142  
   143  	ctx := &Context{
   144  		WorkDir:  workDir,
   145  		Database: dbType,
   146  		LogFile:  logFile,
   147  		DryRun:   dry,
   148  	}
   149  
   150  	err := InitializeMigrationsFolder(ctx)
   151  	if err != nil {
   152  		return nil
   153  	}
   154  
   155  	return ctx
   156  }
   157  
   158  func (ctx *Context) LoadMigrationEntries() error {
   159  	migrations := []*MigrationFileEntry{}
   160  	entires, err := ReadLines(path.Join(ctx.WorkDir, ctx.LogFile))
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	for _, entry := range entires {
   166  		migration, err := ParseMigrationEntry(entry)
   167  		if err != nil {
   168  			return err
   169  		}
   170  
   171  		upFileName, downFileName := migration.FileNames(SelectExtention(ctx.Database))
   172  		migration.UpFilePath = path.Join(ctx.WorkDir, upFileName)
   173  		migration.DownFilePath = path.Join(ctx.WorkDir, downFileName)
   174  
   175  		migrations = append(migrations, migration)
   176  	}
   177  
   178  	ctx.Migrations = migrations
   179  	return nil
   180  }
   181  
   182  func (ctx *Context) WriteMigrationEntries() error {
   183  	filePath := path.Join(ctx.WorkDir, ctx.LogFile)
   184  	file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_TRUNC, 0644)
   185  	if err != nil {
   186  		return err
   187  	}
   188  
   189  	defer file.Close()
   190  	for _, migration := range ctx.Migrations {
   191  		_, err := file.WriteString(migration.EntryString() + "\n")
   192  		if err != nil {
   193  			return err
   194  		}
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  type MigrationDBEntry struct {
   201  	Number    int
   202  	Name      string
   203  	Direction string
   204  	Created   time.Time
   205  }