github.com/acoshift/pgsql@v0.15.3/pgctx/pgctx_test.go (about)

     1  package pgctx_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	"github.com/DATA-DOG/go-sqlmock"
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"github.com/acoshift/pgsql"
    14  	"github.com/acoshift/pgsql/pgctx"
    15  )
    16  
    17  func newCtx(t *testing.T) (context.Context, sqlmock.Sqlmock) {
    18  	t.Helper()
    19  
    20  	db, mock, err := sqlmock.New()
    21  	assert.NoError(t, err)
    22  	return pgctx.NewContext(context.Background(), db), mock
    23  }
    24  
    25  func TestNewContext(t *testing.T) {
    26  	t.Parallel()
    27  
    28  	assert.NotPanics(t, func() {
    29  		newCtx(t)
    30  	})
    31  }
    32  
    33  type testKey1 struct{}
    34  
    35  func TestNewKeyContext(t *testing.T) {
    36  	t.Parallel()
    37  
    38  	assert.NotPanics(t, func() {
    39  		db, _, err := sqlmock.New()
    40  		assert.NoError(t, err)
    41  		ctx := pgctx.NewKeyContext(context.Background(), testKey1{}, db)
    42  		assert.NotNil(t, ctx)
    43  	})
    44  }
    45  
    46  func TestMiddleware(t *testing.T) {
    47  	t.Parallel()
    48  
    49  	db, _, err := sqlmock.New()
    50  	assert.NoError(t, err)
    51  
    52  	called := false
    53  	w := httptest.NewRecorder()
    54  	r := httptest.NewRequest("GET", "/", nil)
    55  	pgctx.Middleware(db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    56  		called = true
    57  		ctx := r.Context()
    58  		assert.NotPanics(t, func() {
    59  			pgctx.QueryRow(ctx, "select 1")
    60  		})
    61  		assert.NotPanics(t, func() {
    62  			pgctx.Query(ctx, "select 1")
    63  		})
    64  		assert.NotPanics(t, func() {
    65  			pgctx.Exec(ctx, "select 1")
    66  		})
    67  	})).ServeHTTP(w, r)
    68  	assert.True(t, called)
    69  }
    70  
    71  func TestKeyMiddleware(t *testing.T) {
    72  	t.Parallel()
    73  
    74  	db, _, err := sqlmock.New()
    75  	assert.NoError(t, err)
    76  
    77  	called := false
    78  	w := httptest.NewRecorder()
    79  	r := httptest.NewRequest("GET", "/", nil)
    80  	pgctx.KeyMiddleware(testKey1{}, db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    81  		called = true
    82  		ctx := r.Context()
    83  		assert.NotPanics(t, func() {
    84  			pgctx.QueryRow(pgctx.With(ctx, testKey1{}), "select 1")
    85  		})
    86  		assert.NotPanics(t, func() {
    87  			pgctx.Query(pgctx.With(ctx, testKey1{}), "select 1")
    88  		})
    89  		assert.NotPanics(t, func() {
    90  			pgctx.Exec(pgctx.With(ctx, testKey1{}), "select 1")
    91  		})
    92  		assert.Panics(t, func() {
    93  			pgctx.QueryRow(ctx, "select 1")
    94  		})
    95  	})).ServeHTTP(w, r)
    96  	assert.True(t, called)
    97  }
    98  
    99  func TestRunInTx(t *testing.T) {
   100  	t.Parallel()
   101  
   102  	t.Run("Committed", func(t *testing.T) {
   103  		ctx, mock := newCtx(t)
   104  
   105  		called := false
   106  		mock.ExpectBegin()
   107  		mock.ExpectCommit()
   108  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   109  			called = true
   110  			return nil
   111  		})
   112  		assert.NoError(t, err)
   113  		assert.True(t, called)
   114  	})
   115  
   116  	t.Run("Rollback with error", func(t *testing.T) {
   117  		ctx, mock := newCtx(t)
   118  
   119  		mock.ExpectBegin()
   120  		mock.ExpectRollback()
   121  		var retErr = fmt.Errorf("error")
   122  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   123  			return retErr
   124  		})
   125  		assert.Error(t, err)
   126  		assert.Equal(t, retErr, err)
   127  	})
   128  
   129  	t.Run("Abort Tx", func(t *testing.T) {
   130  		ctx, mock := newCtx(t)
   131  
   132  		mock.ExpectBegin()
   133  		mock.ExpectCommit()
   134  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   135  			return pgsql.ErrAbortTx
   136  		})
   137  		assert.NoError(t, err)
   138  	})
   139  
   140  	t.Run("Nested Tx", func(t *testing.T) {
   141  		ctx, mock := newCtx(t)
   142  
   143  		mock.ExpectBegin()
   144  		mock.ExpectCommit()
   145  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   146  			return pgctx.RunInTx(ctx, func(ctx context.Context) error {
   147  				return nil
   148  			})
   149  		})
   150  		assert.NoError(t, err)
   151  	})
   152  }
   153  
   154  func TestCommitted(t *testing.T) {
   155  	t.Parallel()
   156  
   157  	t.Run("Outside Tx", func(t *testing.T) {
   158  		ctx, _ := newCtx(t)
   159  		var called bool
   160  		pgctx.Committed(ctx, func(ctx context.Context) {
   161  			called = true
   162  		})
   163  		assert.True(t, called)
   164  	})
   165  
   166  	t.Run("Nil func", func(t *testing.T) {
   167  		ctx, mock := newCtx(t)
   168  
   169  		mock.ExpectBegin()
   170  		mock.ExpectCommit()
   171  		pgctx.RunInTx(ctx, func(ctx context.Context) error {
   172  			pgctx.Committed(ctx, nil)
   173  			return nil
   174  		})
   175  	})
   176  
   177  	t.Run("Committed", func(t *testing.T) {
   178  		ctx, mock := newCtx(t)
   179  
   180  		called := false
   181  		mock.ExpectBegin()
   182  		mock.ExpectCommit()
   183  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   184  			pgctx.Committed(ctx, func(ctx context.Context) {
   185  				called = true
   186  			})
   187  			return nil
   188  		})
   189  		assert.NoError(t, err)
   190  		assert.True(t, called)
   191  	})
   192  
   193  	t.Run("Rollback", func(t *testing.T) {
   194  		ctx, mock := newCtx(t)
   195  
   196  		mock.ExpectBegin()
   197  		mock.ExpectRollback()
   198  		err := pgctx.RunInTx(ctx, func(ctx context.Context) error {
   199  			pgctx.Committed(ctx, func(ctx context.Context) {
   200  				assert.Fail(t, "should not be called")
   201  			})
   202  			return pgsql.ErrAbortTx
   203  		})
   204  		assert.NoError(t, err)
   205  	})
   206  }