github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/transaction/transaction_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transaction_test
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"testing"
    21  
    22  	"github.com/ecodeclub/eorm/internal/datasource/transaction"
    23  
    24  	"github.com/stretchr/testify/suite"
    25  
    26  	"github.com/DATA-DOG/go-sqlmock"
    27  	"github.com/ecodeclub/eorm/internal/datasource"
    28  	"github.com/stretchr/testify/assert"
    29  )
    30  
    31  func TestTx_Commit(t *testing.T) {
    32  	mockDB, mock, err := sqlmock.New()
    33  	if err != nil {
    34  		t.Fatal(err)
    35  	}
    36  	defer func() { _ = mockDB.Close() }()
    37  
    38  	db := openMockDB("mysql", mockDB)
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  	defer func() {
    43  		mock.ExpectClose()
    44  		_ = db.Close()
    45  	}()
    46  
    47  	// 事务正常提交
    48  	mock.ExpectBegin()
    49  	mock.ExpectCommit()
    50  
    51  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    52  	assert.Nil(t, err)
    53  	err = tx.Commit()
    54  	assert.Nil(t, err)
    55  
    56  }
    57  
    58  func TestTx_Rollback(t *testing.T) {
    59  	mockDB, mock, err := sqlmock.New()
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	defer func() { _ = mockDB.Close() }()
    64  
    65  	db := openMockDB("mysql", mockDB)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  
    70  	// 事务回滚
    71  	mock.ExpectBegin()
    72  	mock.ExpectRollback()
    73  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    74  	assert.Nil(t, err)
    75  	err = tx.Rollback()
    76  	assert.Nil(t, err)
    77  }
    78  
    79  type testMockDB struct {
    80  	driver string
    81  	db     *sql.DB
    82  }
    83  
    84  func (*testMockDB) Query(_ context.Context, _ datasource.Query) (*sql.Rows, error) {
    85  	return &sql.Rows{}, nil
    86  }
    87  
    88  func (*testMockDB) Exec(_ context.Context, _ datasource.Query) (sql.Result, error) {
    89  	return nil, nil
    90  }
    91  
    92  func openMockDB(driver string, db *sql.DB) *testMockDB {
    93  	return &testMockDB{driver: driver, db: db}
    94  }
    95  
    96  func (db *testMockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) {
    97  	tx, err := db.db.BeginTx(ctx, opts)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return transaction.NewTx(tx, db), nil
   102  }
   103  
   104  func (db *testMockDB) Close() error {
   105  	return db.db.Close()
   106  }
   107  
   108  type TransactionSuite struct {
   109  	suite.Suite
   110  	mockDB1 *sql.DB
   111  	mock1   sqlmock.Sqlmock
   112  
   113  	mockDB2 *sql.DB
   114  	mock2   sqlmock.Sqlmock
   115  
   116  	mockDB3 *sql.DB
   117  	mock3   sqlmock.Sqlmock
   118  }
   119  
   120  func (s *TransactionSuite) SetupTest() {
   121  	t := s.T()
   122  	s.initMock(t)
   123  }
   124  
   125  func (s *TransactionSuite) TearDownTest() {
   126  	_ = s.mockDB1.Close()
   127  	_ = s.mockDB2.Close()
   128  	_ = s.mockDB3.Close()
   129  }
   130  
   131  func (s *TransactionSuite) initMock(t *testing.T) {
   132  	var err error
   133  	s.mockDB1, s.mock1, err = sqlmock.New()
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	s.mockDB2, s.mock2, err = sqlmock.New()
   138  	if err != nil {
   139  		t.Fatal(err)
   140  	}
   141  	s.mockDB3, s.mock3, err = sqlmock.New()
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  }
   146  
   147  func (s *TransactionSuite) TestDBQuery() {
   148  	//s.mock.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("value"))
   149  	testCases := []struct {
   150  		name     string
   151  		tx       *transaction.Tx
   152  		query    datasource.Query
   153  		mockRows *sqlmock.Rows
   154  		wantResp []string
   155  		wantErr  error
   156  	}{
   157  		{
   158  			name: "query tx",
   159  			query: datasource.Query{
   160  				SQL: "SELECT `first_name` FROM `test_model`",
   161  			},
   162  			tx: func() *transaction.Tx {
   163  				s.mock1.ExpectBegin()
   164  				s.mock1.ExpectQuery("SELECT *").WillReturnRows(
   165  					sqlmock.NewRows([]string{"first_name"}).AddRow("value"))
   166  				s.mock1.ExpectCommit()
   167  				tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{})
   168  				assert.Nil(s.T(), err)
   169  				return transaction.NewTx(tx, NewMockDB(s.mockDB1))
   170  			}(),
   171  			wantResp: []string{"value"},
   172  		},
   173  	}
   174  	for _, tc := range testCases {
   175  		s.T().Run(tc.name, func(t *testing.T) {
   176  			tx := tc.tx
   177  			rows, queryErr := tx.Query(context.Background(), tc.query)
   178  			assert.Equal(t, queryErr, tc.wantErr)
   179  			if queryErr != nil {
   180  				return
   181  			}
   182  			assert.NotNil(t, rows)
   183  			var resp []string
   184  			for rows.Next() {
   185  				val := new(string)
   186  				err := rows.Scan(val)
   187  				assert.Nil(t, err)
   188  				if err != nil {
   189  					return
   190  				}
   191  				assert.NotNil(t, val)
   192  				resp = append(resp, *val)
   193  			}
   194  			assert.Nil(t, tx.Commit())
   195  			assert.ElementsMatch(t, tc.wantResp, resp)
   196  		})
   197  	}
   198  }
   199  
   200  func (s *TransactionSuite) TestDBExec() {
   201  	testCases := []struct {
   202  		name         string
   203  		lastInsertId int64
   204  		rowsAffected int64
   205  		wantErr      error
   206  		isCommit     bool
   207  		tx           *transaction.Tx
   208  		query        datasource.Query
   209  	}{
   210  		{
   211  			name: "res 1 rollback",
   212  			query: datasource.Query{
   213  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   214  			},
   215  			tx: func() *transaction.Tx {
   216  				s.mock1.ExpectBegin()
   217  				s.mock1.ExpectExec("^INSERT INTO (.+)").
   218  					WillReturnResult(sqlmock.NewResult(2, 1))
   219  				s.mock1.ExpectRollback()
   220  				tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{})
   221  				assert.Nil(s.T(), err)
   222  				return transaction.NewTx(tx, NewMockDB(s.mockDB1))
   223  			}(),
   224  			lastInsertId: int64(2),
   225  			rowsAffected: int64(1),
   226  		},
   227  		{
   228  			name: "res 1",
   229  			query: datasource.Query{
   230  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   231  			},
   232  			tx: func() *transaction.Tx {
   233  				s.mock2.ExpectBegin()
   234  				s.mock2.ExpectExec("^INSERT INTO (.+)").
   235  					WillReturnResult(sqlmock.NewResult(2, 1))
   236  				s.mock2.ExpectCommit()
   237  				tx, err := s.mockDB2.BeginTx(context.Background(), &sql.TxOptions{})
   238  				assert.Nil(s.T(), err)
   239  				return transaction.NewTx(tx, NewMockDB(s.mockDB2))
   240  			}(),
   241  			isCommit:     true,
   242  			lastInsertId: int64(2),
   243  			rowsAffected: int64(1),
   244  		},
   245  		{
   246  			name: "res 2",
   247  			query: datasource.Query{
   248  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4) (1,2,3,4)",
   249  			},
   250  			tx: func() *transaction.Tx {
   251  				s.mock3.ExpectBegin()
   252  				s.mock3.ExpectExec("^INSERT INTO (.+)").
   253  					WillReturnResult(sqlmock.NewResult(4, 2))
   254  				s.mock3.ExpectCommit()
   255  				tx, err := s.mockDB3.BeginTx(context.Background(), &sql.TxOptions{})
   256  				assert.Nil(s.T(), err)
   257  				return transaction.NewTx(tx, NewMockDB(s.mockDB3))
   258  			}(),
   259  			isCommit:     true,
   260  			lastInsertId: int64(4),
   261  			rowsAffected: int64(2),
   262  		},
   263  	}
   264  	for _, tc := range testCases {
   265  		s.T().Run(tc.name, func(t *testing.T) {
   266  			tx := tc.tx
   267  			res, err := tx.Exec(context.Background(), tc.query)
   268  			assert.Nil(t, err)
   269  			lastInsertId, err := res.LastInsertId()
   270  			assert.Nil(t, err)
   271  			assert.EqualValues(t, tc.lastInsertId, lastInsertId)
   272  			rowsAffected, err := res.RowsAffected()
   273  			assert.Nil(t, err)
   274  			if tc.isCommit {
   275  				assert.Nil(t, tx.Commit())
   276  			} else {
   277  				assert.Nil(t, tx.Rollback())
   278  			}
   279  			assert.EqualValues(t, tc.rowsAffected, rowsAffected)
   280  		})
   281  	}
   282  }
   283  
   284  func TestSingleSuite(t *testing.T) {
   285  	suite.Run(t, &TransactionSuite{})
   286  }
   287  
   288  type mockDB struct {
   289  	db *sql.DB
   290  }
   291  
   292  func (m *mockDB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
   293  	return m.db.QueryContext(ctx, query.SQL, query.Args...)
   294  }
   295  
   296  func (m *mockDB) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) {
   297  	return m.db.ExecContext(ctx, query.SQL, query.Args...)
   298  }
   299  
   300  func (m *mockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) {
   301  	tx, err := m.db.BeginTx(ctx, opts)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	return transaction.NewTx(tx, m), nil
   306  }
   307  
   308  func (m *mockDB) Close() error {
   309  	return m.db.Close()
   310  }
   311  
   312  func NewMockDB(db *sql.DB) datasource.DataSource {
   313  	return &mockDB{
   314  		db: db,
   315  	}
   316  }