github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/stmt_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  var errMockedPlaceholder = errors.New("placeholder")
    13  
    14  func TestStmt_exec(t *testing.T) {
    15  	tests := []struct {
    16  		name         string
    17  		query        string
    18  		args         []interface{}
    19  		delay        bool
    20  		hasError     bool
    21  		err          error
    22  		lastInsertId int64
    23  		rowsAffected int64
    24  	}{
    25  		{
    26  			name:         "normal",
    27  			query:        "select user from users where id=?",
    28  			args:         []interface{}{1},
    29  			lastInsertId: 1,
    30  			rowsAffected: 2,
    31  		},
    32  		{
    33  			name:     "exec error",
    34  			query:    "select user from users where id=?",
    35  			args:     []interface{}{1},
    36  			hasError: true,
    37  			err:      errors.New("exec"),
    38  		},
    39  		{
    40  			name:     "exec more args error",
    41  			query:    "select user from users where id=? and name=?",
    42  			args:     []interface{}{1},
    43  			hasError: true,
    44  			err:      errors.New("exec"),
    45  		},
    46  		{
    47  			name:         "slowcall",
    48  			query:        "select user from users where id=?",
    49  			args:         []interface{}{1},
    50  			delay:        true,
    51  			lastInsertId: 1,
    52  			rowsAffected: 2,
    53  		},
    54  	}
    55  
    56  	for _, test := range tests {
    57  		test := test
    58  		fns := []func(args ...interface{}) (sql.Result, error){
    59  			func(args ...interface{}) (sql.Result, error) {
    60  				return exec(&mockedSessionConn{
    61  					lastInsertId: test.lastInsertId,
    62  					rowsAffected: test.rowsAffected,
    63  					err:          test.err,
    64  					delay:        test.delay,
    65  				}, test.query, args...)
    66  			},
    67  			func(args ...interface{}) (sql.Result, error) {
    68  				return execStmt(&mockedStmtConn{
    69  					lastInsertId: test.lastInsertId,
    70  					rowsAffected: test.rowsAffected,
    71  					err:          test.err,
    72  					delay:        test.delay,
    73  				}, test.query, args...)
    74  			},
    75  		}
    76  
    77  		for _, fn := range fns {
    78  			fn := fn
    79  			t.Run(test.name, func(t *testing.T) {
    80  				t.Parallel()
    81  
    82  				res, err := fn(test.args...)
    83  				if test.hasError {
    84  					assert.NotNil(t, err)
    85  					return
    86  				}
    87  
    88  				assert.Nil(t, err)
    89  				lastInsertId, err := res.LastInsertId()
    90  				assert.Nil(t, err)
    91  				assert.Equal(t, test.lastInsertId, lastInsertId)
    92  				rowsAffected, err := res.RowsAffected()
    93  				assert.Nil(t, err)
    94  				assert.Equal(t, test.rowsAffected, rowsAffected)
    95  			})
    96  		}
    97  	}
    98  }
    99  
   100  func TestStmt_query(t *testing.T) {
   101  	tests := []struct {
   102  		name     string
   103  		query    string
   104  		args     []interface{}
   105  		delay    bool
   106  		hasError bool
   107  		err      error
   108  	}{
   109  		{
   110  			name:  "normal",
   111  			query: "select user from users where id=?",
   112  			args:  []interface{}{1},
   113  		},
   114  		{
   115  			name:     "query error",
   116  			query:    "select user from users where id=?",
   117  			args:     []interface{}{1},
   118  			hasError: true,
   119  			err:      errors.New("exec"),
   120  		},
   121  		{
   122  			name:     "query more args error",
   123  			query:    "select user from users where id=? and name=?",
   124  			args:     []interface{}{1},
   125  			hasError: true,
   126  			err:      errors.New("exec"),
   127  		},
   128  		{
   129  			name:  "slowcall",
   130  			query: "select user from users where id=?",
   131  			args:  []interface{}{1},
   132  			delay: true,
   133  		},
   134  	}
   135  
   136  	for _, test := range tests {
   137  		test := test
   138  		fns := []func(args ...interface{}) error{
   139  			func(args ...interface{}) error {
   140  				return query(&mockedSessionConn{
   141  					err:   test.err,
   142  					delay: test.delay,
   143  				}, func(rows *sql.Rows) error {
   144  					return nil
   145  				}, test.query, args...)
   146  			},
   147  			func(args ...interface{}) error {
   148  				return queryStmt(&mockedStmtConn{
   149  					err:   test.err,
   150  					delay: test.delay,
   151  				}, func(rows *sql.Rows) error {
   152  					return nil
   153  				}, test.query, args...)
   154  			},
   155  		}
   156  
   157  		for _, fn := range fns {
   158  			fn := fn
   159  			t.Run(test.name, func(t *testing.T) {
   160  				t.Parallel()
   161  
   162  				err := fn(test.args...)
   163  				if test.hasError {
   164  					assert.NotNil(t, err)
   165  					return
   166  				}
   167  
   168  				assert.NotNil(t, err)
   169  			})
   170  		}
   171  	}
   172  }
   173  
   174  func TestSetSlowThreshold(t *testing.T) {
   175  	assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
   176  	SetSlowThreshold(time.Second)
   177  	assert.Equal(t, time.Second, slowThreshold.Load())
   178  }
   179  
   180  type mockedSessionConn struct {
   181  	lastInsertId int64
   182  	rowsAffected int64
   183  	err          error
   184  	delay        bool
   185  }
   186  
   187  func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
   188  	if m.delay {
   189  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   190  	}
   191  	return mockedResult{
   192  		lastInsertId: m.lastInsertId,
   193  		rowsAffected: m.rowsAffected,
   194  	}, m.err
   195  }
   196  
   197  func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
   198  	if m.delay {
   199  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   200  	}
   201  
   202  	err := errMockedPlaceholder
   203  	if m.err != nil {
   204  		err = m.err
   205  	}
   206  	return new(sql.Rows), err
   207  }
   208  
   209  type mockedStmtConn struct {
   210  	lastInsertId int64
   211  	rowsAffected int64
   212  	err          error
   213  	delay        bool
   214  }
   215  
   216  func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
   217  	if m.delay {
   218  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   219  	}
   220  	return mockedResult{
   221  		lastInsertId: m.lastInsertId,
   222  		rowsAffected: m.rowsAffected,
   223  	}, m.err
   224  }
   225  
   226  func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
   227  	if m.delay {
   228  		time.Sleep(defaultSlowThreshold + time.Millisecond)
   229  	}
   230  
   231  	err := errMockedPlaceholder
   232  	if m.err != nil {
   233  		err = m.err
   234  	}
   235  	return new(sql.Rows), err
   236  }
   237  
   238  type mockedResult struct {
   239  	lastInsertId int64
   240  	rowsAffected int64
   241  }
   242  
   243  func (m mockedResult) LastInsertId() (int64, error) {
   244  	return m.lastInsertId, nil
   245  }
   246  
   247  func (m mockedResult) RowsAffected() (int64, error) {
   248  	return m.rowsAffected, nil
   249  }