github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/tx_test.go (about) 1 package sqlx 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 ) 11 12 const ( 13 mockCommit = 1 14 mockRollback = 2 15 ) 16 17 type mockTx struct { 18 status int 19 } 20 21 func (mt *mockTx) Commit() error { 22 mt.status |= mockCommit 23 return nil 24 } 25 26 func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) { 27 return nil, nil 28 } 29 30 func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 31 return nil, nil 32 } 33 34 func (mt *mockTx) Prepare(query string) (StmtSession, error) { 35 return nil, nil 36 } 37 38 func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { 39 return nil, nil 40 } 41 42 func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error { 43 return nil 44 } 45 46 func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { 47 return nil 48 } 49 50 func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error { 51 return nil 52 } 53 54 func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { 55 return nil 56 } 57 58 func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error { 59 return nil 60 } 61 62 func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { 63 return nil 64 } 65 66 func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { 67 return nil 68 } 69 70 func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error { 71 return nil 72 } 73 74 func (mt *mockTx) Rollback() error { 75 mt.status |= mockRollback 76 return nil 77 } 78 79 func beginMock(mock *mockTx) beginnable { 80 return func(*sql.DB) (trans, error) { 81 return mock, nil 82 } 83 } 84 85 func TestTransactCommit(t *testing.T) { 86 mock := &mockTx{} 87 err := transactOnConn(context.Background(), nil, beginMock(mock), 88 func(context.Context, Session) error { 89 return nil 90 }) 91 assert.Equal(t, mockCommit, mock.status) 92 assert.Nil(t, err) 93 } 94 95 func TestTransactRollback(t *testing.T) { 96 mock := &mockTx{} 97 err := transactOnConn(context.Background(), nil, beginMock(mock), 98 func(context.Context, Session) error { 99 return errors.New("rollback") 100 }) 101 assert.Equal(t, mockRollback, mock.status) 102 assert.NotNil(t, err) 103 }