github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/transaction/transaction_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transaction_test 16 17 import ( 18 "context" 19 "database/sql" 20 "testing" 21 22 "github.com/ecodeclub/eorm/internal/datasource/transaction" 23 24 "github.com/stretchr/testify/suite" 25 26 "github.com/DATA-DOG/go-sqlmock" 27 "github.com/ecodeclub/eorm/internal/datasource" 28 "github.com/stretchr/testify/assert" 29 ) 30 31 func TestTx_Commit(t *testing.T) { 32 mockDB, mock, err := sqlmock.New() 33 if err != nil { 34 t.Fatal(err) 35 } 36 defer func() { _ = mockDB.Close() }() 37 38 db := openMockDB("mysql", mockDB) 39 if err != nil { 40 t.Fatal(err) 41 } 42 defer func() { 43 mock.ExpectClose() 44 _ = db.Close() 45 }() 46 47 // 事务正常提交 48 mock.ExpectBegin() 49 mock.ExpectCommit() 50 51 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 52 assert.Nil(t, err) 53 err = tx.Commit() 54 assert.Nil(t, err) 55 56 } 57 58 func TestTx_Rollback(t *testing.T) { 59 mockDB, mock, err := sqlmock.New() 60 if err != nil { 61 t.Fatal(err) 62 } 63 defer func() { _ = mockDB.Close() }() 64 65 db := openMockDB("mysql", mockDB) 66 if err != nil { 67 t.Fatal(err) 68 } 69 70 // 事务回滚 71 mock.ExpectBegin() 72 mock.ExpectRollback() 73 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 74 assert.Nil(t, err) 75 err = tx.Rollback() 76 assert.Nil(t, err) 77 } 78 79 type testMockDB struct { 80 driver string 81 db *sql.DB 82 } 83 84 func (*testMockDB) Query(_ context.Context, _ datasource.Query) (*sql.Rows, error) { 85 return &sql.Rows{}, nil 86 } 87 88 func (*testMockDB) Exec(_ context.Context, _ datasource.Query) (sql.Result, error) { 89 return nil, nil 90 } 91 92 func openMockDB(driver string, db *sql.DB) *testMockDB { 93 return &testMockDB{driver: driver, db: db} 94 } 95 96 func (db *testMockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { 97 tx, err := db.db.BeginTx(ctx, opts) 98 if err != nil { 99 return nil, err 100 } 101 return transaction.NewTx(tx, db), nil 102 } 103 104 func (db *testMockDB) Close() error { 105 return db.db.Close() 106 } 107 108 type TransactionSuite struct { 109 suite.Suite 110 mockDB1 *sql.DB 111 mock1 sqlmock.Sqlmock 112 113 mockDB2 *sql.DB 114 mock2 sqlmock.Sqlmock 115 116 mockDB3 *sql.DB 117 mock3 sqlmock.Sqlmock 118 } 119 120 func (s *TransactionSuite) SetupTest() { 121 t := s.T() 122 s.initMock(t) 123 } 124 125 func (s *TransactionSuite) TearDownTest() { 126 _ = s.mockDB1.Close() 127 _ = s.mockDB2.Close() 128 _ = s.mockDB3.Close() 129 } 130 131 func (s *TransactionSuite) initMock(t *testing.T) { 132 var err error 133 s.mockDB1, s.mock1, err = sqlmock.New() 134 if err != nil { 135 t.Fatal(err) 136 } 137 s.mockDB2, s.mock2, err = sqlmock.New() 138 if err != nil { 139 t.Fatal(err) 140 } 141 s.mockDB3, s.mock3, err = sqlmock.New() 142 if err != nil { 143 t.Fatal(err) 144 } 145 } 146 147 func (s *TransactionSuite) TestDBQuery() { 148 //s.mock.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("value")) 149 testCases := []struct { 150 name string 151 tx *transaction.Tx 152 query datasource.Query 153 mockRows *sqlmock.Rows 154 wantResp []string 155 wantErr error 156 }{ 157 { 158 name: "query tx", 159 query: datasource.Query{ 160 SQL: "SELECT `first_name` FROM `test_model`", 161 }, 162 tx: func() *transaction.Tx { 163 s.mock1.ExpectBegin() 164 s.mock1.ExpectQuery("SELECT *").WillReturnRows( 165 sqlmock.NewRows([]string{"first_name"}).AddRow("value")) 166 s.mock1.ExpectCommit() 167 tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{}) 168 assert.Nil(s.T(), err) 169 return transaction.NewTx(tx, NewMockDB(s.mockDB1)) 170 }(), 171 wantResp: []string{"value"}, 172 }, 173 } 174 for _, tc := range testCases { 175 s.T().Run(tc.name, func(t *testing.T) { 176 tx := tc.tx 177 rows, queryErr := tx.Query(context.Background(), tc.query) 178 assert.Equal(t, queryErr, tc.wantErr) 179 if queryErr != nil { 180 return 181 } 182 assert.NotNil(t, rows) 183 var resp []string 184 for rows.Next() { 185 val := new(string) 186 err := rows.Scan(val) 187 assert.Nil(t, err) 188 if err != nil { 189 return 190 } 191 assert.NotNil(t, val) 192 resp = append(resp, *val) 193 } 194 assert.Nil(t, tx.Commit()) 195 assert.ElementsMatch(t, tc.wantResp, resp) 196 }) 197 } 198 } 199 200 func (s *TransactionSuite) TestDBExec() { 201 testCases := []struct { 202 name string 203 lastInsertId int64 204 rowsAffected int64 205 wantErr error 206 isCommit bool 207 tx *transaction.Tx 208 query datasource.Query 209 }{ 210 { 211 name: "res 1 rollback", 212 query: datasource.Query{ 213 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 214 }, 215 tx: func() *transaction.Tx { 216 s.mock1.ExpectBegin() 217 s.mock1.ExpectExec("^INSERT INTO (.+)"). 218 WillReturnResult(sqlmock.NewResult(2, 1)) 219 s.mock1.ExpectRollback() 220 tx, err := s.mockDB1.BeginTx(context.Background(), &sql.TxOptions{}) 221 assert.Nil(s.T(), err) 222 return transaction.NewTx(tx, NewMockDB(s.mockDB1)) 223 }(), 224 lastInsertId: int64(2), 225 rowsAffected: int64(1), 226 }, 227 { 228 name: "res 1", 229 query: datasource.Query{ 230 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 231 }, 232 tx: func() *transaction.Tx { 233 s.mock2.ExpectBegin() 234 s.mock2.ExpectExec("^INSERT INTO (.+)"). 235 WillReturnResult(sqlmock.NewResult(2, 1)) 236 s.mock2.ExpectCommit() 237 tx, err := s.mockDB2.BeginTx(context.Background(), &sql.TxOptions{}) 238 assert.Nil(s.T(), err) 239 return transaction.NewTx(tx, NewMockDB(s.mockDB2)) 240 }(), 241 isCommit: true, 242 lastInsertId: int64(2), 243 rowsAffected: int64(1), 244 }, 245 { 246 name: "res 2", 247 query: datasource.Query{ 248 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4) (1,2,3,4)", 249 }, 250 tx: func() *transaction.Tx { 251 s.mock3.ExpectBegin() 252 s.mock3.ExpectExec("^INSERT INTO (.+)"). 253 WillReturnResult(sqlmock.NewResult(4, 2)) 254 s.mock3.ExpectCommit() 255 tx, err := s.mockDB3.BeginTx(context.Background(), &sql.TxOptions{}) 256 assert.Nil(s.T(), err) 257 return transaction.NewTx(tx, NewMockDB(s.mockDB3)) 258 }(), 259 isCommit: true, 260 lastInsertId: int64(4), 261 rowsAffected: int64(2), 262 }, 263 } 264 for _, tc := range testCases { 265 s.T().Run(tc.name, func(t *testing.T) { 266 tx := tc.tx 267 res, err := tx.Exec(context.Background(), tc.query) 268 assert.Nil(t, err) 269 lastInsertId, err := res.LastInsertId() 270 assert.Nil(t, err) 271 assert.EqualValues(t, tc.lastInsertId, lastInsertId) 272 rowsAffected, err := res.RowsAffected() 273 assert.Nil(t, err) 274 if tc.isCommit { 275 assert.Nil(t, tx.Commit()) 276 } else { 277 assert.Nil(t, tx.Rollback()) 278 } 279 assert.EqualValues(t, tc.rowsAffected, rowsAffected) 280 }) 281 } 282 } 283 284 func TestSingleSuite(t *testing.T) { 285 suite.Run(t, &TransactionSuite{}) 286 } 287 288 type mockDB struct { 289 db *sql.DB 290 } 291 292 func (m *mockDB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { 293 return m.db.QueryContext(ctx, query.SQL, query.Args...) 294 } 295 296 func (m *mockDB) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { 297 return m.db.ExecContext(ctx, query.SQL, query.Args...) 298 } 299 300 func (m *mockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { 301 tx, err := m.db.BeginTx(ctx, opts) 302 if err != nil { 303 return nil, err 304 } 305 return transaction.NewTx(tx, m), nil 306 } 307 308 func (m *mockDB) Close() error { 309 return m.db.Close() 310 } 311 312 func NewMockDB(db *sql.DB) datasource.DataSource { 313 return &mockDB{ 314 db: db, 315 } 316 }