github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/stmt_test.go (about) 1 package sqlx 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "testing" 8 "time" 9 10 "github.com/stretchr/testify/assert" 11 ) 12 13 var errMockedPlaceholder = errors.New("placeholder") 14 15 func TestStmt_exec(t *testing.T) { 16 tests := []struct { 17 name string 18 query string 19 args []interface{} 20 delay bool 21 hasError bool 22 err error 23 lastInsertId int64 24 rowsAffected int64 25 }{ 26 { 27 name: "normal", 28 query: "select user from users where id=?", 29 args: []interface{}{1}, 30 lastInsertId: 1, 31 rowsAffected: 2, 32 }, 33 { 34 name: "exec error", 35 query: "select user from users where id=?", 36 args: []interface{}{1}, 37 hasError: true, 38 err: errors.New("exec"), 39 }, 40 { 41 name: "exec more args error", 42 query: "select user from users where id=? and name=?", 43 args: []interface{}{1}, 44 hasError: true, 45 err: errors.New("exec"), 46 }, 47 { 48 name: "slowcall", 49 query: "select user from users where id=?", 50 args: []interface{}{1}, 51 delay: true, 52 lastInsertId: 1, 53 rowsAffected: 2, 54 }, 55 } 56 57 for _, test := range tests { 58 test := test 59 fns := []func(args ...interface{}) (sql.Result, error){ 60 func(args ...interface{}) (sql.Result, error) { 61 return exec(context.Background(), &mockedSessionConn{ 62 lastInsertId: test.lastInsertId, 63 rowsAffected: test.rowsAffected, 64 err: test.err, 65 delay: test.delay, 66 }, test.query, args...) 67 }, 68 func(args ...interface{}) (sql.Result, error) { 69 return execStmt(context.Background(), &mockedStmtConn{ 70 lastInsertId: test.lastInsertId, 71 rowsAffected: test.rowsAffected, 72 err: test.err, 73 delay: test.delay, 74 }, test.query, args...) 75 }, 76 } 77 78 for _, fn := range fns { 79 fn := fn 80 t.Run(test.name, func(t *testing.T) { 81 t.Parallel() 82 83 res, err := fn(test.args...) 84 if test.hasError { 85 assert.NotNil(t, err) 86 return 87 } 88 89 assert.Nil(t, err) 90 lastInsertId, err := res.LastInsertId() 91 assert.Nil(t, err) 92 assert.Equal(t, test.lastInsertId, lastInsertId) 93 rowsAffected, err := res.RowsAffected() 94 assert.Nil(t, err) 95 assert.Equal(t, test.rowsAffected, rowsAffected) 96 }) 97 } 98 } 99 } 100 101 func TestStmt_query(t *testing.T) { 102 tests := []struct { 103 name string 104 query string 105 args []interface{} 106 delay bool 107 hasError bool 108 err error 109 }{ 110 { 111 name: "normal", 112 query: "select user from users where id=?", 113 args: []interface{}{1}, 114 }, 115 { 116 name: "query error", 117 query: "select user from users where id=?", 118 args: []interface{}{1}, 119 hasError: true, 120 err: errors.New("exec"), 121 }, 122 { 123 name: "query more args error", 124 query: "select user from users where id=? and name=?", 125 args: []interface{}{1}, 126 hasError: true, 127 err: errors.New("exec"), 128 }, 129 { 130 name: "slowcall", 131 query: "select user from users where id=?", 132 args: []interface{}{1}, 133 delay: true, 134 }, 135 } 136 137 for _, test := range tests { 138 test := test 139 fns := []func(args ...interface{}) error{ 140 func(args ...interface{}) error { 141 return query(context.Background(), &mockedSessionConn{ 142 err: test.err, 143 delay: test.delay, 144 }, func(rows *sql.Rows) error { 145 return nil 146 }, test.query, args...) 147 }, 148 func(args ...interface{}) error { 149 return queryStmt(context.Background(), &mockedStmtConn{ 150 err: test.err, 151 delay: test.delay, 152 }, func(rows *sql.Rows) error { 153 return nil 154 }, test.query, args...) 155 }, 156 } 157 158 for _, fn := range fns { 159 fn := fn 160 t.Run(test.name, func(t *testing.T) { 161 t.Parallel() 162 163 err := fn(test.args...) 164 if test.hasError { 165 assert.NotNil(t, err) 166 return 167 } 168 169 assert.NotNil(t, err) 170 }) 171 } 172 } 173 } 174 175 func TestSetSlowThreshold(t *testing.T) { 176 assert.Equal(t, defaultSlowThreshold, slowThreshold.Load()) 177 SetSlowThreshold(time.Second) 178 assert.Equal(t, time.Second, slowThreshold.Load()) 179 } 180 181 func TestDisableLog(t *testing.T) { 182 assert.True(t, logSql.True()) 183 assert.True(t, logSlowSql.True()) 184 defer func() { 185 logSql.Set(true) 186 logSlowSql.Set(true) 187 }() 188 189 DisableLog() 190 assert.False(t, logSql.True()) 191 assert.False(t, logSlowSql.True()) 192 } 193 194 func TestDisableStmtLog(t *testing.T) { 195 assert.True(t, logSql.True()) 196 assert.True(t, logSlowSql.True()) 197 defer func() { 198 logSql.Set(true) 199 logSlowSql.Set(true) 200 }() 201 202 DisableStmtLog() 203 assert.False(t, logSql.True()) 204 assert.True(t, logSlowSql.True()) 205 } 206 207 func TestNilGuard(t *testing.T) { 208 assert.True(t, logSql.True()) 209 assert.True(t, logSlowSql.True()) 210 defer func() { 211 logSql.Set(true) 212 logSlowSql.Set(true) 213 }() 214 215 DisableLog() 216 guard := newGuard("any") 217 assert.Nil(t, guard.start("foo", "bar")) 218 guard.finish(context.Background(), nil) 219 assert.Equal(t, nilGuard{}, guard) 220 } 221 222 type mockedSessionConn struct { 223 lastInsertId int64 224 rowsAffected int64 225 err error 226 delay bool 227 } 228 229 func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) { 230 return m.ExecContext(context.Background(), query, args...) 231 } 232 233 func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 234 if m.delay { 235 time.Sleep(defaultSlowThreshold + time.Millisecond) 236 } 237 return mockedResult{ 238 lastInsertId: m.lastInsertId, 239 rowsAffected: m.rowsAffected, 240 }, m.err 241 } 242 243 func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) { 244 return m.QueryContext(context.Background(), query, args...) 245 } 246 247 func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 248 if m.delay { 249 time.Sleep(defaultSlowThreshold + time.Millisecond) 250 } 251 252 err := errMockedPlaceholder 253 if m.err != nil { 254 err = m.err 255 } 256 return new(sql.Rows), err 257 } 258 259 type mockedStmtConn struct { 260 lastInsertId int64 261 rowsAffected int64 262 err error 263 delay bool 264 } 265 266 func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) { 267 return m.ExecContext(context.Background(), args...) 268 } 269 270 func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) { 271 if m.delay { 272 time.Sleep(defaultSlowThreshold + time.Millisecond) 273 } 274 return mockedResult{ 275 lastInsertId: m.lastInsertId, 276 rowsAffected: m.rowsAffected, 277 }, m.err 278 } 279 280 func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) { 281 return m.QueryContext(context.Background(), args...) 282 } 283 284 func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) { 285 if m.delay { 286 time.Sleep(defaultSlowThreshold + time.Millisecond) 287 } 288 289 err := errMockedPlaceholder 290 if m.err != nil { 291 err = m.err 292 } 293 return new(sql.Rows), err 294 } 295 296 type mockedResult struct { 297 lastInsertId int64 298 rowsAffected int64 299 } 300 301 func (m mockedResult) LastInsertId() (int64, error) { 302 return m.lastInsertId, nil 303 } 304 305 func (m mockedResult) RowsAffected() (int64, error) { 306 return m.rowsAffected, nil 307 }