github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/stmt_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  var errMockedPlaceholder = errors.New("placeholder")
    14  
    15  func TestStmt_exec(t *testing.T) {
    16  	tests := []struct {
    17  		name         string
    18  		query        string
    19  		args         []interface{}
    20  		delay        bool
    21  		hasError     bool
    22  		err          error
    23  		lastInsertId int64
    24  		rowsAffected int64
    25  	}{
    26  		{
    27  			name:         "normal",
    28  			query:        "select user from users where id=?",
    29  			args:         []interface{}{1},
    30  			lastInsertId: 1,
    31  			rowsAffected: 2,
    32  		},
    33  		{
    34  			name:     "exec error",
    35  			query:    "select user from users where id=?",
    36  			args:     []interface{}{1},
    37  			hasError: true,
    38  			err:      errors.New("exec"),
    39  		},
    40  		{
    41  			name:     "exec more args error",
    42  			query:    "select user from users where id=? and name=?",
    43  			args:     []interface{}{1},
    44  			hasError: true,
    45  			err:      errors.New("exec"),
    46  		},
    47  		{
    48  			name:         "slowcall",
    49  			query:        "select user from users where id=?",
    50  			args:         []interface{}{1},
    51  			delay:        true,
    52  			lastInsertId: 1,
    53  			rowsAffected: 2,
    54  		},
    55  	}
    56  
    57  	for _, test := range tests {
    58  		test := test
    59  		fns := []func(args ...interface{}) (sql.Result, error){
    60  			func(args ...interface{}) (sql.Result, error) {
    61  				return exec(context.Background(), &mockedSessionConn{
    62  					lastInsertId: test.lastInsertId,
    63  					rowsAffected: test.rowsAffected,
    64  					err:          test.err,
    65  					delay:        test.delay,
    66  				}, test.query, args...)
    67  			},
    68  			func(args ...interface{}) (sql.Result, error) {
    69  				return execStmt(context.Background(), &mockedStmtConn{
    70  					lastInsertId: test.lastInsertId,
    71  					rowsAffected: test.rowsAffected,
    72  					err:          test.err,
    73  					delay:        test.delay,
    74  				}, test.query, args...)
    75  			},
    76  		}
    77  
    78  		for _, fn := range fns {
    79  			fn := fn
    80  			t.Run(test.name, func(t *testing.T) {
    81  				t.Parallel()
    82  
    83  				res, err := fn(test.args...)
    84  				if test.hasError {
    85  					assert.NotNil(t, err)
    86  					return
    87  				}
    88  
    89  				assert.Nil(t, err)
    90  				lastInsertId, err := res.LastInsertId()
    91  				assert.Nil(t, err)
    92  				assert.Equal(t, test.lastInsertId, lastInsertId)
    93  				rowsAffected, err := res.RowsAffected()
    94  				assert.Nil(t, err)
    95  				assert.Equal(t, test.rowsAffected, rowsAffected)
    96  			})
    97  		}
    98  	}
    99  }
   100  
   101  func TestStmt_query(t *testing.T) {
   102  	tests := []struct {
   103  		name     string
   104  		query    string
   105  		args     []interface{}
   106  		delay    bool
   107  		hasError bool
   108  		err      error
   109  	}{
   110  		{
   111  			name:  "normal",
   112  			query: "select user from users where id=?",
   113  			args:  []interface{}{1},
   114  		},
   115  		{
   116  			name:     "query error",
   117  			query:    "select user from users where id=?",
   118  			args:     []interface{}{1},
   119  			hasError: true,
   120  			err:      errors.New("exec"),
   121  		},
   122  		{
   123  			name:     "query more args error",
   124  			query:    "select user from users where id=? and name=?",
   125  			args:     []interface{}{1},
   126  			hasError: true,
   127  			err:      errors.New("exec"),
   128  		},
   129  		{
   130  			name:  "slowcall",
   131  			query: "select user from users where id=?",
   132  			args:  []interface{}{1},
   133  			delay: true,
   134  		},
   135  	}
   136  
   137  	for _, test := range tests {
   138  		test := test
   139  		fns := []func(args ...interface{}) error{
   140  			func(args ...interface{}) error {
   141  				return query(context.Background(), &mockedSessionConn{
   142  					err:   test.err,
   143  					delay: test.delay,
   144  				}, func(rows *sql.Rows) error {
   145  					return nil
   146  				}, test.query, args...)
   147  			},
   148  			func(args ...interface{}) error {
   149  				return queryStmt(context.Background(), &mockedStmtConn{
   150  					err:   test.err,
   151  					delay: test.delay,
   152  				}, func(rows *sql.Rows) error {
   153  					return nil
   154  				}, test.query, args...)
   155  			},
   156  		}
   157  
   158  		for _, fn := range fns {
   159  			fn := fn
   160  			t.Run(test.name, func(t *testing.T) {
   161  				t.Parallel()
   162  
   163  				err := fn(test.args...)
   164  				if test.hasError {
   165  					assert.NotNil(t, err)
   166  					return
   167  				}
   168  
   169  				assert.NotNil(t, err)
   170  			})
   171  		}
   172  	}
   173  }
   174  
   175  func TestSetSlowThreshold(t *testing.T) {
   176  	assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
   177  	SetSlowThreshold(time.Second)
   178  	assert.Equal(t, time.Second, slowThreshold.Load())
   179  }
   180  
   181  func TestDisableLog(t *testing.T) {
   182  	assert.True(t, logSql.True())
   183  	assert.True(t, logSlowSql.True())
   184  	defer func() {
   185  		logSql.Set(true)
   186  		logSlowSql.Set(true)
   187  	}()
   188  
   189  	DisableLog()
   190  	assert.False(t, logSql.True())
   191  	assert.False(t, logSlowSql.True())
   192  }
   193  
   194  func TestDisableStmtLog(t *testing.T) {
   195  	assert.True(t, logSql.True())
   196  	assert.True(t, logSlowSql.True())
   197  	defer func() {
   198  		logSql.Set(true)
   199  		logSlowSql.Set(true)
   200  	}()
   201  
   202  	DisableStmtLog()
   203  	assert.False(t, logSql.True())
   204  	assert.True(t, logSlowSql.True())
   205  }
   206  
   207  func TestNilGuard(t *testing.T) {
   208  	assert.True(t, logSql.True())
   209  	assert.True(t, logSlowSql.True())
   210  	defer func() {
   211  		logSql.Set(true)
   212  		logSlowSql.Set(true)
   213  	}()
   214  
   215  	DisableLog()
   216  	guard := newGuard("any")
   217  	assert.Nil(t, guard.start("foo", "bar"))
   218  	guard.finish(context.Background(), nil)
   219  	assert.Equal(t, nilGuard{}, guard)
   220  }
   221  
   222  type mockedSessionConn struct {
   223  	lastInsertId int64
   224  	rowsAffected int64
   225  	err          error
   226  	delay        bool
   227  }
   228  
   229  func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
   230  	return m.ExecContext(context.Background(), query, args...)
   231  }
   232  
   233  func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   234  	if m.delay {
   235  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   236  	}
   237  	return mockedResult{
   238  		lastInsertId: m.lastInsertId,
   239  		rowsAffected: m.rowsAffected,
   240  	}, m.err
   241  }
   242  
   243  func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
   244  	return m.QueryContext(context.Background(), query, args...)
   245  }
   246  
   247  func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   248  	if m.delay {
   249  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   250  	}
   251  
   252  	err := errMockedPlaceholder
   253  	if m.err != nil {
   254  		err = m.err
   255  	}
   256  	return new(sql.Rows), err
   257  }
   258  
   259  type mockedStmtConn struct {
   260  	lastInsertId int64
   261  	rowsAffected int64
   262  	err          error
   263  	delay        bool
   264  }
   265  
   266  func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
   267  	return m.ExecContext(context.Background(), args...)
   268  }
   269  
   270  func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
   271  	if m.delay {
   272  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   273  	}
   274  	return mockedResult{
   275  		lastInsertId: m.lastInsertId,
   276  		rowsAffected: m.rowsAffected,
   277  	}, m.err
   278  }
   279  
   280  func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
   281  	return m.QueryContext(context.Background(), args...)
   282  }
   283  
   284  func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
   285  	if m.delay {
   286  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   287  	}
   288  
   289  	err := errMockedPlaceholder
   290  	if m.err != nil {
   291  		err = m.err
   292  	}
   293  	return new(sql.Rows), err
   294  }
   295  
   296  type mockedResult struct {
   297  	lastInsertId int64
   298  	rowsAffected int64
   299  }
   300  
   301  func (m mockedResult) LastInsertId() (int64, error) {
   302  	return m.lastInsertId, nil
   303  }
   304  
   305  func (m mockedResult) RowsAffected() (int64, error) {
   306  	return m.rowsAffected, nil
   307  }