github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/transaction_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"testing"
    22  
    23  	"github.com/DATA-DOG/go-sqlmock"
    24  	"github.com/ecodeclub/eorm/internal/datasource"
    25  	"github.com/ecodeclub/eorm/internal/datasource/single"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  func TestTx_Commit(t *testing.T) {
    31  	mockDB, mock, err := sqlmock.New()
    32  	if err != nil {
    33  		t.Fatal(err)
    34  	}
    35  	defer func() { _ = mockDB.Close() }()
    36  
    37  	db, err := OpenDS("mysql", single.NewDB(mockDB))
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	defer func() {
    42  		mock.ExpectClose()
    43  		_ = db.Close()
    44  	}()
    45  
    46  	// 事务正常提交
    47  	mock.ExpectBegin()
    48  	mock.ExpectCommit()
    49  
    50  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    51  	assert.Nil(t, err)
    52  	err = tx.Commit()
    53  	assert.Nil(t, err)
    54  
    55  }
    56  
    57  func TestTx_Rollback(t *testing.T) {
    58  	mockDB, mock, err := sqlmock.New()
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	defer func() { _ = mockDB.Close() }()
    63  
    64  	db, err := OpenDS("mysql", single.NewDB(mockDB))
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  
    69  	// 事务回滚
    70  	mock.ExpectBegin()
    71  	mock.ExpectRollback()
    72  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
    73  	assert.Nil(t, err)
    74  	err = tx.Rollback()
    75  	assert.Nil(t, err)
    76  }
    77  
    78  func TestTx_QueryContext(t *testing.T) {
    79  	testCases := []struct {
    80  		name       string
    81  		query      Query
    82  		mockOrder  func(mock sqlmock.Sqlmock)
    83  		sourceFunc func(db *sql.DB, t *testing.T) datasource.DataSource
    84  		wantResp   []string
    85  		wantErr    error
    86  		isCommit   bool
    87  	}{
    88  		{
    89  			name: "err",
    90  			mockOrder: func(mock sqlmock.Sqlmock) {
    91  				mock.ExpectBegin()
    92  				mock.ExpectQuery("SELECT `xx` FROM `test_model`").
    93  					WillReturnError(errors.New("未知字段"))
    94  				mock.ExpectRollback()
    95  			},
    96  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
    97  				return single.NewDB(db)
    98  			},
    99  			query: Query{
   100  				SQL: "SELECT `xx` FROM `test_model`",
   101  			},
   102  			wantErr:  errors.New("未知字段"),
   103  			isCommit: false,
   104  		},
   105  		{
   106  			name: "commit",
   107  			mockOrder: func(mock sqlmock.Sqlmock) {
   108  				mock.ExpectBegin()
   109  				mock.ExpectQuery("SELECT `first_name` FROM `test_model`").
   110  					WillReturnRows(sqlmock.NewRows([]string{"first_name"}).AddRow("value"))
   111  				mock.ExpectCommit()
   112  			},
   113  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   114  				return single.NewDB(db)
   115  			},
   116  			query: Query{
   117  				SQL: "SELECT `first_name` FROM `test_model`",
   118  			},
   119  			isCommit: true,
   120  		},
   121  	}
   122  	for _, tc := range testCases {
   123  		t.Run(tc.name, func(t *testing.T) {
   124  			mockDB, mock, err := sqlmock.New()
   125  			if err != nil {
   126  				t.Fatal(err)
   127  			}
   128  			defer func(db *sql.DB) { _ = db.Close() }(mockDB)
   129  			tc.mockOrder(mock)
   130  			source := tc.sourceFunc(mockDB, t)
   131  			orm, err := OpenDS("mysql", source)
   132  			require.NoError(t, err)
   133  			tx, err := orm.BeginTx(context.Background(), &sql.TxOptions{})
   134  			require.NoError(t, err)
   135  			rows, queryErr := tx.queryContext(context.Background(), datasource.Query(tc.query))
   136  			assert.Equal(t, queryErr, tc.wantErr)
   137  			if queryErr != nil {
   138  				return
   139  			}
   140  
   141  			if tc.isCommit {
   142  				err = tx.Commit()
   143  			} else {
   144  				err = tx.Rollback()
   145  			}
   146  			assert.Equal(t, tc.wantErr, err)
   147  			if err != nil {
   148  				return
   149  			}
   150  
   151  			assert.NotNil(t, rows)
   152  			var resp []string
   153  			for rows.Next() {
   154  				val := new(string)
   155  				err := rows.Scan(val)
   156  				assert.Nil(t, err)
   157  				if err != nil {
   158  					return
   159  				}
   160  				assert.NotNil(t, val)
   161  				resp = append(resp, *val)
   162  			}
   163  
   164  			assert.ElementsMatch(t, tc.wantResp, resp)
   165  			if err = mock.ExpectationsWereMet(); err != nil {
   166  				t.Error(err)
   167  			}
   168  		})
   169  	}
   170  }
   171  
   172  func TestTx_ExecContext(t *testing.T) {
   173  	testCases := []struct {
   174  		name           string
   175  		query          Query
   176  		mockOrder      func(mock sqlmock.Sqlmock)
   177  		sourceFunc     func(db *sql.DB, t *testing.T) datasource.DataSource
   178  		wantVal        sql.Result
   179  		wantBeginTxErr error
   180  		wantErr        error
   181  		isCommit       bool
   182  	}{
   183  		//{
   184  		//	name: "source err",
   185  		//	mockOrder: func(mock sqlmock.Sqlmock) {
   186  		//		mock.ExpectBegin()
   187  		//		mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20))
   188  		//		mock.ExpectCommit()
   189  		//	},
   190  		//	sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   191  		//		clusterDB := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{
   192  		//			"db0": masterslave.NewMasterSlavesDB(db),
   193  		//		})
   194  		//		return clusterDB
   195  		//	},
   196  		//	query: Query{
   197  		//		SQL:  "DELETE FROM `test_model` WHERE `id`=",
   198  		//		Args: []any{1},
   199  		//	},
   200  		//	wantBeginTxErr: errors.New("eorm: 未实现 TxBeginner 接口"),
   201  		//},
   202  		{
   203  			name: "commit err",
   204  			mockOrder: func(mock sqlmock.Sqlmock) {
   205  				mock.ExpectBegin()
   206  				mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20))
   207  				mock.ExpectCommit().WillReturnError(errors.New("commit 错误"))
   208  			},
   209  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   210  				return single.NewDB(db)
   211  			},
   212  			query: Query{
   213  				SQL:  "DELETE FROM `test_model` WHERE `id`=",
   214  				Args: []any{1},
   215  			},
   216  			wantErr:  errors.New("commit 错误"),
   217  			isCommit: true,
   218  		},
   219  		{
   220  			name: "rollback err",
   221  			mockOrder: func(mock sqlmock.Sqlmock) {
   222  				mock.ExpectBegin()
   223  				mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20))
   224  				mock.ExpectRollback().WillReturnError(errors.New("rollback 错误"))
   225  			},
   226  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   227  				return single.NewDB(db)
   228  			},
   229  			query: Query{
   230  				SQL:  "DELETE FROM `test_model` WHERE `id`=",
   231  				Args: []any{1},
   232  			},
   233  			wantErr: errors.New("rollback 错误"),
   234  		},
   235  		{
   236  			name: "commit",
   237  			mockOrder: func(mock sqlmock.Sqlmock) {
   238  				mock.ExpectBegin()
   239  				mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20))
   240  				mock.ExpectCommit()
   241  			},
   242  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   243  				return single.NewDB(db)
   244  			},
   245  			query: Query{
   246  				SQL:  "DELETE FROM `test_model` WHERE `id`=",
   247  				Args: []any{1},
   248  			},
   249  			wantVal:  sqlmock.NewResult(10, 20),
   250  			isCommit: true,
   251  		},
   252  		{
   253  			name: "rollback",
   254  			mockOrder: func(mock sqlmock.Sqlmock) {
   255  				mock.ExpectBegin()
   256  				mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20))
   257  				mock.ExpectRollback()
   258  			},
   259  			sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource {
   260  				return single.NewDB(db)
   261  			},
   262  			query: Query{
   263  				SQL:  "DELETE FROM `test_model` WHERE `id`=",
   264  				Args: []any{1},
   265  			},
   266  			wantVal: sqlmock.NewResult(10, 20),
   267  		},
   268  	}
   269  
   270  	for _, tc := range testCases {
   271  		t.Run(tc.name, func(t *testing.T) {
   272  			mockDB, mock, err := sqlmock.New()
   273  			if err != nil {
   274  				t.Fatal(err)
   275  			}
   276  			defer func(db *sql.DB) { _ = db.Close() }(mockDB)
   277  			tc.mockOrder(mock)
   278  
   279  			source := tc.sourceFunc(mockDB, t)
   280  			orm, err := OpenDS("mysql", source)
   281  			require.NoError(t, err)
   282  			tx, err := orm.BeginTx(context.Background(), &sql.TxOptions{})
   283  			assert.Equal(t, tc.wantBeginTxErr, err)
   284  			if err != nil {
   285  				return
   286  			}
   287  			result, err := tx.execContext(context.Background(), datasource.Query(tc.query))
   288  			require.NoError(t, err)
   289  
   290  			if tc.isCommit {
   291  				err = tx.Commit()
   292  			} else {
   293  				err = tx.Rollback()
   294  			}
   295  			assert.Equal(t, tc.wantErr, err)
   296  			if err != nil {
   297  				return
   298  			}
   299  
   300  			rowsAffectedExpect, err := tc.wantVal.RowsAffected()
   301  			require.NoError(t, err)
   302  			rowsAffected, err := result.RowsAffected()
   303  			require.NoError(t, err)
   304  			assert.Equal(t, rowsAffectedExpect, rowsAffected)
   305  
   306  			lastInsertIdExpected, err := tc.wantVal.LastInsertId()
   307  			require.NoError(t, err)
   308  			lastInsertId, err := result.LastInsertId()
   309  			require.NoError(t, err)
   310  			assert.Equal(t, lastInsertIdExpected, lastInsertId)
   311  
   312  			if err = mock.ExpectationsWereMet(); err != nil {
   313  				t.Error(err)
   314  			}
   315  		})
   316  	}
   317  }