github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/stmt_test.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "errors" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/assert" 10 ) 11 12 var errMockedPlaceholder = errors.New("placeholder") 13 14 func TestStmt_exec(t *testing.T) { 15 tests := []struct { 16 name string 17 query string 18 args []interface{} 19 delay bool 20 hasError bool 21 err error 22 lastInsertId int64 23 rowsAffected int64 24 }{ 25 { 26 name: "normal", 27 query: "select user from users where id=?", 28 args: []interface{}{1}, 29 lastInsertId: 1, 30 rowsAffected: 2, 31 }, 32 { 33 name: "exec error", 34 query: "select user from users where id=?", 35 args: []interface{}{1}, 36 hasError: true, 37 err: errors.New("exec"), 38 }, 39 { 40 name: "exec more args error", 41 query: "select user from users where id=? and name=?", 42 args: []interface{}{1}, 43 hasError: true, 44 err: errors.New("exec"), 45 }, 46 { 47 name: "slowcall", 48 query: "select user from users where id=?", 49 args: []interface{}{1}, 50 delay: true, 51 lastInsertId: 1, 52 rowsAffected: 2, 53 }, 54 } 55 56 for _, test := range tests { 57 test := test 58 fns := []func(args ...interface{}) (sql.Result, error){ 59 func(args ...interface{}) (sql.Result, error) { 60 return exec(&mockedSessionConn{ 61 lastInsertId: test.lastInsertId, 62 rowsAffected: test.rowsAffected, 63 err: test.err, 64 delay: test.delay, 65 }, test.query, args...) 66 }, 67 func(args ...interface{}) (sql.Result, error) { 68 return execStmt(&mockedStmtConn{ 69 lastInsertId: test.lastInsertId, 70 rowsAffected: test.rowsAffected, 71 err: test.err, 72 delay: test.delay, 73 }, test.query, args...) 74 }, 75 } 76 77 for _, fn := range fns { 78 fn := fn 79 t.Run(test.name, func(t *testing.T) { 80 t.Parallel() 81 82 res, err := fn(test.args...) 83 if test.hasError { 84 assert.NotNil(t, err) 85 return 86 } 87 88 assert.Nil(t, err) 89 lastInsertId, err := res.LastInsertId() 90 assert.Nil(t, err) 91 assert.Equal(t, test.lastInsertId, lastInsertId) 92 rowsAffected, err := res.RowsAffected() 93 assert.Nil(t, err) 94 assert.Equal(t, test.rowsAffected, rowsAffected) 95 }) 96 } 97 } 98 } 99 100 func TestStmt_query(t *testing.T) { 101 tests := []struct { 102 name string 103 query string 104 args []interface{} 105 delay bool 106 hasError bool 107 err error 108 }{ 109 { 110 name: "normal", 111 query: "select user from users where id=?", 112 args: []interface{}{1}, 113 }, 114 { 115 name: "query error", 116 query: "select user from users where id=?", 117 args: []interface{}{1}, 118 hasError: true, 119 err: errors.New("exec"), 120 }, 121 { 122 name: "query more args error", 123 query: "select user from users where id=? and name=?", 124 args: []interface{}{1}, 125 hasError: true, 126 err: errors.New("exec"), 127 }, 128 { 129 name: "slowcall", 130 query: "select user from users where id=?", 131 args: []interface{}{1}, 132 delay: true, 133 }, 134 } 135 136 for _, test := range tests { 137 test := test 138 fns := []func(args ...interface{}) error{ 139 func(args ...interface{}) error { 140 return query(&mockedSessionConn{ 141 err: test.err, 142 delay: test.delay, 143 }, func(rows *sql.Rows) error { 144 return nil 145 }, test.query, args...) 146 }, 147 func(args ...interface{}) error { 148 return queryStmt(&mockedStmtConn{ 149 err: test.err, 150 delay: test.delay, 151 }, func(rows *sql.Rows) error { 152 return nil 153 }, test.query, args...) 154 }, 155 } 156 157 for _, fn := range fns { 158 fn := fn 159 t.Run(test.name, func(t *testing.T) { 160 t.Parallel() 161 162 err := fn(test.args...) 163 if test.hasError { 164 assert.NotNil(t, err) 165 return 166 } 167 168 assert.NotNil(t, err) 169 }) 170 } 171 } 172 } 173 174 func TestSetSlowThreshold(t *testing.T) { 175 assert.Equal(t, defaultSlowThreshold, slowThreshold.Load()) 176 SetSlowThreshold(time.Second) 177 assert.Equal(t, time.Second, slowThreshold.Load()) 178 } 179 180 type mockedSessionConn struct { 181 lastInsertId int64 182 rowsAffected int64 183 err error 184 delay bool 185 } 186 187 func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) { 188 if m.delay { 189 time.Sleep(defaultSlowThreshold + time.Millisecond) 190 } 191 return mockedResult{ 192 lastInsertId: m.lastInsertId, 193 rowsAffected: m.rowsAffected, 194 }, m.err 195 } 196 197 func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) { 198 if m.delay { 199 time.Sleep(defaultSlowThreshold + time.Millisecond) 200 } 201 202 err := errMockedPlaceholder 203 if m.err != nil { 204 err = m.err 205 } 206 return new(sql.Rows), err 207 } 208 209 type mockedStmtConn struct { 210 lastInsertId int64 211 rowsAffected int64 212 err error 213 delay bool 214 } 215 216 func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) { 217 if m.delay { 218 time.Sleep(defaultSlowThreshold + time.Millisecond) 219 } 220 return mockedResult{ 221 lastInsertId: m.lastInsertId, 222 rowsAffected: m.rowsAffected, 223 }, m.err 224 } 225 226 func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) { 227 if m.delay { 228 time.Sleep(defaultSlowThreshold + time.Millisecond) 229 } 230 231 err := errMockedPlaceholder 232 if m.err != nil { 233 err = m.err 234 } 235 return new(sql.Rows), err 236 } 237 238 type mockedResult struct { 239 lastInsertId int64 240 rowsAffected int64 241 } 242 243 func (m mockedResult) LastInsertId() (int64, error) { 244 return m.lastInsertId, nil 245 } 246 247 func (m mockedResult) RowsAffected() (int64, error) { 248 return m.rowsAffected, nil 249 }