goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/util/session/session_test.go (about) 1 package session 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 "gorm.io/gorm" 12 "gorm.io/gorm/clause" 13 "gorm.io/gorm/utils/tests" 14 "goyave.dev/goyave/v5/config" 15 "goyave.dev/goyave/v5/database" 16 "goyave.dev/goyave/v5/util/errors" 17 ) 18 19 type testKey struct{} 20 21 type testCommitter struct { 22 gorm.ConnPool 23 beginError error 24 commitError error 25 committed bool 26 rolledback bool 27 } 28 29 func (c *testCommitter) Commit() error { 30 c.committed = true 31 return c.commitError 32 } 33 34 func (c *testCommitter) Rollback() error { 35 c.rolledback = true 36 return nil 37 } 38 39 func (c *testCommitter) BeginTx(_ context.Context, _ *sql.TxOptions) (gorm.ConnPool, error) { 40 return c, c.beginError 41 } 42 43 func TestGormSession(t *testing.T) { 44 cfg := config.LoadDefault() 45 cfg.Set("database.config.disableAutomaticPing", true) 46 47 t.Run("New", func(t *testing.T) { 48 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 49 require.NoError(t, err) 50 51 opts := &sql.TxOptions{ 52 Isolation: sql.LevelReadCommitted, 53 ReadOnly: true, 54 } 55 session := GORM(db, opts) 56 57 assert.Equal(t, Gorm{ 58 ctx: context.Background(), 59 db: db, 60 TxOptions: opts, 61 }, session) 62 }) 63 64 t.Run("Manual", func(t *testing.T) { 65 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 66 require.NoError(t, err) 67 committer := &testCommitter{} 68 db.Statement.ConnPool = committer 69 opts := &sql.TxOptions{ 70 Isolation: sql.LevelReadCommitted, 71 ReadOnly: true, 72 } 73 session := GORM(db, opts) 74 75 ctx := context.WithValue(context.Background(), testKey{}, "testvalue") 76 tx, err := session.Begin(ctx) 77 require.NoError(t, err) 78 assert.NotEqual(t, session, tx) 79 assert.Equal(t, opts, tx.(Gorm).TxOptions) 80 assert.Equal(t, tx.(Gorm).ctx, tx.Context()) 81 assert.Equal(t, "testvalue", tx.Context().Value(testKey{})) 82 assert.Equal(t, tx.(Gorm).db, tx.Context().Value(dbKey{})) 83 84 require.NoError(t, tx.Commit()) 85 assert.True(t, committer.committed) 86 87 require.NoError(t, tx.Rollback()) 88 assert.True(t, committer.rolledback) 89 }) 90 91 t.Run("Begin_error", func(t *testing.T) { 92 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 93 require.NoError(t, err) 94 beginErr := fmt.Errorf("begin error") 95 committer := &testCommitter{ 96 beginError: beginErr, 97 } 98 db.Statement.ConnPool = committer 99 session := GORM(db, nil) 100 101 tx, err := session.Begin(context.Background()) 102 require.ErrorIs(t, err, beginErr) 103 assert.Nil(t, tx) 104 105 err = session.Transaction(context.Background(), func(_ context.Context) error { 106 return nil 107 }) 108 require.ErrorIs(t, err, beginErr) 109 }) 110 111 t.Run("Nested_manual", func(t *testing.T) { 112 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 113 require.NoError(t, err) 114 committer := &testCommitter{} 115 db.Statement.ConnPool = committer 116 session := GORM(db, nil) 117 118 ctx := context.WithValue(context.Background(), testKey{}, "testvalue") 119 tx, err := session.Begin(ctx) 120 tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB 121 require.NoError(t, err) 122 assert.NotNil(t, tx) 123 124 subtx, err := session.Begin(tx.Context()) 125 require.NoError(t, err) 126 assert.Equal(t, "testvalue", subtx.(Gorm).db.Statement.Context.Value(testKey{})) // Parent context is kept 127 assert.Contains(t, subtx.(Gorm).db.Statement.Clauses, "testclause") // Parent DB is used 128 }) 129 130 t.Run("Transaction", func(t *testing.T) { 131 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 132 require.NoError(t, err) 133 committer := &testCommitter{} 134 db.Statement.ConnPool = committer 135 session := GORM(db, nil) 136 137 var ctxValue any 138 ctx := context.WithValue(context.Background(), testKey{}, "testvalue") 139 err = session.Transaction(ctx, func(ctx context.Context) error { 140 ctxValue = ctx.Value(testKey{}) 141 db := ctx.Value(dbKey{}) 142 assert.NotNil(t, db) 143 _, ok := db.(*gorm.DB) 144 assert.True(t, ok) 145 return nil 146 }) 147 require.NoError(t, err) 148 assert.Equal(t, "testvalue", ctxValue) 149 assert.True(t, committer.committed) 150 assert.False(t, committer.rolledback) 151 }) 152 153 t.Run("Nested_Transaction", func(t *testing.T) { 154 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 155 require.NoError(t, err) 156 committer := &testCommitter{} 157 db.Statement.ConnPool = committer 158 session := GORM(db, nil) 159 160 ctx := context.WithValue(context.Background(), testKey{}, "testvalue") 161 tx, err := session.Begin(ctx) 162 tx.(Gorm).db.Statement.Clauses["testclause"] = clause.Clause{} // Use this to check the nested db is based on the parent DB 163 require.NoError(t, err) 164 assert.NotNil(t, tx) 165 166 err = session.Transaction(tx.Context(), func(ctx context.Context) error { 167 db := DB(ctx, nil) 168 assert.NotNil(t, db) 169 assert.Contains(t, db.Statement.Clauses, "testclause") // Parent DB is used 170 return nil 171 }) 172 require.NoError(t, err) 173 }) 174 175 t.Run("TransactionError", func(t *testing.T) { 176 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 177 require.NoError(t, err) 178 committer := &testCommitter{} 179 db.Statement.ConnPool = committer 180 session := GORM(db, nil) 181 182 var ctxValue any 183 ctx := context.WithValue(context.Background(), testKey{}, "testvalue") 184 err = session.Transaction(ctx, func(ctx context.Context) error { 185 ctxValue = ctx.Value(testKey{}) 186 return fmt.Errorf("test err") 187 }) 188 require.Error(t, err) 189 assert.Equal(t, errors.New(fmt.Errorf("test err")).Error(), err.Error()) 190 assert.Equal(t, "testvalue", ctxValue) 191 assert.True(t, committer.rolledback) 192 assert.False(t, committer.committed) 193 }) 194 195 t.Run("Transaction_Commit_error", func(t *testing.T) { 196 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 197 require.NoError(t, err) 198 commitErr := fmt.Errorf("commit error") 199 committer := &testCommitter{ 200 commitError: commitErr, 201 } 202 db.Statement.ConnPool = committer 203 session := GORM(db, nil) 204 205 err = session.Transaction(context.Background(), func(_ context.Context) error { 206 return nil 207 }) 208 require.ErrorIs(t, err, commitErr) 209 }) 210 211 t.Run("DB", func(t *testing.T) { 212 db, err := database.NewFromDialector(cfg, nil, tests.DummyDialector{}) 213 require.NoError(t, err) 214 fallback := &gorm.DB{} 215 216 cases := []struct { 217 ctx context.Context 218 expect *gorm.DB 219 desc string 220 }{ 221 { 222 desc: "missing_from_context", 223 ctx: context.Background(), 224 expect: fallback, 225 }, 226 { 227 desc: "fallback", 228 ctx: context.Background(), 229 expect: fallback, 230 }, 231 { 232 desc: "found", 233 ctx: context.WithValue(context.Background(), dbKey{}, db), 234 expect: db, 235 }, 236 } 237 238 for _, c := range cases { 239 c := c 240 t.Run(c.desc, func(t *testing.T) { 241 db := DB(c.ctx, fallback) 242 assert.Equal(t, c.expect, db) 243 }) 244 } 245 }) 246 }