github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/bulkinserter_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"strconv"
     8  	"testing"
     9  
    10  	"github.com/DATA-DOG/go-sqlmock"
    11  	"github.com/lingyao2333/mo-zero/core/logx"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  type mockedConn struct {
    16  	query   string
    17  	args    []interface{}
    18  	execErr error
    19  }
    20  
    21  func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) {
    22  	c.query = query
    23  	c.args = args
    24  	return nil, c.execErr
    25  }
    26  
    27  func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
    28  	panic("implement me")
    29  }
    30  
    31  func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    32  	panic("implement me")
    33  }
    34  
    35  func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    36  	panic("implement me")
    37  }
    38  
    39  func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    40  	panic("implement me")
    41  }
    42  
    43  func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
    44  	panic("implement me")
    45  }
    46  
    47  func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
    48  	panic("should not called")
    49  }
    50  
    51  func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
    52  	return c.ExecCtx(context.Background(), query, args...)
    53  }
    54  
    55  func (c *mockedConn) Prepare(query string) (StmtSession, error) {
    56  	panic("should not called")
    57  }
    58  
    59  func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
    60  	panic("should not called")
    61  }
    62  
    63  func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
    64  	panic("should not called")
    65  }
    66  
    67  func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
    68  	panic("should not called")
    69  }
    70  
    71  func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
    72  	panic("should not called")
    73  }
    74  
    75  func (c *mockedConn) RawDB() (*sql.DB, error) {
    76  	panic("should not called")
    77  }
    78  
    79  func (c *mockedConn) Transact(func(session Session) error) error {
    80  	panic("should not called")
    81  }
    82  
    83  func TestBulkInserter(t *testing.T) {
    84  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
    85  		var conn mockedConn
    86  		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
    87  		assert.Nil(t, err)
    88  		for i := 0; i < 5; i++ {
    89  			assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
    90  		}
    91  		inserter.Flush()
    92  		assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
    93  			`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
    94  			`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
    95  			conn.query)
    96  		assert.Nil(t, conn.args)
    97  	})
    98  }
    99  
   100  func TestBulkInserterSuffix(t *testing.T) {
   101  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
   102  		var conn mockedConn
   103  		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
   104  			`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
   105  		assert.Nil(t, err)
   106  		assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
   107  			`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
   108  		for i := 0; i < 5; i++ {
   109  			assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
   110  		}
   111  		inserter.SetResultHandler(func(result sql.Result, err error) {})
   112  		inserter.Flush()
   113  		assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
   114  			`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
   115  			`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
   116  			conn.query)
   117  		assert.Nil(t, conn.args)
   118  	})
   119  }
   120  
   121  func TestBulkInserterBadStatement(t *testing.T) {
   122  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
   123  		var conn mockedConn
   124  		_, err := NewBulkInserter(&conn, "foo")
   125  		assert.NotNil(t, err)
   126  	})
   127  }
   128  
   129  func TestBulkInserter_Update(t *testing.T) {
   130  	conn := mockedConn{
   131  		execErr: errors.New("foo"),
   132  	}
   133  	_, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
   134  	assert.NotNil(t, err)
   135  	_, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
   136  	assert.NotNil(t, err)
   137  	inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
   138  	assert.Nil(t, err)
   139  	inserter.inserter.Execute([]string{"bar"})
   140  	inserter.SetResultHandler(func(result sql.Result, err error) {
   141  	})
   142  	inserter.UpdateOrDelete(func() {})
   143  	inserter.inserter.Execute([]string(nil))
   144  	assert.NotNil(t, inserter.UpdateStmt("foo"))
   145  	assert.NotNil(t, inserter.Insert("foo", "bar"))
   146  }
   147  
   148  func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
   149  	logx.Disable()
   150  
   151  	db, mock, err := sqlmock.New()
   152  	if err != nil {
   153  		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
   154  	}
   155  	defer db.Close()
   156  
   157  	fn(db, mock)
   158  
   159  	if err := mock.ExpectationsWereMet(); err != nil {
   160  		t.Errorf("there were unfulfilled expectations: %s", err)
   161  	}
   162  }