github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/bulkinserter_test.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"strconv"
     7  	"testing"
     8  
     9  	"github.com/DATA-DOG/go-sqlmock"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/shuguocloud/go-zero/core/logx"
    12  )
    13  
    14  type mockedConn struct {
    15  	query   string
    16  	args    []interface{}
    17  	execErr error
    18  }
    19  
    20  func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
    21  	c.query = query
    22  	c.args = args
    23  	return nil, c.execErr
    24  }
    25  
    26  func (c *mockedConn) Prepare(query string) (StmtSession, error) {
    27  	panic("should not called")
    28  }
    29  
    30  func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
    31  	panic("should not called")
    32  }
    33  
    34  func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
    35  	panic("should not called")
    36  }
    37  
    38  func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
    39  	panic("should not called")
    40  }
    41  
    42  func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
    43  	panic("should not called")
    44  }
    45  
    46  func (c *mockedConn) RawDB() (*sql.DB, error) {
    47  	panic("should not called")
    48  }
    49  
    50  func (c *mockedConn) Transact(func(session Session) error) error {
    51  	panic("should not called")
    52  }
    53  
    54  func TestBulkInserter(t *testing.T) {
    55  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
    56  		var conn mockedConn
    57  		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
    58  		assert.Nil(t, err)
    59  		for i := 0; i < 5; i++ {
    60  			assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
    61  		}
    62  		inserter.Flush()
    63  		assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
    64  			`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
    65  			`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
    66  			conn.query)
    67  		assert.Nil(t, conn.args)
    68  	})
    69  }
    70  
    71  func TestBulkInserterSuffix(t *testing.T) {
    72  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
    73  		var conn mockedConn
    74  		inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
    75  			`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
    76  		assert.Nil(t, err)
    77  		assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+
    78  			`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`))
    79  		for i := 0; i < 5; i++ {
    80  			assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
    81  		}
    82  		inserter.SetResultHandler(func(result sql.Result, err error) {})
    83  		inserter.Flush()
    84  		assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
    85  			`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
    86  			`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
    87  			conn.query)
    88  		assert.Nil(t, conn.args)
    89  	})
    90  }
    91  
    92  func TestBulkInserterBadStatement(t *testing.T) {
    93  	runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
    94  		var conn mockedConn
    95  		_, err := NewBulkInserter(&conn, "foo")
    96  		assert.NotNil(t, err)
    97  	})
    98  }
    99  
   100  func TestBulkInserter_Update(t *testing.T) {
   101  	conn := mockedConn{
   102  		execErr: errors.New("foo"),
   103  	}
   104  	_, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`)
   105  	assert.NotNil(t, err)
   106  	_, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`)
   107  	assert.NotNil(t, err)
   108  	inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
   109  	assert.Nil(t, err)
   110  	inserter.inserter.Execute([]string{"bar"})
   111  	inserter.SetResultHandler(func(result sql.Result, err error) {
   112  	})
   113  	inserter.UpdateOrDelete(func() {})
   114  	inserter.inserter.Execute([]string(nil))
   115  	assert.NotNil(t, inserter.UpdateStmt("foo"))
   116  	assert.NotNil(t, inserter.Insert("foo", "bar"))
   117  }
   118  
   119  func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
   120  	logx.Disable()
   121  
   122  	db, mock, err := sqlmock.New()
   123  	if err != nil {
   124  		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
   125  	}
   126  	defer db.Close()
   127  
   128  	fn(db, mock)
   129  
   130  	if err := mock.ExpectationsWereMet(); err != nil {
   131  		t.Errorf("there were unfulfilled expectations: %s", err)
   132  	}
   133  }