github.com/schemalex/git-schemalex@v0.0.0-20170921120917-b690b7f9e063/gitschemalex.go (about)

     1  package gitschemalex
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"errors"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"os/exec"
    12  	"path/filepath"
    13  	"strings"
    14  
    15  	_ "github.com/go-sql-driver/mysql"
    16  	"github.com/schemalex/schemalex"
    17  	"github.com/schemalex/schemalex/diff"
    18  )
    19  
    20  var (
    21  	ErrEqualVersion = errors.New("db version is equal to schema version")
    22  )
    23  
    24  type Runner struct {
    25  	Workspace string
    26  	Deploy    bool
    27  	DSN       string
    28  	Table     string
    29  	Schema    string
    30  }
    31  
    32  func (r *Runner) Run(ctx context.Context) error {
    33  	db, err := r.DB()
    34  
    35  	if err != nil {
    36  		return err
    37  	}
    38  
    39  	defer db.Close()
    40  
    41  	schemaVersion, err := r.SchemaVersion(ctx)
    42  	if err != nil {
    43  		return err
    44  	}
    45  
    46  	var dbVersion string
    47  	if err := r.DatabaseVersion(ctx, db, &dbVersion); err != nil {
    48  		if !strings.Contains(err.Error(), "doesn't exist") {
    49  			return err
    50  		}
    51  		return r.DeploySchema(ctx, db, schemaVersion)
    52  	}
    53  
    54  	if dbVersion == schemaVersion {
    55  		return ErrEqualVersion
    56  	}
    57  
    58  	if err := r.UpgradeSchema(ctx, db, schemaVersion, dbVersion); err != nil {
    59  		return err
    60  	}
    61  
    62  	return nil
    63  }
    64  
    65  func (r *Runner) DB() (*sql.DB, error) {
    66  	return sql.Open("mysql", r.DSN)
    67  }
    68  
    69  func (r *Runner) DatabaseVersion(ctx context.Context, db *sql.DB, version *string) error {
    70  	return db.QueryRowContext(ctx, fmt.Sprintf("SELECT version FROM `%s`", r.Table)).Scan(version)
    71  }
    72  
    73  func (r *Runner) SchemaVersion(ctx context.Context) (string, error) {
    74  	byt, err := r.execGitCmd(ctx, "log", "-n", "1", "--pretty=format:%H", "--", r.Schema)
    75  	if err != nil {
    76  		return "", err
    77  	}
    78  
    79  	return string(byt), nil
    80  }
    81  
    82  func (r *Runner) DeploySchema(ctx context.Context, db *sql.DB, version string) error {
    83  	content, err := r.schemaContent()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	queries := queryListFromString(content)
    88  	queries.AppendStmt(fmt.Sprintf("CREATE TABLE `%s` ( version VARCHAR(40) NOT NULL )", r.Table))
    89  	queries.AppendStmt(fmt.Sprintf("INSERT INTO `%s` (version) VALUES (?)", r.Table), version)
    90  	return r.execSql(ctx, db, queries)
    91  }
    92  
    93  func (r *Runner) UpgradeSchema(ctx context.Context, db *sql.DB, schemaVersion string, dbVersion string) error {
    94  	lastSchema, err := r.schemaSpecificCommit(ctx, dbVersion)
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	currentSchema, err := r.schemaContent()
   100  	if err != nil {
   101  		return err
   102  	}
   103  	stmts := &bytes.Buffer{}
   104  	p := schemalex.New()
   105  	err = diff.Strings(stmts, lastSchema, currentSchema, diff.WithTransaction(true), diff.WithParser(p))
   106  	if err != nil {
   107  		return err
   108  	}
   109  
   110  	queries := queryListFromString(stmts.String())
   111  	queries.AppendStmt(fmt.Sprintf("UPDATE %s SET version = ?", r.Table), schemaVersion)
   112  
   113  	return r.execSql(ctx, db, queries)
   114  }
   115  
   116  // private
   117  
   118  func (r *Runner) schemaSpecificCommit(ctx context.Context, commit string) (string, error) {
   119  	byt, err := r.execGitCmd(ctx, "ls-tree", commit, "--", r.Schema)
   120  
   121  	if err != nil {
   122  		return "", err
   123  	}
   124  
   125  	fields := strings.Fields(string(byt))
   126  
   127  	byt, err = r.execGitCmd(ctx, "cat-file", "blob", fields[2])
   128  	if err != nil {
   129  		return "", err
   130  	}
   131  
   132  	return string(byt), nil
   133  }
   134  
   135  func (r *Runner) execSql(ctx context.Context, db *sql.DB, queries queryList) error {
   136  	if !r.Deploy {
   137  		return queries.dump(os.Stdout)
   138  	}
   139  	return queries.execute(ctx, db)
   140  }
   141  
   142  func (r *Runner) schemaContent() (string, error) {
   143  	byt, err := ioutil.ReadFile(filepath.Join(r.Workspace, r.Schema))
   144  	if err != nil {
   145  		return "", err
   146  	}
   147  	return string(byt), nil
   148  }
   149  
   150  func (r *Runner) execGitCmd(ctx context.Context, args ...string) ([]byte, error) {
   151  	cmd := exec.CommandContext(ctx, "git", args...)
   152  	if r.Workspace != "" {
   153  		cmd.Dir = r.Workspace
   154  	}
   155  
   156  	byt, err := cmd.Output()
   157  	if err != nil {
   158  		return nil, fmt.Errorf("%s got err:%s", cmd.Args, err)
   159  	}
   160  
   161  	return byt, nil
   162  }