github.com/SigNoz/golang-migrate/v4@v4.0.0-20231005133642-7493dbaf5f5b/internal/cli/commands.go (about)

     1  package cli
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/golang-migrate/migrate/v4"
    13  	_ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again
    14  	_ "github.com/golang-migrate/migrate/v4/source/file"
    15  )
    16  
    17  var (
    18  	errInvalidSequenceWidth     = errors.New("Digits must be positive")
    19  	errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
    20  	errInvalidTimeFormat        = errors.New("Time format may not be empty")
    21  )
    22  
    23  func nextSeqVersion(matches []string, seqDigits int) (string, error) {
    24  	if seqDigits <= 0 {
    25  		return "", errInvalidSequenceWidth
    26  	}
    27  
    28  	nextSeq := uint64(1)
    29  
    30  	if len(matches) > 0 {
    31  		filename := matches[len(matches)-1]
    32  		matchSeqStr := filepath.Base(filename)
    33  		idx := strings.Index(matchSeqStr, "_")
    34  
    35  		if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit
    36  			return "", fmt.Errorf("Malformed migration filename: %s", filename)
    37  		}
    38  
    39  		var err error
    40  		matchSeqStr = matchSeqStr[0:idx]
    41  		nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
    42  
    43  		if err != nil {
    44  			return "", err
    45  		}
    46  
    47  		nextSeq++
    48  	}
    49  
    50  	version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
    51  
    52  	if len(version) > seqDigits {
    53  		return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
    54  	}
    55  
    56  	return version, nil
    57  }
    58  
    59  func timeVersion(startTime time.Time, format string) (version string, err error) {
    60  	switch format {
    61  	case "":
    62  		err = errInvalidTimeFormat
    63  	case "unix":
    64  		version = strconv.FormatInt(startTime.Unix(), 10)
    65  	case "unixNano":
    66  		version = strconv.FormatInt(startTime.UnixNano(), 10)
    67  	default:
    68  		version = startTime.Format(format)
    69  	}
    70  
    71  	return
    72  }
    73  
    74  // createCmd (meant to be called via a CLI command) creates a new migration
    75  func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
    76  	if seq && format != defaultTimeFormat {
    77  		return errIncompatibleSeqAndFormat
    78  	}
    79  
    80  	var version string
    81  	var err error
    82  
    83  	dir = filepath.Clean(dir)
    84  	ext = "." + strings.TrimPrefix(ext, ".")
    85  
    86  	if seq {
    87  		matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
    88  
    89  		if err != nil {
    90  			return err
    91  		}
    92  
    93  		version, err = nextSeqVersion(matches, seqDigits)
    94  
    95  		if err != nil {
    96  			return err
    97  		}
    98  	} else {
    99  		version, err = timeVersion(startTime, format)
   100  
   101  		if err != nil {
   102  			return err
   103  		}
   104  	}
   105  
   106  	versionGlob := filepath.Join(dir, version+"_*"+ext)
   107  	matches, err := filepath.Glob(versionGlob)
   108  
   109  	if err != nil {
   110  		return err
   111  	}
   112  
   113  	if len(matches) > 0 {
   114  		return fmt.Errorf("duplicate migration version: %s", version)
   115  	}
   116  
   117  	if err = os.MkdirAll(dir, os.ModePerm); err != nil {
   118  		return err
   119  	}
   120  
   121  	for _, direction := range []string{"up", "down"} {
   122  		basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
   123  		filename := filepath.Join(dir, basename)
   124  
   125  		if err = createFile(filename); err != nil {
   126  			return err
   127  		}
   128  
   129  		if print {
   130  			absPath, _ := filepath.Abs(filename)
   131  			log.Println(absPath)
   132  		}
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  func createFile(filename string) error {
   139  	// create exclusive (fails if file already exists)
   140  	// os.Create() specifies 0666 as the FileMode, so we're doing the same
   141  	f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
   142  
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	return f.Close()
   148  }
   149  
   150  func gotoCmd(m *migrate.Migrate, v uint) error {
   151  	if err := m.Migrate(v); err != nil {
   152  		if err != migrate.ErrNoChange {
   153  			return err
   154  		}
   155  		log.Println(err)
   156  	}
   157  	return nil
   158  }
   159  
   160  func upCmd(m *migrate.Migrate, limit int) error {
   161  	if limit >= 0 {
   162  		if err := m.Steps(limit); err != nil {
   163  			if err != migrate.ErrNoChange {
   164  				return err
   165  			}
   166  			log.Println(err)
   167  		}
   168  	} else {
   169  		if err := m.Up(); err != nil {
   170  			if err != migrate.ErrNoChange {
   171  				return err
   172  			}
   173  			log.Println(err)
   174  		}
   175  	}
   176  	return nil
   177  }
   178  
   179  func downCmd(m *migrate.Migrate, limit int) error {
   180  	if limit >= 0 {
   181  		if err := m.Steps(-limit); err != nil {
   182  			if err != migrate.ErrNoChange {
   183  				return err
   184  			}
   185  			log.Println(err)
   186  		}
   187  	} else {
   188  		if err := m.Down(); err != nil {
   189  			if err != migrate.ErrNoChange {
   190  				return err
   191  			}
   192  			log.Println(err)
   193  		}
   194  	}
   195  	return nil
   196  }
   197  
   198  func dropCmd(m *migrate.Migrate) error {
   199  	if err := m.Drop(); err != nil {
   200  		return err
   201  	}
   202  	return nil
   203  }
   204  
   205  func forceCmd(m *migrate.Migrate, v int) error {
   206  	if err := m.Force(v); err != nil {
   207  		return err
   208  	}
   209  	return nil
   210  }
   211  
   212  func versionCmd(m *migrate.Migrate) error {
   213  	v, dirty, err := m.Version()
   214  	if err != nil {
   215  		return err
   216  	}
   217  	if dirty {
   218  		log.Printf("%v (dirty)\n", v)
   219  	} else {
   220  		log.Println(v)
   221  	}
   222  	return nil
   223  }
   224  
   225  // numDownMigrationsFromArgs returns an int for number of migrations to apply
   226  // and a bool indicating if we need a confirm before applying
   227  func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) {
   228  	if applyAll {
   229  		if len(args) > 0 {
   230  			return 0, false, errors.New("-all cannot be used with other arguments")
   231  		}
   232  		return -1, false, nil
   233  	}
   234  
   235  	switch len(args) {
   236  	case 0:
   237  		return -1, true, nil
   238  	case 1:
   239  		downValue := args[0]
   240  		n, err := strconv.ParseUint(downValue, 10, 64)
   241  		if err != nil {
   242  			return 0, false, errors.New("can't read limit argument N")
   243  		}
   244  		return int(n), false, nil
   245  	default:
   246  		return 0, false, errors.New("too many arguments")
   247  	}
   248  }