github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/tx_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  const (
    13  	mockCommit   = 1
    14  	mockRollback = 2
    15  )
    16  
    17  type mockTx struct {
    18  	status int
    19  }
    20  
    21  func (mt *mockTx) Commit() error {
    22  	mt.status |= mockCommit
    23  	return nil
    24  }
    25  
    26  func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
    27  	return nil, nil
    28  }
    29  
    30  func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
    31  	return nil, nil
    32  }
    33  
    34  func (mt *mockTx) Prepare(query string) (StmtSession, error) {
    35  	return nil, nil
    36  }
    37  
    38  func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
    39  	return nil, nil
    40  }
    41  
    42  func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
    43  	return nil
    44  }
    45  
    46  func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    47  	return nil
    48  }
    49  
    50  func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
    51  	return nil
    52  }
    53  
    54  func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    55  	return nil
    56  }
    57  
    58  func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
    59  	return nil
    60  }
    61  
    62  func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    63  	return nil
    64  }
    65  
    66  func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
    67  	return nil
    68  }
    69  
    70  func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    71  	return nil
    72  }
    73  
    74  func (mt *mockTx) Rollback() error {
    75  	mt.status |= mockRollback
    76  	return nil
    77  }
    78  
    79  func beginMock(mock *mockTx) beginnable {
    80  	return func(*sql.DB) (trans, error) {
    81  		return mock, nil
    82  	}
    83  }
    84  
    85  func TestTransactCommit(t *testing.T) {
    86  	mock := &mockTx{}
    87  	err := transactOnConn(context.Background(), nil, beginMock(mock),
    88  		func(context.Context, Session) error {
    89  			return nil
    90  		})
    91  	assert.Equal(t, mockCommit, mock.status)
    92  	assert.Nil(t, err)
    93  }
    94  
    95  func TestTransactRollback(t *testing.T) {
    96  	mock := &mockTx{}
    97  	err := transactOnConn(context.Background(), nil, beginMock(mock),
    98  		func(context.Context, Session) error {
    99  			return errors.New("rollback")
   100  		})
   101  	assert.Equal(t, mockRollback, mock.status)
   102  	assert.NotNil(t, err)
   103  }