github.com/walf443/mgr@v0.0.0-20150203144449-6f7a3a548462/cli/main.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"github.com/k0kubun/pp"
     7  	"github.com/walf443/mgr/sqlparser/mysql"
     8  	"github.com/walf443/mgr/diff"
     9  	"io/ioutil"
    10  	"os"
    11  )
    12  
    13  func main() {
    14  	var beforeFile = flag.String("before", "", "before schema filename. if \"stdin\" given, read from os.Stdin")
    15  	var afterFile  = flag.String("after", "",  "after schema filename")
    16  	flag.Parse()
    17  	if os.Getenv("DEBUG") == "" {
    18  		pp.SetDefaultOutput(ioutil.Discard)
    19  	}
    20  
    21  	if *beforeFile == "" || *afterFile == "" {
    22  		fmt.Fprintf(os.Stderr, "-before or -after are missing\n")
    23  		os.Exit(1)
    24  	}
    25  
    26  	beforeSchema := ""
    27  	var err error
    28  	if *beforeFile == "stdin" {
    29  		beforeSchema,  err = loadStdin()
    30  		if err != nil {
    31  			fmt.Fprintf(os.Stderr, "failed to load file: %q\n", err)
    32  			os.Exit(1)
    33  		}
    34  	} else {
    35  		beforeSchema, err = loadFile(*beforeFile)
    36  		if err != nil {
    37  			fmt.Fprintf(os.Stderr, "failed to load file: %q\n", err)
    38  			os.Exit(1)
    39  		}
    40  	}
    41  	afterSchema, err := loadFile(*afterFile)
    42  	if err != nil {
    43  		fmt.Fprintf(os.Stderr, "failed to load file: %q\n", err)
    44  		os.Exit(1)
    45  	}
    46  
    47  	beforeStmts, err := parseSchema(beforeSchema)
    48  	if err != nil {
    49  		fmt.Fprintf(os.Stderr, "failed to parse file: %s\n%s\n", *beforeFile, err)
    50  		os.Exit(1)
    51  	}
    52  	afterStmts, err := parseSchema(afterSchema)
    53  	if err != nil {
    54  		fmt.Fprintf(os.Stderr, "failed to parse file: %s\n%s\n", *afterFile, err)
    55  		os.Exit(1)
    56  	}
    57  	result := diff.Extract(beforeStmts, afterStmts)
    58  	for _, stmt := range(result.Changes()) {
    59  		fmt.Println(stmt)
    60  	}
    61  	pp.Print(result)
    62  }
    63  
    64  func loadFile(fname string) (string, error) {
    65  	result, err := ioutil.ReadFile(fname)
    66  	if err != nil {
    67  		return "", err
    68  	}
    69  	return string(result), nil
    70  }
    71  
    72  func loadStdin() (string, error) {
    73  	result, err := ioutil.ReadAll(os.Stdin)
    74  	if err != nil {
    75  		return "", err
    76  	}
    77  	return string(result), nil
    78  }
    79  
    80  func parseSchema(schema string) ([]mysql.Statement, error) {
    81  	s := new(mysql.Scanner)
    82  	s.Init(schema)
    83  	return mysql.Parse(s)
    84  }