github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/bulkinserter_test.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "errors" 6 "strconv" 7 "testing" 8 9 "github.com/DATA-DOG/go-sqlmock" 10 "github.com/stretchr/testify/assert" 11 "github.com/shuguocloud/go-zero/core/logx" 12 ) 13 14 type mockedConn struct { 15 query string 16 args []interface{} 17 execErr error 18 } 19 20 func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) { 21 c.query = query 22 c.args = args 23 return nil, c.execErr 24 } 25 26 func (c *mockedConn) Prepare(query string) (StmtSession, error) { 27 panic("should not called") 28 } 29 30 func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error { 31 panic("should not called") 32 } 33 34 func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error { 35 panic("should not called") 36 } 37 38 func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error { 39 panic("should not called") 40 } 41 42 func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error { 43 panic("should not called") 44 } 45 46 func (c *mockedConn) RawDB() (*sql.DB, error) { 47 panic("should not called") 48 } 49 50 func (c *mockedConn) Transact(func(session Session) error) error { 51 panic("should not called") 52 } 53 54 func TestBulkInserter(t *testing.T) { 55 runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { 56 var conn mockedConn 57 inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`) 58 assert.Nil(t, err) 59 for i := 0; i < 5; i++ { 60 assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i)) 61 } 62 inserter.Flush() 63 assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+ 64 `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+ 65 `('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`, 66 conn.query) 67 assert.Nil(t, conn.args) 68 }) 69 } 70 71 func TestBulkInserterSuffix(t *testing.T) { 72 runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { 73 var conn mockedConn 74 inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+ 75 `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`) 76 assert.Nil(t, err) 77 assert.Nil(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user, count) VALUES`+ 78 `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)) 79 for i := 0; i < 5; i++ { 80 assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i)) 81 } 82 inserter.SetResultHandler(func(result sql.Result, err error) {}) 83 inserter.Flush() 84 assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+ 85 `('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+ 86 `('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`, 87 conn.query) 88 assert.Nil(t, conn.args) 89 }) 90 } 91 92 func TestBulkInserterBadStatement(t *testing.T) { 93 runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { 94 var conn mockedConn 95 _, err := NewBulkInserter(&conn, "foo") 96 assert.NotNil(t, err) 97 }) 98 } 99 100 func TestBulkInserter_Update(t *testing.T) { 101 conn := mockedConn{ 102 execErr: errors.New("foo"), 103 } 104 _, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES()`) 105 assert.NotNil(t, err) 106 _, err = NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?)`) 107 assert.NotNil(t, err) 108 inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`) 109 assert.Nil(t, err) 110 inserter.inserter.Execute([]string{"bar"}) 111 inserter.SetResultHandler(func(result sql.Result, err error) { 112 }) 113 inserter.UpdateOrDelete(func() {}) 114 inserter.inserter.Execute([]string(nil)) 115 assert.NotNil(t, inserter.UpdateStmt("foo")) 116 assert.NotNil(t, inserter.Insert("foo", "bar")) 117 } 118 119 func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { 120 logx.Disable() 121 122 db, mock, err := sqlmock.New() 123 if err != nil { 124 t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 125 } 126 defer db.Close() 127 128 fn(db, mock) 129 130 if err := mock.ExpectationsWereMet(); err != nil { 131 t.Errorf("there were unfulfilled expectations: %s", err) 132 } 133 }