goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/util/session/session_test.go (about)

     1  package session
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  	"gorm.io/gorm"
    12  	"gorm.io/gorm/clause"
    13  	"gorm.io/gorm/utils/tests"
    14  	"goyave.dev/goyave/v5/config"
    15  	"goyave.dev/goyave/v5/database"
    16  	"goyave.dev/goyave/v5/util/errors"
    17  )
    18  
    19  type testKey struct{}
    20  
    21  type testCommitter struct {
    22  	gorm.ConnPool
    23  	beginError  error
    24  	commitError error
    25  	committed   bool
    26  	rolledback  bool
    27  }
    28  
    29  func (c *testCommitter) Commit() error {
    30  	c.committed = true
    31  	return c.commitError
    32  }
    33  
    34  func (c *testCommitter) Rollback() error {
    35  	c.rolledback = true
    36  	return nil
    37  }
    38  
    39  func (c *testCommitter) BeginTx(_ context.Context, _ *sql.TxOptions) (gorm.ConnPool, error) {
    40  	return c, c.beginError
    41  }
    42  
    43  func TestGormSession(t *testing.T) {
    44  	cfg := config.LoadDefault()
    45  	cfg.Set("database.config.disableAutomaticPing", true)
    46  
    47  	t.Run("New", func(t *testing.T) {
    48  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
    49  		require.NoError(t, err)
    50  
    51  		opts := &sql.TxOptions{
    52  			Isolation: sql.LevelReadCommitted,
    53  			ReadOnly:  true,
    54  		}
    55  		session := GORM(db, opts)
    56  
    57  		assert.Equal(t, Gorm{
    58  			ctx:       context.Background(),
    59  			db:        db,
    60  			TxOptions: opts,
    61  		}, session)
    62  	})
    63  
    64  	t.Run("Manual", func(t *testing.T) {
    65  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
    66  		require.NoError(t, err)
    67  		committer := &testCommitter{}
    68  		db.Statement.ConnPool = committer
    69  		opts := &sql.TxOptions{
    70  			Isolation: sql.LevelReadCommitted,
    71  			ReadOnly:  true,
    72  		}
    73  		session := GORM(db, opts)
    74  
    75  		ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
    76  		tx, err := session.Begin(ctx)
    77  		require.NoError(t, err)
    78  		assert.NotEqual(t, session, tx)
    79  		assert.Equal(t, opts, tx.(Gorm).TxOptions)
    80  		assert.Equal(t, tx.(Gorm).ctx, tx.Context())
    81  		assert.Equal(t, "testvalue", tx.Context().Value(testKey{}))
    82  		assert.Equal(t, tx.(Gorm).db, tx.Context().Value(dbKey{}))
    83  
    84  		require.NoError(t, tx.Commit())
    85  		assert.True(t, committer.committed)
    86  
    87  		require.NoError(t, tx.Rollback())
    88  		assert.True(t, committer.rolledback)
    89  	})
    90  
    91  	t.Run("Begin_error", func(t *testing.T) {
    92  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
    93  		require.NoError(t, err)
    94  		beginErr := fmt.Errorf("begin error")
    95  		committer := &testCommitter{
    96  			beginError: beginErr,
    97  		}
    98  		db.Statement.ConnPool = committer
    99  		session := GORM(db, nil)
   100  
   101  		tx, err := session.Begin(context.Background())
   102  		require.ErrorIs(t, err, beginErr)
   103  		assert.Nil(t, tx)
   104  
   105  		err = session.Transaction(context.Background(), func(_ context.Context) error {
   106  			return nil
   107  		})
   108  		require.ErrorIs(t, err, beginErr)
   109  	})
   110  
   111  	t.Run("Nested_manual", func(t *testing.T) {
   112  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   113  		require.NoError(t, err)
   114  		committer := &testCommitter{}
   115  		db.Statement.ConnPool = committer
   116  		session := GORM(db, nil)
   117  
   118  		ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
   119  		tx, err := session.Begin(ctx)
   120  		tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB
   121  		require.NoError(t, err)
   122  		assert.NotNil(t, tx)
   123  
   124  		subtx, err := session.Begin(tx.Context())
   125  		require.NoError(t, err)
   126  		assert.Equal(t, "testvalue", subtx.(Gorm).db.Statement.Context.Value(testKey{})) // Parent context is kept
   127  		assert.Contains(t, subtx.(Gorm).db.Statement.Clauses, "testclause")              // Parent DB is used
   128  	})
   129  
   130  	t.Run("Transaction", func(t *testing.T) {
   131  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   132  		require.NoError(t, err)
   133  		committer := &testCommitter{}
   134  		db.Statement.ConnPool = committer
   135  		session := GORM(db, nil)
   136  
   137  		var ctxValue any
   138  		ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
   139  		err = session.Transaction(ctx, func(ctx context.Context) error {
   140  			ctxValue = ctx.Value(testKey{})
   141  			db := ctx.Value(dbKey{})
   142  			assert.NotNil(t, db)
   143  			_, ok := db.(*gorm.DB)
   144  			assert.True(t, ok)
   145  			return nil
   146  		})
   147  		require.NoError(t, err)
   148  		assert.Equal(t, "testvalue", ctxValue)
   149  		assert.True(t, committer.committed)
   150  		assert.False(t, committer.rolledback)
   151  	})
   152  
   153  	t.Run("Nested_Transaction", func(t *testing.T) {
   154  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   155  		require.NoError(t, err)
   156  		committer := &testCommitter{}
   157  		db.Statement.ConnPool = committer
   158  		session := GORM(db, nil)
   159  
   160  		ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
   161  		tx, err := session.Begin(ctx)
   162  		tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB
   163  		require.NoError(t, err)
   164  		assert.NotNil(t, tx)
   165  
   166  		err = session.Transaction(tx.Context(), func(ctx context.Context) error {
   167  			db := DB(ctx, nil)
   168  			assert.NotNil(t, db)
   169  			assert.Contains(t, db.Statement.Clauses, "testclause") // Parent DB is used
   170  			return nil
   171  		})
   172  		require.NoError(t, err)
   173  	})
   174  
   175  	t.Run("TransactionError", func(t *testing.T) {
   176  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   177  		require.NoError(t, err)
   178  		committer := &testCommitter{}
   179  		db.Statement.ConnPool = committer
   180  		session := GORM(db, nil)
   181  
   182  		var ctxValue any
   183  		ctx := context.WithValue(context.Background(), testKey{}, "testvalue")
   184  		err = session.Transaction(ctx, func(ctx context.Context) error {
   185  			ctxValue = ctx.Value(testKey{})
   186  			return fmt.Errorf("test err")
   187  		})
   188  		require.Error(t, err)
   189  		assert.Equal(t, errors.New(fmt.Errorf("test err")).Error(), err.Error())
   190  		assert.Equal(t, "testvalue", ctxValue)
   191  		assert.True(t, committer.rolledback)
   192  		assert.False(t, committer.committed)
   193  	})
   194  
   195  	t.Run("Transaction_Commit_error", func(t *testing.T) {
   196  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   197  		require.NoError(t, err)
   198  		commitErr := fmt.Errorf("commit error")
   199  		committer := &testCommitter{
   200  			commitError: commitErr,
   201  		}
   202  		db.Statement.ConnPool = committer
   203  		session := GORM(db, nil)
   204  
   205  		err = session.Transaction(context.Background(), func(_ context.Context) error {
   206  			return nil
   207  		})
   208  		require.ErrorIs(t, err, commitErr)
   209  	})
   210  
   211  	t.Run("DB", func(t *testing.T) {
   212  		db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{})
   213  		require.NoError(t, err)
   214  		fallback := &gorm.DB{}
   215  
   216  		cases := []struct {
   217  			ctx    context.Context
   218  			expect *gorm.DB
   219  			desc   string
   220  		}{
   221  			{
   222  				desc:   "missing_from_context",
   223  				ctx:    context.Background(),
   224  				expect: fallback,
   225  			},
   226  			{
   227  				desc:   "fallback",
   228  				ctx:    context.Background(),
   229  				expect: fallback,
   230  			},
   231  			{
   232  				desc:   "found",
   233  				ctx:    context.WithValue(context.Background(), dbKey{}, db),
   234  				expect: db,
   235  			},
   236  		}
   237  
   238  		for _, c := range cases {
   239  			c := c
   240  			t.Run(c.desc, func(t *testing.T) {
   241  				db := DB(c.ctx, fallback)
   242  				assert.Equal(t, c.expect, db)
   243  			})
   244  		}
   245  	})
   246  }