github.com/Elate-DevOps/migrate/v4@v4.0.12/internal/cli/commands.go (about)

     1  package cli
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/Elate-DevOps/migrate/v4"
    13  	_ "github.com/Elate-DevOps/migrate/v4/database/stub" // TODO remove again
    14  	_ "github.com/Elate-DevOps/migrate/v4/source/file"
    15  )
    16  
    17  var (
    18  	errInvalidSequenceWidth     = errors.New("Digits must be positive")
    19  	errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
    20  	errInvalidTimeFormat        = errors.New("Time format may not be empty")
    21  )
    22  
    23  func nextSeqVersion(matches []string, seqDigits int) (string, error) {
    24  	if seqDigits <= 0 {
    25  		return "", errInvalidSequenceWidth
    26  	}
    27  
    28  	nextSeq := uint64(1)
    29  
    30  	if len(matches) > 0 {
    31  		filename := matches[len(matches)-1]
    32  		matchSeqStr := filepath.Base(filename)
    33  		idx := strings.Index(matchSeqStr, "_")
    34  
    35  		if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit
    36  			return "", fmt.Errorf("Malformed migration filename: %s", filename)
    37  		}
    38  
    39  		var err error
    40  		matchSeqStr = matchSeqStr[0:idx]
    41  		nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
    42  
    43  		if err != nil {
    44  			return "", err
    45  		}
    46  
    47  		nextSeq++
    48  	}
    49  
    50  	version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
    51  
    52  	if len(version) > seqDigits {
    53  		return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
    54  	}
    55  
    56  	return version, nil
    57  }
    58  
    59  func timeVersion(startTime time.Time, format string) (version string, err error) {
    60  	switch format {
    61  	case "":
    62  		err = errInvalidTimeFormat
    63  	case "unix":
    64  		version = strconv.FormatInt(startTime.Unix(), 10)
    65  	case "unixNano":
    66  		version = strconv.FormatInt(startTime.UnixNano(), 10)
    67  	default:
    68  		version = startTime.Format(format)
    69  	}
    70  
    71  	return
    72  }
    73  
    74  // createCmd (meant to be called via a CLI command) creates a new migration
    75  func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
    76  	if seq && format != defaultTimeFormat {
    77  		return errIncompatibleSeqAndFormat
    78  	}
    79  
    80  	var version string
    81  	var err error
    82  
    83  	dir = filepath.Clean(dir)
    84  	ext = "." + strings.TrimPrefix(ext, ".")
    85  
    86  	if seq {
    87  		matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
    88  		if err != nil {
    89  			return err
    90  		}
    91  
    92  		version, err = nextSeqVersion(matches, seqDigits)
    93  
    94  		if err != nil {
    95  			return err
    96  		}
    97  	} else {
    98  		version, err = timeVersion(startTime, format)
    99  
   100  		if err != nil {
   101  			return err
   102  		}
   103  	}
   104  
   105  	versionGlob := filepath.Join(dir, version+"_*"+ext)
   106  	matches, err := filepath.Glob(versionGlob)
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if len(matches) > 0 {
   112  		return fmt.Errorf("duplicate migration version: %s", version)
   113  	}
   114  
   115  	if err = os.MkdirAll(dir, os.ModePerm); err != nil {
   116  		return err
   117  	}
   118  
   119  	for _, direction := range []string{"up", "down"} {
   120  		basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
   121  		filename := filepath.Join(dir, basename)
   122  
   123  		if err = createFile(filename); err != nil {
   124  			return err
   125  		}
   126  
   127  		if print {
   128  			absPath, _ := filepath.Abs(filename)
   129  			log.Println(absPath)
   130  		}
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  func createFile(filename string) error {
   137  	// create exclusive (fails if file already exists)
   138  	// os.Create() specifies 0666 as the FileMode, so we're doing the same
   139  	f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o666)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	return f.Close()
   145  }
   146  
   147  func gotoCmd(m *migrate.Migrate, v uint) error {
   148  	if err := m.Migrate(v); err != nil {
   149  		if err != migrate.ErrNoChange {
   150  			return err
   151  		}
   152  		log.Println(err)
   153  	}
   154  	return nil
   155  }
   156  
   157  func upCmd(m *migrate.Migrate, limit int) error {
   158  	if limit >= 0 {
   159  		if err := m.Steps(limit); err != nil {
   160  			if err != migrate.ErrNoChange {
   161  				return err
   162  			}
   163  			log.Println(err)
   164  		}
   165  	} else {
   166  		if err := m.Up(); err != nil {
   167  			if err != migrate.ErrNoChange {
   168  				return err
   169  			}
   170  			log.Println(err)
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  func downCmd(m *migrate.Migrate, limit int) error {
   177  	if limit >= 0 {
   178  		if err := m.Steps(-limit); err != nil {
   179  			if err != migrate.ErrNoChange {
   180  				return err
   181  			}
   182  			log.Println(err)
   183  		}
   184  	} else {
   185  		if err := m.Down(); err != nil {
   186  			if err != migrate.ErrNoChange {
   187  				return err
   188  			}
   189  			log.Println(err)
   190  		}
   191  	}
   192  	return nil
   193  }
   194  
   195  func dropCmd(m *migrate.Migrate) error {
   196  	if err := m.Drop(); err != nil {
   197  		return err
   198  	}
   199  	return nil
   200  }
   201  
   202  func forceCmd(m *migrate.Migrate, v int) error {
   203  	if err := m.Force(v); err != nil {
   204  		return err
   205  	}
   206  	return nil
   207  }
   208  
   209  func versionCmd(m *migrate.Migrate) error {
   210  	v, dirty, err := m.Version()
   211  	if err != nil {
   212  		return err
   213  	}
   214  	if dirty {
   215  		log.Printf("%v (dirty)\n", v)
   216  	} else {
   217  		log.Println(v)
   218  	}
   219  	return nil
   220  }
   221  
   222  // numDownMigrationsFromArgs returns an int for number of migrations to apply
   223  // and a bool indicating if we need a confirm before applying
   224  func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) {
   225  	if applyAll {
   226  		if len(args) > 0 {
   227  			return 0, false, errors.New("-all cannot be used with other arguments")
   228  		}
   229  		return -1, false, nil
   230  	}
   231  
   232  	switch len(args) {
   233  	case 0:
   234  		return -1, true, nil
   235  	case 1:
   236  		downValue := args[0]
   237  		n, err := strconv.ParseUint(downValue, 10, 64)
   238  		if err != nil {
   239  			return 0, false, errors.New("can't read limit argument N")
   240  		}
   241  		return int(n), false, nil
   242  	default:
   243  		return 0, false, errors.New("too many arguments")
   244  	}
   245  }