github.com/orderbynull/buffalo@v0.11.1/middleware/pop_transaction_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"path/filepath"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/gobuffalo/buffalo"
    11  	"github.com/gobuffalo/pop"
    12  	"github.com/gobuffalo/uuid"
    13  	"github.com/markbates/willie"
    14  	"github.com/pkg/errors"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  type widget struct {
    19  	ID        uuid.UUID `db:"id"`
    20  	CreatedAt time.Time `db:"created_at"`
    21  	UpdatedAt time.Time `db:"updated_at"`
    22  }
    23  
    24  func tx(fn func(tx *pop.Connection)) error {
    25  	pop.Debug = true
    26  	defer func() { pop.Debug = false }()
    27  	d, err := ioutil.TempDir("", "")
    28  	if err != nil {
    29  		return errors.WithStack(err)
    30  	}
    31  	path := filepath.Join(d, "pt_test.sqlite")
    32  	defer os.RemoveAll(path)
    33  
    34  	db, err := pop.NewConnection(&pop.ConnectionDetails{
    35  		Dialect: "sqlite",
    36  		URL:     path,
    37  	})
    38  	if err != nil {
    39  		return errors.WithStack(err)
    40  	}
    41  	if err := db.Dialect.CreateDB(); err != nil {
    42  		return errors.WithStack(err)
    43  	}
    44  	if err := db.Open(); err != nil {
    45  		return err
    46  	}
    47  	if err := db.RawQuery(mig).Exec(); err != nil {
    48  		return err
    49  	}
    50  	fn(db)
    51  	return nil
    52  }
    53  
    54  func app(db *pop.Connection) *buffalo.App {
    55  	app := buffalo.New(buffalo.Options{})
    56  	app.Use(PopTransaction(db))
    57  	app.GET("/success", func(c buffalo.Context) error {
    58  		w := &widget{}
    59  		tx := c.Value("tx").(*pop.Connection)
    60  		if err := tx.Create(w); err != nil {
    61  			return err
    62  		}
    63  		return c.Render(201, nil)
    64  	})
    65  	app.GET("/non-success", func(c buffalo.Context) error {
    66  		w := &widget{}
    67  		tx := c.Value("tx").(*pop.Connection)
    68  		if err := tx.Create(w); err != nil {
    69  			return err
    70  		}
    71  		return c.Render(301, nil)
    72  	})
    73  	app.GET("/error", func(c buffalo.Context) error {
    74  		w := &widget{}
    75  		tx := c.Value("tx").(*pop.Connection)
    76  		if err := tx.Create(w); err != nil {
    77  			return err
    78  		}
    79  		return errors.New("boom")
    80  	})
    81  	return app
    82  }
    83  
    84  func Test_PopTransaction(t *testing.T) {
    85  	r := require.New(t)
    86  	err := tx(func(db *pop.Connection) {
    87  		w := willie.New(app(db))
    88  		res := w.HTML("/success").Get()
    89  		r.Equal(201, res.Code)
    90  		count, err := db.Count("widgets")
    91  		r.NoError(err)
    92  		r.Equal(1, count)
    93  	})
    94  	r.NoError(err)
    95  }
    96  
    97  func Test_PopTransaction_Error(t *testing.T) {
    98  	r := require.New(t)
    99  	err := tx(func(db *pop.Connection) {
   100  		w := willie.New(app(db))
   101  		res := w.HTML("/error").Get()
   102  		r.Equal(500, res.Code)
   103  		count, err := db.Count("widgets")
   104  		r.NoError(err)
   105  		r.Equal(0, count)
   106  	})
   107  	r.NoError(err)
   108  }
   109  
   110  func Test_PopTransaction_NonSuccess(t *testing.T) {
   111  	r := require.New(t)
   112  	err := tx(func(db *pop.Connection) {
   113  		w := willie.New(app(db))
   114  		res := w.HTML("/non-success").Get()
   115  		r.Equal(301, res.Code)
   116  		count, err := db.Count("widgets")
   117  		r.NoError(err)
   118  		r.Equal(1, count)
   119  	})
   120  	r.NoError(err)
   121  }
   122  
   123  const mig = `CREATE TABLE "widgets" (
   124    "created_at" DATETIME NOT NULL,
   125    "updated_at" DATETIME NOT NULL,
   126    "id" TEXT PRIMARY KEY
   127  );`