github.com/acoshift/pgsql@v0.15.3/pgctx/pgctx_test.go (about) 1 package pgctx_test 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 10 "github.com/DATA-DOG/go-sqlmock" 11 "github.com/stretchr/testify/assert" 12 13 "github.com/acoshift/pgsql" 14 "github.com/acoshift/pgsql/pgctx" 15 ) 16 17 func newCtx(t *testing.T) (context.Context, sqlmock.Sqlmock) { 18 t.Helper() 19 20 db, mock, err := sqlmock.New() 21 assert.NoError(t, err) 22 return pgctx.NewContext(context.Background(), db), mock 23 } 24 25 func TestNewContext(t *testing.T) { 26 t.Parallel() 27 28 assert.NotPanics(t, func() { 29 newCtx(t) 30 }) 31 } 32 33 type testKey1 struct{} 34 35 func TestNewKeyContext(t *testing.T) { 36 t.Parallel() 37 38 assert.NotPanics(t, func() { 39 db, _, err := sqlmock.New() 40 assert.NoError(t, err) 41 ctx := pgctx.NewKeyContext(context.Background(), testKey1{}, db) 42 assert.NotNil(t, ctx) 43 }) 44 } 45 46 func TestMiddleware(t *testing.T) { 47 t.Parallel() 48 49 db, _, err := sqlmock.New() 50 assert.NoError(t, err) 51 52 called := false 53 w := httptest.NewRecorder() 54 r := httptest.NewRequest("GET", "/", nil) 55 pgctx.Middleware(db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 56 called = true 57 ctx := r.Context() 58 assert.NotPanics(t, func() { 59 pgctx.QueryRow(ctx, "select 1") 60 }) 61 assert.NotPanics(t, func() { 62 pgctx.Query(ctx, "select 1") 63 }) 64 assert.NotPanics(t, func() { 65 pgctx.Exec(ctx, "select 1") 66 }) 67 })).ServeHTTP(w, r) 68 assert.True(t, called) 69 } 70 71 func TestKeyMiddleware(t *testing.T) { 72 t.Parallel() 73 74 db, _, err := sqlmock.New() 75 assert.NoError(t, err) 76 77 called := false 78 w := httptest.NewRecorder() 79 r := httptest.NewRequest("GET", "/", nil) 80 pgctx.KeyMiddleware(testKey1{}, db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 called = true 82 ctx := r.Context() 83 assert.NotPanics(t, func() { 84 pgctx.QueryRow(pgctx.With(ctx, testKey1{}), "select 1") 85 }) 86 assert.NotPanics(t, func() { 87 pgctx.Query(pgctx.With(ctx, testKey1{}), "select 1") 88 }) 89 assert.NotPanics(t, func() { 90 pgctx.Exec(pgctx.With(ctx, testKey1{}), "select 1") 91 }) 92 assert.Panics(t, func() { 93 pgctx.QueryRow(ctx, "select 1") 94 }) 95 })).ServeHTTP(w, r) 96 assert.True(t, called) 97 } 98 99 func TestRunInTx(t *testing.T) { 100 t.Parallel() 101 102 t.Run("Committed", func(t *testing.T) { 103 ctx, mock := newCtx(t) 104 105 called := false 106 mock.ExpectBegin() 107 mock.ExpectCommit() 108 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 109 called = true 110 return nil 111 }) 112 assert.NoError(t, err) 113 assert.True(t, called) 114 }) 115 116 t.Run("Rollback with error", func(t *testing.T) { 117 ctx, mock := newCtx(t) 118 119 mock.ExpectBegin() 120 mock.ExpectRollback() 121 var retErr = fmt.Errorf("error") 122 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 123 return retErr 124 }) 125 assert.Error(t, err) 126 assert.Equal(t, retErr, err) 127 }) 128 129 t.Run("Abort Tx", func(t *testing.T) { 130 ctx, mock := newCtx(t) 131 132 mock.ExpectBegin() 133 mock.ExpectCommit() 134 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 135 return pgsql.ErrAbortTx 136 }) 137 assert.NoError(t, err) 138 }) 139 140 t.Run("Nested Tx", func(t *testing.T) { 141 ctx, mock := newCtx(t) 142 143 mock.ExpectBegin() 144 mock.ExpectCommit() 145 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 146 return pgctx.RunInTx(ctx, func(ctx context.Context) error { 147 return nil 148 }) 149 }) 150 assert.NoError(t, err) 151 }) 152 } 153 154 func TestCommitted(t *testing.T) { 155 t.Parallel() 156 157 t.Run("Outside Tx", func(t *testing.T) { 158 ctx, _ := newCtx(t) 159 var called bool 160 pgctx.Committed(ctx, func(ctx context.Context) { 161 called = true 162 }) 163 assert.True(t, called) 164 }) 165 166 t.Run("Nil func", func(t *testing.T) { 167 ctx, mock := newCtx(t) 168 169 mock.ExpectBegin() 170 mock.ExpectCommit() 171 pgctx.RunInTx(ctx, func(ctx context.Context) error { 172 pgctx.Committed(ctx, nil) 173 return nil 174 }) 175 }) 176 177 t.Run("Committed", func(t *testing.T) { 178 ctx, mock := newCtx(t) 179 180 called := false 181 mock.ExpectBegin() 182 mock.ExpectCommit() 183 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 184 pgctx.Committed(ctx, func(ctx context.Context) { 185 called = true 186 }) 187 return nil 188 }) 189 assert.NoError(t, err) 190 assert.True(t, called) 191 }) 192 193 t.Run("Rollback", func(t *testing.T) { 194 ctx, mock := newCtx(t) 195 196 mock.ExpectBegin() 197 mock.ExpectRollback() 198 err := pgctx.RunInTx(ctx, func(ctx context.Context) error { 199 pgctx.Committed(ctx, func(ctx context.Context) { 200 assert.Fail(t, "should not be called") 201 }) 202 return pgsql.ErrAbortTx 203 }) 204 assert.NoError(t, err) 205 }) 206 }