github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/tx_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  const (
    12  	mockCommit   = 1
    13  	mockRollback = 2
    14  )
    15  
    16  type mockTx struct {
    17  	status int
    18  }
    19  
    20  func (mt *mockTx) Commit() error {
    21  	mt.status |= mockCommit
    22  	return nil
    23  }
    24  
    25  func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
    26  	return nil, nil
    27  }
    28  
    29  func (mt *mockTx) Prepare(query string) (StmtSession, error) {
    30  	return nil, nil
    31  }
    32  
    33  func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
    34  	return nil
    35  }
    36  
    37  func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
    38  	return nil
    39  }
    40  
    41  func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
    42  	return nil
    43  }
    44  
    45  func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
    46  	return nil
    47  }
    48  
    49  func (mt *mockTx) Rollback() error {
    50  	mt.status |= mockRollback
    51  	return nil
    52  }
    53  
    54  func beginMock(mock *mockTx) beginnable {
    55  	return func(*sql.DB) (trans, error) {
    56  		return mock, nil
    57  	}
    58  }
    59  
    60  func TestTransactCommit(t *testing.T) {
    61  	mock := &mockTx{}
    62  	err := transactOnConn(nil, beginMock(mock), func(Session) error {
    63  		return nil
    64  	})
    65  	assert.Equal(t, mockCommit, mock.status)
    66  	assert.Nil(t, err)
    67  }
    68  
    69  func TestTransactRollback(t *testing.T) {
    70  	mock := &mockTx{}
    71  	err := transactOnConn(nil, beginMock(mock), func(Session) error {
    72  		return errors.New("rollback")
    73  	})
    74  	assert.Equal(t, mockRollback, mock.status)
    75  	assert.NotNil(t, err)
    76  }