github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/single/db_test.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package single 16 17 import ( 18 "context" 19 "database/sql" 20 "database/sql/driver" 21 "errors" 22 "fmt" 23 "testing" 24 25 "github.com/ecodeclub/eorm/internal/datasource" 26 27 "github.com/DATA-DOG/go-sqlmock" 28 _ "github.com/mattn/go-sqlite3" 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/suite" 31 ) 32 33 type SingleSuite struct { 34 suite.Suite 35 mockDB *sql.DB 36 mock sqlmock.Sqlmock 37 } 38 39 func (s *SingleSuite) SetupTest() { 40 t := s.T() 41 s.initMock(t) 42 } 43 44 func (s *SingleSuite) TearDownTest() { 45 _ = s.mockDB.Close() 46 } 47 48 func (s *SingleSuite) initMock(t *testing.T) { 49 var err error 50 s.mockDB, s.mock, err = sqlmock.New() 51 if err != nil { 52 t.Fatal(err) 53 } 54 } 55 56 func (s *SingleSuite) TestDBQuery() { 57 //s.mock.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("value")) 58 59 testCases := []struct { 60 name string 61 query datasource.Query 62 mockRows *sqlmock.Rows 63 wantResp []string 64 wantErr error 65 }{ 66 { 67 name: "one row", 68 query: datasource.Query{ 69 SQL: "SELECT `first_name` FROM `test_model`", 70 }, 71 mockRows: sqlmock.NewRows([]string{"first_name"}).AddRow("value"), 72 wantResp: []string{"value"}, 73 }, 74 { 75 name: "multi row", 76 query: datasource.Query{ 77 SQL: "SELECT `first_name` FROM `test_model`", 78 }, 79 mockRows: func() *sqlmock.Rows { 80 res := sqlmock.NewRows([]string{"first_name"}) 81 res.AddRow("value1") 82 res.AddRow("value2") 83 return res 84 }(), 85 wantResp: []string{"value1", "value2"}, 86 }, 87 } 88 for _, tc := range testCases { 89 s.mock.ExpectQuery(tc.query.SQL).WillReturnRows(tc.mockRows) 90 } 91 for _, tc := range testCases { 92 s.T().Run(tc.name, func(t *testing.T) { 93 db := NewDB(s.mockDB) 94 rows, queryErr := db.Query(context.Background(), tc.query) 95 assert.Equal(t, queryErr, tc.wantErr) 96 if queryErr != nil { 97 return 98 } 99 assert.NotNil(t, rows) 100 var resp []string 101 for rows.Next() { 102 val := new(string) 103 err := rows.Scan(val) 104 assert.Nil(t, err) 105 if err != nil { 106 return 107 } 108 assert.NotNil(t, val) 109 resp = append(resp, *val) 110 } 111 112 assert.ElementsMatch(t, tc.wantResp, resp) 113 }) 114 } 115 } 116 117 func (s *SingleSuite) TestDBExec() { 118 testCases := []struct { 119 name string 120 lastInsertId int64 121 rowsAffected int64 122 wantErr error 123 mockResult driver.Result 124 query datasource.Query 125 }{ 126 { 127 name: "res 1", 128 query: datasource.Query{ 129 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)", 130 }, 131 mockResult: func() driver.Result { 132 return sqlmock.NewResult(2, 1) 133 }(), 134 lastInsertId: int64(2), 135 rowsAffected: int64(1), 136 }, 137 { 138 name: "res 2", 139 query: datasource.Query{ 140 SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4) (1,2,3,4)", 141 }, 142 mockResult: func() driver.Result { 143 return sqlmock.NewResult(4, 2) 144 }(), 145 lastInsertId: int64(4), 146 rowsAffected: int64(2), 147 }, 148 } 149 for _, tc := range testCases { 150 s.mock.ExpectExec("^INSERT INTO (.+)").WillReturnResult(tc.mockResult) 151 } 152 for _, tc := range testCases { 153 s.T().Run(tc.name, func(t *testing.T) { 154 db := NewDB(s.mockDB) 155 res, err := db.Exec(context.Background(), tc.query) 156 assert.Nil(t, err) 157 lastInsertId, err := res.LastInsertId() 158 assert.Nil(t, err) 159 assert.EqualValues(t, tc.lastInsertId, lastInsertId) 160 rowsAffected, err := res.RowsAffected() 161 assert.Nil(t, err) 162 assert.EqualValues(t, tc.rowsAffected, rowsAffected) 163 }) 164 } 165 } 166 167 func TestSingleSuite(t *testing.T) { 168 suite.Run(t, &SingleSuite{}) 169 } 170 171 func TestDB_BeginTx(t *testing.T) { 172 mockDB, mock, err := sqlmock.New() 173 if err != nil { 174 t.Fatal(err) 175 } 176 defer func() { _ = mockDB.Close() }() 177 178 db := NewDB(mockDB) 179 // Begin 失败 180 mock.ExpectBegin().WillReturnError(errors.New("begin failed")) 181 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 182 assert.Equal(t, errors.New("begin failed"), err) 183 assert.Nil(t, tx) 184 185 mock.ExpectBegin() 186 tx, err = db.BeginTx(context.Background(), &sql.TxOptions{}) 187 assert.Nil(t, err) 188 assert.NotNil(t, tx) 189 } 190 191 func TestDB_Wait(t *testing.T) { 192 mockDB, mock, err := sqlmock.New() 193 if err != nil { 194 t.Fatal(err) 195 } 196 defer func() { _ = mockDB.Close() }() 197 198 db := NewDB(mockDB) 199 if err != nil { 200 t.Fatal(err) 201 } 202 mock.ExpectPing() 203 err = db.Wait() 204 assert.Nil(t, err) 205 } 206 207 func ExampleDB_BeginTx() { 208 db, _ := OpenDB("sqlite3", "file:test.db?cache=shared&mode=memory") 209 defer func() { 210 _ = db.Close() 211 }() 212 tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) 213 if err == nil { 214 fmt.Println("Begin") 215 } 216 // 或者 tx.Rollback() 217 err = tx.Commit() 218 if err == nil { 219 fmt.Println("Commit") 220 } 221 // Output: 222 // Begin 223 // Commit 224 } 225 226 func ExampleDB_Close() { 227 db, _ := OpenDB("sqlite3", "file:test.db?cache=shared&mode=memory") 228 err := db.Close() 229 if err == nil { 230 fmt.Println("close") 231 } 232 233 // Output: 234 // close 235 }