github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/migrator.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"os"
     7  	"path/filepath"
     8  	"regexp"
     9  	"sort"
    10  	"text/tabwriter"
    11  	"time"
    12  
    13  	"github.com/gobuffalo/pop/v6/logging"
    14  )
    15  
    16  var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz)$`)
    17  
    18  // NewMigrator returns a new "blank" migrator. It is recommended
    19  // to use something like MigrationBox or FileMigrator. A "blank"
    20  // Migrator should only be used as the basis for a new type of
    21  // migration system.
    22  func NewMigrator(c *Connection) Migrator {
    23  	return Migrator{
    24  		Connection: c,
    25  	}
    26  }
    27  
    28  // Migrator forms the basis of all migrations systems.
    29  // It does the actual heavy lifting of running migrations.
    30  // When building a new migration system, you should embed this
    31  // type into your migrator.
    32  type Migrator struct {
    33  	Connection     *Connection
    34  	SchemaPath     string
    35  	UpMigrations   UpMigrations
    36  	DownMigrations DownMigrations
    37  }
    38  
    39  func (m Migrator) migrationIsCompatible(d dialect, mi Migration) bool {
    40  	if mi.DBType == "all" || mi.DBType == d.Name() {
    41  		return true
    42  	}
    43  	return false
    44  }
    45  
    46  // UpLogOnly insert pending "up" migrations logs only, without applying the patch.
    47  // It's used when loading the schema dump, instead of the migrations.
    48  func (m Migrator) UpLogOnly() error {
    49  	c := m.Connection
    50  	return m.exec(func() error {
    51  		mtn := c.MigrationTableName()
    52  		mfs := m.UpMigrations
    53  		sort.Sort(mfs)
    54  		return c.Transaction(func(tx *Connection) error {
    55  			for _, mi := range mfs.Migrations {
    56  				if !m.migrationIsCompatible(c.Dialect, mi) {
    57  					continue
    58  				}
    59  				exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
    60  				if err != nil {
    61  					return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err)
    62  				}
    63  				if exists {
    64  					continue
    65  				}
    66  				_, err = tx.Store.Exec(fmt.Sprintf("insert into %s (version) values ('%s')", mtn, mi.Version))
    67  				if err != nil {
    68  					return fmt.Errorf("problem inserting migration version %s: %w", mi.Version, err)
    69  				}
    70  			}
    71  			return nil
    72  		})
    73  	})
    74  }
    75  
    76  // Up runs pending "up" migrations and applies them to the database.
    77  func (m Migrator) Up() error {
    78  	_, err := m.UpTo(0)
    79  	return err
    80  }
    81  
    82  // UpTo runs up to step "up" migrations and applies them to the database.
    83  // If step <= 0 all pending migrations are run.
    84  func (m Migrator) UpTo(step int) (applied int, err error) {
    85  	c := m.Connection
    86  	err = m.exec(func() error {
    87  		mtn := c.MigrationTableName()
    88  		mfs := m.UpMigrations
    89  		mfs.Filter(func(mf Migration) bool {
    90  			return m.migrationIsCompatible(c.Dialect, mf)
    91  		})
    92  		sort.Sort(mfs)
    93  		for _, mi := range mfs.Migrations {
    94  			exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
    95  			if err != nil {
    96  				return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err)
    97  			}
    98  			if exists {
    99  				continue
   100  			}
   101  			err = c.Transaction(func(tx *Connection) error {
   102  				err := mi.Run(tx)
   103  				if err != nil {
   104  					return err
   105  				}
   106  				_, err = tx.Store.Exec(fmt.Sprintf("insert into %s (version) values ('%s')", mtn, mi.Version))
   107  				if err != nil {
   108  					return fmt.Errorf("problem inserting migration version %s: %w", mi.Version, err)
   109  				}
   110  				return nil
   111  			})
   112  			if err != nil {
   113  				return err
   114  			}
   115  			log(logging.Info, "> %s", mi.Name)
   116  			applied++
   117  			if step > 0 && applied >= step {
   118  				break
   119  			}
   120  		}
   121  		if applied == 0 {
   122  			log(logging.Info, "Migrations already up to date, nothing to apply")
   123  		} else {
   124  			log(logging.Info, "Successfully applied %d migrations.", applied)
   125  		}
   126  		return nil
   127  	})
   128  	return
   129  }
   130  
   131  // Down runs pending "down" migrations and rolls back the
   132  // database by the specified number of steps.
   133  func (m Migrator) Down(step int) error {
   134  	c := m.Connection
   135  	return m.exec(func() error {
   136  		mtn := c.MigrationTableName()
   137  		count, err := c.Count(mtn)
   138  		if err != nil {
   139  			return fmt.Errorf("migration down: unable count existing migration: %w", err)
   140  		}
   141  		mfs := m.DownMigrations
   142  		mfs.Filter(func(mf Migration) bool {
   143  			return m.migrationIsCompatible(c.Dialect, mf)
   144  		})
   145  		sort.Sort(mfs)
   146  		// skip all ran migration
   147  		if len(mfs.Migrations) > count {
   148  			mfs.Migrations = mfs.Migrations[len(mfs.Migrations)-count:]
   149  		}
   150  		// run only required steps
   151  		if step > 0 && len(mfs.Migrations) >= step {
   152  			mfs.Migrations = mfs.Migrations[:step]
   153  		}
   154  		for _, mi := range mfs.Migrations {
   155  			exists, err := c.Where("version = ?", mi.Version).Exists(mtn)
   156  			if err != nil {
   157  				return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err)
   158  			}
   159  			if !exists {
   160  				return fmt.Errorf("migration version %s does not exist", mi.Version)
   161  			}
   162  			err = c.Transaction(func(tx *Connection) error {
   163  				err := mi.Run(tx)
   164  				if err != nil {
   165  					return err
   166  				}
   167  				err = tx.RawQuery(fmt.Sprintf("delete from %s where version = ?", mtn), mi.Version).Exec()
   168  				if err != nil {
   169  					return fmt.Errorf("problem deleting migration version %s: %w", mi.Version, err)
   170  				}
   171  				return nil
   172  			})
   173  			if err != nil {
   174  				return err
   175  			}
   176  
   177  			log(logging.Info, "< %s", mi.Name)
   178  		}
   179  		return nil
   180  	})
   181  }
   182  
   183  // Reset the database by running the down migrations followed by the up migrations.
   184  func (m Migrator) Reset() error {
   185  	err := m.Down(-1)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	return m.Up()
   190  }
   191  
   192  // CreateSchemaMigrations sets up a table to track migrations. This is an idempotent
   193  // operation.
   194  func CreateSchemaMigrations(c *Connection) error {
   195  	mtn := c.MigrationTableName()
   196  	err := c.Open()
   197  	if err != nil {
   198  		return fmt.Errorf("could not open connection: %w", err)
   199  	}
   200  	_, err = c.Store.Exec(fmt.Sprintf("select * from %s", mtn))
   201  	if err == nil {
   202  		return nil
   203  	}
   204  
   205  	return c.Transaction(func(tx *Connection) error {
   206  		schemaMigrations := newSchemaMigrations(mtn)
   207  		smSQL, err := c.Dialect.FizzTranslator().CreateTable(schemaMigrations)
   208  		if err != nil {
   209  			return fmt.Errorf("could not build SQL for schema migration table: %w", err)
   210  		}
   211  		err = tx.RawQuery(smSQL).Exec()
   212  		if err != nil {
   213  			return fmt.Errorf("could not execute %s: %w", smSQL, err)
   214  		}
   215  		return nil
   216  	})
   217  }
   218  
   219  // CreateSchemaMigrations sets up a table to track migrations. This is an idempotent
   220  // operation.
   221  func (m Migrator) CreateSchemaMigrations() error {
   222  	return CreateSchemaMigrations(m.Connection)
   223  }
   224  
   225  // Status prints out the status of applied/pending migrations.
   226  func (m Migrator) Status(out io.Writer) error {
   227  	err := m.CreateSchemaMigrations()
   228  	if err != nil {
   229  		return err
   230  	}
   231  	w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent)
   232  	_, _ = fmt.Fprintln(w, "Version\tName\tStatus\t")
   233  	for _, mf := range m.UpMigrations.Migrations {
   234  		exists, err := m.Connection.Where("version = ?", mf.Version).Exists(m.Connection.MigrationTableName())
   235  		if err != nil {
   236  			return fmt.Errorf("problem with migration: %w", err)
   237  		}
   238  		state := "Pending"
   239  		if exists {
   240  			state = "Applied"
   241  		}
   242  		_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t\n", mf.Version, mf.Name, state)
   243  	}
   244  	return w.Flush()
   245  }
   246  
   247  // DumpMigrationSchema will generate a file of the current database schema
   248  // based on the value of Migrator.SchemaPath
   249  func (m Migrator) DumpMigrationSchema() error {
   250  	if m.SchemaPath == "" {
   251  		return nil
   252  	}
   253  	c := m.Connection
   254  	schema := filepath.Join(m.SchemaPath, "schema.sql")
   255  	f, err := os.Create(schema)
   256  	if err != nil {
   257  		return err
   258  	}
   259  	err = c.Dialect.DumpSchema(f)
   260  	if err != nil {
   261  		os.RemoveAll(schema)
   262  		return err
   263  	}
   264  	return nil
   265  }
   266  
   267  func (m Migrator) exec(fn func() error) error {
   268  	now := time.Now()
   269  	defer func() {
   270  		err := m.DumpMigrationSchema()
   271  		if err != nil {
   272  			log(logging.Warn, "Migrator: unable to dump schema: %v", err)
   273  		}
   274  	}()
   275  	defer printTimer(now)
   276  
   277  	err := m.CreateSchemaMigrations()
   278  	if err != nil {
   279  		return fmt.Errorf("Migrator: problem creating schema migrations: %w", err)
   280  	}
   281  	return fn()
   282  }
   283  
   284  func printTimer(timerStart time.Time) {
   285  	diff := time.Since(timerStart).Seconds()
   286  	if diff > 60 {
   287  		log(logging.Info, "%.4f minutes", diff/60)
   288  	} else {
   289  		log(logging.Info, "%.4f seconds", diff)
   290  	}
   291  }