github.com/gobuffalo/buffalo-cli/v2@v2.0.0-alpha.15.0.20200919213536-a7350c8e6799/cli/internal/plugins/pop/test/tester.go (about)

     1  package test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  
    11  	"github.com/gobuffalo/buffalo-cli/v2/cli/cmds/test"
    12  	"github.com/gobuffalo/plugins"
    13  	"github.com/gobuffalo/pop/v5"
    14  )
    15  
    16  var _ plugins.Plugin = &Tester{}
    17  var _ test.Argumenter = &Tester{}
    18  var _ test.BeforeTester = &Tester{}
    19  
    20  type Tester struct{}
    21  
    22  func (t *Tester) TestArgs(ctx context.Context, root string) ([]string, error) {
    23  	args := []string{"-p", "1"}
    24  
    25  	dy := filepath.Join(root, "database.yml")
    26  	if _, err := os.Stat(dy); err != nil {
    27  		return args, nil
    28  	}
    29  
    30  	b, err := ioutil.ReadFile(dy)
    31  	if err != nil {
    32  		return nil, plugins.Wrap(t, err)
    33  	}
    34  	if bytes.Contains(b, []byte("sqlite")) {
    35  		args = append(args, "-tags", "sqlite")
    36  	}
    37  	return args, nil
    38  }
    39  
    40  func (Tester) PluginName() string {
    41  	return "pop/tester"
    42  }
    43  
    44  func (t *Tester) BeforeTest(ctx context.Context, root string, args []string) error {
    45  	if err := pop.AddLookupPaths(root); err != nil {
    46  		return plugins.Wrap(t, err)
    47  	}
    48  
    49  	var err error
    50  	db, ok := ctx.Value("tx").(*pop.Connection)
    51  	if !ok {
    52  		if _, err := os.Stat(filepath.Join(root, "database.yml")); err != nil {
    53  			return plugins.Wrap(t, err)
    54  		}
    55  
    56  		db, err = pop.Connect("test")
    57  		if err != nil {
    58  			return plugins.Wrap(t, err)
    59  		}
    60  	}
    61  	// drop the test db:
    62  	db.Dialect.DropDB()
    63  
    64  	// create the test db:
    65  	if err := db.Dialect.CreateDB(); err != nil {
    66  		return plugins.Wrap(t, err)
    67  	}
    68  
    69  	for _, a := range args {
    70  		if a == "--force-migrations" {
    71  			return t.forceMigrations(root, db)
    72  		}
    73  	}
    74  
    75  	schema, err := t.findSchema(root)
    76  	if err != nil {
    77  		return plugins.Wrap(t, err)
    78  	}
    79  	if schema == nil {
    80  		return t.forceMigrations(root, db)
    81  	}
    82  
    83  	err = db.Dialect.LoadSchema(schema)
    84  	return plugins.Wrap(t, err)
    85  }
    86  
    87  func (t *Tester) forceMigrations(root string, db *pop.Connection) error {
    88  	ms := filepath.Join(root, "migrations")
    89  	fm, err := pop.NewFileMigrator(ms, db)
    90  	if err != nil {
    91  		return plugins.Wrap(t, err)
    92  	}
    93  	return fm.Up()
    94  }
    95  
    96  func (t *Tester) findSchema(root string) (io.Reader, error) {
    97  	ms := filepath.Join(root, "migrations", "schema.sql")
    98  	if f, err := os.Open(ms); err == nil {
    99  		return f, nil
   100  	}
   101  
   102  	if dev, err := pop.Connect("development"); err == nil {
   103  		schema := &bytes.Buffer{}
   104  		if err = dev.Dialect.DumpSchema(schema); err == nil {
   105  			return schema, nil
   106  		}
   107  	}
   108  
   109  	return nil, nil
   110  }