github.com/wawandco/oxplugins@v0.7.11/tools/liquibase/command.go (about)

     1  package liquibase
     2  
     3  import (
     4  	"context"
     5  	"encoding/xml"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  
    10  	"github.com/jackc/pgx/v4"
    11  	"github.com/spf13/pflag"
    12  	"github.com/wawandco/oxplugins/plugins"
    13  )
    14  
    15  var _ plugins.Command = (*Command)(nil)
    16  var _ plugins.HelpTexter = (*Command)(nil)
    17  
    18  var ErrInvalidInstruction = errors.New("Invalid instruction please specify up or down")
    19  
    20  type Command struct {
    21  	connectionName string
    22  	steps          int
    23  	connections    map[string]URLProvider
    24  	flags          *pflag.FlagSet
    25  }
    26  
    27  func (lb Command) Name() string {
    28  	return "migrate"
    29  }
    30  
    31  func (lb Command) ParentName() string {
    32  	return "db"
    33  }
    34  
    35  func (lb Command) HelpText() string {
    36  	return "runs Liquibase command to update database specified with --conn flag"
    37  }
    38  
    39  func (lb *Command) Run(ctx context.Context, root string, args []string) error {
    40  	if len(args) < 3 {
    41  		return lb.Up()
    42  	}
    43  
    44  	direction := args[2]
    45  	if direction == "up" {
    46  		return lb.Up()
    47  	}
    48  
    49  	if direction == "down" {
    50  		return lb.Rollback()
    51  	}
    52  
    53  	return ErrInvalidInstruction
    54  }
    55  
    56  func (lb *Command) RunBeforeTest(ctx context.Context, root string, args []string) error {
    57  	lb.connectionName = "test"
    58  
    59  	return lb.Up()
    60  }
    61  
    62  func (lb Command) Up() error {
    63  	cx := lb.connections[lb.connectionName]
    64  	if cx == nil {
    65  		return errors.New("connection not found")
    66  	}
    67  
    68  	conn, err := pgx.Connect(context.Background(), cx.URL())
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	err = lb.EnsureTables(conn)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	cl, err := lb.ReadChangelog()
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	for _, v := range cl.Migrations {
    84  		// Read the file
    85  		m, err := lb.ReadMigration(v.File)
    86  		if err != nil {
    87  			return err
    88  		}
    89  
    90  		for _, mc := range m.ChangeSets {
    91  			err = mc.Execute(conn, v.File)
    92  			if err != nil {
    93  				fmt.Printf("[error] error executing `%v`.\n", mc.ID)
    94  				return err
    95  			}
    96  		}
    97  	}
    98  
    99  	fmt.Println("[info] Database up to date.")
   100  
   101  	return nil
   102  }
   103  
   104  func (lb *Command) Rollback() error {
   105  	cx := lb.connections[lb.connectionName]
   106  	if cx == nil {
   107  		return errors.New("connection not found")
   108  	}
   109  
   110  	conn, err := pgx.Connect(context.Background(), cx.URL())
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	err = lb.EnsureTables(conn)
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	// Default to 1 on down.
   121  	if lb.steps == 0 {
   122  		lb.steps = 1
   123  	}
   124  
   125  	for i := 0; i < lb.steps; i++ {
   126  		var id, file string
   127  		row := conn.QueryRow(context.Background(), `SELECT filename, id FROM databasechangelog ORDER BY orderexecuted desc`)
   128  		err = row.Scan(&file, &id)
   129  		if err != nil && !errors.Is(err, pgx.ErrNoRows) {
   130  			return err
   131  		}
   132  
   133  		if errors.Is(err, pgx.ErrNoRows) {
   134  			fmt.Printf("[info] no migrations to run down.")
   135  
   136  			return nil
   137  		}
   138  
   139  		m, err := lb.ReadMigration(file)
   140  		if err != nil {
   141  			return err
   142  		}
   143  
   144  		for _, v := range m.ChangeSets {
   145  			if v.ID != id {
   146  				continue
   147  			}
   148  
   149  			err := v.Rollback(conn)
   150  			if err != nil {
   151  				fmt.Printf("[error] error rolling back `%v`.\n", v.ID)
   152  
   153  				return err
   154  			}
   155  		}
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func (lb *Command) ParseFlags(args []string) {
   162  	lb.flags = pflag.NewFlagSet(lb.Name(), pflag.ContinueOnError)
   163  	lb.flags.StringVarP(&lb.connectionName, "conn", "", "development", "the name of the connection to use")
   164  	lb.flags.IntVarP(&lb.steps, "steps", "s", 0, "number of migrations to run")
   165  	lb.flags.Parse(args) //nolint:errcheck,we don't care hence the flag
   166  }
   167  
   168  func (lb *Command) Flags() *pflag.FlagSet {
   169  	return lb.flags
   170  }
   171  
   172  func (lb Command) ReadChangelog() (*ChangeLog, error) {
   173  	d, err := ioutil.ReadFile("migrations/changelog.xml")
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	cl := &ChangeLog{}
   179  	err = xml.Unmarshal([]byte(d), cl)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  
   184  	return cl, nil
   185  }
   186  
   187  func (lb Command) ReadMigration(path string) (*Migration, error) {
   188  	d, err := ioutil.ReadFile(path)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	m := &Migration{}
   194  	err = xml.Unmarshal([]byte(d), m)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	return m, nil
   200  }