github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/transaction_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "testing" 22 23 "github.com/DATA-DOG/go-sqlmock" 24 "github.com/ecodeclub/eorm/internal/datasource" 25 "github.com/ecodeclub/eorm/internal/datasource/single" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 ) 29 30 func TestTx_Commit(t *testing.T) { 31 mockDB, mock, err := sqlmock.New() 32 if err != nil { 33 t.Fatal(err) 34 } 35 defer func() { _ = mockDB.Close() }() 36 37 db, err := OpenDS("mysql", single.NewDB(mockDB)) 38 if err != nil { 39 t.Fatal(err) 40 } 41 defer func() { 42 mock.ExpectClose() 43 _ = db.Close() 44 }() 45 46 // 事务正常提交 47 mock.ExpectBegin() 48 mock.ExpectCommit() 49 50 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 51 assert.Nil(t, err) 52 err = tx.Commit() 53 assert.Nil(t, err) 54 55 } 56 57 func TestTx_Rollback(t *testing.T) { 58 mockDB, mock, err := sqlmock.New() 59 if err != nil { 60 t.Fatal(err) 61 } 62 defer func() { _ = mockDB.Close() }() 63 64 db, err := OpenDS("mysql", single.NewDB(mockDB)) 65 if err != nil { 66 t.Fatal(err) 67 } 68 69 // 事务回滚 70 mock.ExpectBegin() 71 mock.ExpectRollback() 72 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 73 assert.Nil(t, err) 74 err = tx.Rollback() 75 assert.Nil(t, err) 76 } 77 78 func TestTx_QueryContext(t *testing.T) { 79 testCases := []struct { 80 name string 81 query Query 82 mockOrder func(mock sqlmock.Sqlmock) 83 sourceFunc func(db *sql.DB, t *testing.T) datasource.DataSource 84 wantResp []string 85 wantErr error 86 isCommit bool 87 }{ 88 { 89 name: "err", 90 mockOrder: func(mock sqlmock.Sqlmock) { 91 mock.ExpectBegin() 92 mock.ExpectQuery("SELECT `xx` FROM `test_model`"). 93 WillReturnError(errors.New("未知字段")) 94 mock.ExpectRollback() 95 }, 96 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 97 return single.NewDB(db) 98 }, 99 query: Query{ 100 SQL: "SELECT `xx` FROM `test_model`", 101 }, 102 wantErr: errors.New("未知字段"), 103 isCommit: false, 104 }, 105 { 106 name: "commit", 107 mockOrder: func(mock sqlmock.Sqlmock) { 108 mock.ExpectBegin() 109 mock.ExpectQuery("SELECT `first_name` FROM `test_model`"). 110 WillReturnRows(sqlmock.NewRows([]string{"first_name"}).AddRow("value")) 111 mock.ExpectCommit() 112 }, 113 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 114 return single.NewDB(db) 115 }, 116 query: Query{ 117 SQL: "SELECT `first_name` FROM `test_model`", 118 }, 119 isCommit: true, 120 }, 121 } 122 for _, tc := range testCases { 123 t.Run(tc.name, func(t *testing.T) { 124 mockDB, mock, err := sqlmock.New() 125 if err != nil { 126 t.Fatal(err) 127 } 128 defer func(db *sql.DB) { _ = db.Close() }(mockDB) 129 tc.mockOrder(mock) 130 source := tc.sourceFunc(mockDB, t) 131 orm, err := OpenDS("mysql", source) 132 require.NoError(t, err) 133 tx, err := orm.BeginTx(context.Background(), &sql.TxOptions{}) 134 require.NoError(t, err) 135 rows, queryErr := tx.queryContext(context.Background(), datasource.Query(tc.query)) 136 assert.Equal(t, queryErr, tc.wantErr) 137 if queryErr != nil { 138 return 139 } 140 141 if tc.isCommit { 142 err = tx.Commit() 143 } else { 144 err = tx.Rollback() 145 } 146 assert.Equal(t, tc.wantErr, err) 147 if err != nil { 148 return 149 } 150 151 assert.NotNil(t, rows) 152 var resp []string 153 for rows.Next() { 154 val := new(string) 155 err := rows.Scan(val) 156 assert.Nil(t, err) 157 if err != nil { 158 return 159 } 160 assert.NotNil(t, val) 161 resp = append(resp, *val) 162 } 163 164 assert.ElementsMatch(t, tc.wantResp, resp) 165 if err = mock.ExpectationsWereMet(); err != nil { 166 t.Error(err) 167 } 168 }) 169 } 170 } 171 172 func TestTx_ExecContext(t *testing.T) { 173 testCases := []struct { 174 name string 175 query Query 176 mockOrder func(mock sqlmock.Sqlmock) 177 sourceFunc func(db *sql.DB, t *testing.T) datasource.DataSource 178 wantVal sql.Result 179 wantBeginTxErr error 180 wantErr error 181 isCommit bool 182 }{ 183 //{ 184 // name: "source err", 185 // mockOrder: func(mock sqlmock.Sqlmock) { 186 // mock.ExpectBegin() 187 // mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) 188 // mock.ExpectCommit() 189 // }, 190 // sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 191 // clusterDB := cluster.NewClusterDB(map[string]*masterslave.MasterSlavesDB{ 192 // "db0": masterslave.NewMasterSlavesDB(db), 193 // }) 194 // return clusterDB 195 // }, 196 // query: Query{ 197 // SQL: "DELETE FROM `test_model` WHERE `id`=", 198 // Args: []any{1}, 199 // }, 200 // wantBeginTxErr: errors.New("eorm: 未实现 TxBeginner 接口"), 201 //}, 202 { 203 name: "commit err", 204 mockOrder: func(mock sqlmock.Sqlmock) { 205 mock.ExpectBegin() 206 mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) 207 mock.ExpectCommit().WillReturnError(errors.New("commit 错误")) 208 }, 209 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 210 return single.NewDB(db) 211 }, 212 query: Query{ 213 SQL: "DELETE FROM `test_model` WHERE `id`=", 214 Args: []any{1}, 215 }, 216 wantErr: errors.New("commit 错误"), 217 isCommit: true, 218 }, 219 { 220 name: "rollback err", 221 mockOrder: func(mock sqlmock.Sqlmock) { 222 mock.ExpectBegin() 223 mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) 224 mock.ExpectRollback().WillReturnError(errors.New("rollback 错误")) 225 }, 226 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 227 return single.NewDB(db) 228 }, 229 query: Query{ 230 SQL: "DELETE FROM `test_model` WHERE `id`=", 231 Args: []any{1}, 232 }, 233 wantErr: errors.New("rollback 错误"), 234 }, 235 { 236 name: "commit", 237 mockOrder: func(mock sqlmock.Sqlmock) { 238 mock.ExpectBegin() 239 mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) 240 mock.ExpectCommit() 241 }, 242 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 243 return single.NewDB(db) 244 }, 245 query: Query{ 246 SQL: "DELETE FROM `test_model` WHERE `id`=", 247 Args: []any{1}, 248 }, 249 wantVal: sqlmock.NewResult(10, 20), 250 isCommit: true, 251 }, 252 { 253 name: "rollback", 254 mockOrder: func(mock sqlmock.Sqlmock) { 255 mock.ExpectBegin() 256 mock.ExpectExec("DELETE FROM `test_model` WHERE `id`=").WithArgs(1).WillReturnResult(sqlmock.NewResult(10, 20)) 257 mock.ExpectRollback() 258 }, 259 sourceFunc: func(db *sql.DB, t *testing.T) datasource.DataSource { 260 return single.NewDB(db) 261 }, 262 query: Query{ 263 SQL: "DELETE FROM `test_model` WHERE `id`=", 264 Args: []any{1}, 265 }, 266 wantVal: sqlmock.NewResult(10, 20), 267 }, 268 } 269 270 for _, tc := range testCases { 271 t.Run(tc.name, func(t *testing.T) { 272 mockDB, mock, err := sqlmock.New() 273 if err != nil { 274 t.Fatal(err) 275 } 276 defer func(db *sql.DB) { _ = db.Close() }(mockDB) 277 tc.mockOrder(mock) 278 279 source := tc.sourceFunc(mockDB, t) 280 orm, err := OpenDS("mysql", source) 281 require.NoError(t, err) 282 tx, err := orm.BeginTx(context.Background(), &sql.TxOptions{}) 283 assert.Equal(t, tc.wantBeginTxErr, err) 284 if err != nil { 285 return 286 } 287 result, err := tx.execContext(context.Background(), datasource.Query(tc.query)) 288 require.NoError(t, err) 289 290 if tc.isCommit { 291 err = tx.Commit() 292 } else { 293 err = tx.Rollback() 294 } 295 assert.Equal(t, tc.wantErr, err) 296 if err != nil { 297 return 298 } 299 300 rowsAffectedExpect, err := tc.wantVal.RowsAffected() 301 require.NoError(t, err) 302 rowsAffected, err := result.RowsAffected() 303 require.NoError(t, err) 304 assert.Equal(t, rowsAffectedExpect, rowsAffected) 305 306 lastInsertIdExpected, err := tc.wantVal.LastInsertId() 307 require.NoError(t, err) 308 lastInsertId, err := result.LastInsertId() 309 require.NoError(t, err) 310 assert.Equal(t, lastInsertIdExpected, lastInsertId) 311 312 if err = mock.ExpectationsWereMet(); err != nil { 313 t.Error(err) 314 } 315 }) 316 } 317 }