github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/single/db_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package single
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"database/sql/driver"
    21  	"errors"
    22  	"fmt"
    23  	"testing"
    24  
    25  	"github.com/ecodeclub/eorm/internal/datasource"
    26  
    27  	"github.com/DATA-DOG/go-sqlmock"
    28  	_ "github.com/mattn/go-sqlite3"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/suite"
    31  )
    32  
    33  type SingleSuite struct {
    34  	suite.Suite
    35  	mockDB *sql.DB
    36  	mock   sqlmock.Sqlmock
    37  }
    38  
    39  func (s *SingleSuite) SetupTest() {
    40  	t := s.T()
    41  	s.initMock(t)
    42  }
    43  
    44  func (s *SingleSuite) TearDownTest() {
    45  	_ = s.mockDB.Close()
    46  }
    47  
    48  func (s *SingleSuite) initMock(t *testing.T) {
    49  	var err error
    50  	s.mockDB, s.mock, err = sqlmock.New()
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  }
    55  
    56  func (s *SingleSuite) TestDBQuery() {
    57  	//s.mock.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows([]string{"mark"}).AddRow("value"))
    58  
    59  	testCases := []struct {
    60  		name     string
    61  		query    datasource.Query
    62  		mockRows *sqlmock.Rows
    63  		wantResp []string
    64  		wantErr  error
    65  	}{
    66  		{
    67  			name: "one row",
    68  			query: datasource.Query{
    69  				SQL: "SELECT `first_name` FROM `test_model`",
    70  			},
    71  			mockRows: sqlmock.NewRows([]string{"first_name"}).AddRow("value"),
    72  			wantResp: []string{"value"},
    73  		},
    74  		{
    75  			name: "multi row",
    76  			query: datasource.Query{
    77  				SQL: "SELECT `first_name` FROM `test_model`",
    78  			},
    79  			mockRows: func() *sqlmock.Rows {
    80  				res := sqlmock.NewRows([]string{"first_name"})
    81  				res.AddRow("value1")
    82  				res.AddRow("value2")
    83  				return res
    84  			}(),
    85  			wantResp: []string{"value1", "value2"},
    86  		},
    87  	}
    88  	for _, tc := range testCases {
    89  		s.mock.ExpectQuery(tc.query.SQL).WillReturnRows(tc.mockRows)
    90  	}
    91  	for _, tc := range testCases {
    92  		s.T().Run(tc.name, func(t *testing.T) {
    93  			db := NewDB(s.mockDB)
    94  			rows, queryErr := db.Query(context.Background(), tc.query)
    95  			assert.Equal(t, queryErr, tc.wantErr)
    96  			if queryErr != nil {
    97  				return
    98  			}
    99  			assert.NotNil(t, rows)
   100  			var resp []string
   101  			for rows.Next() {
   102  				val := new(string)
   103  				err := rows.Scan(val)
   104  				assert.Nil(t, err)
   105  				if err != nil {
   106  					return
   107  				}
   108  				assert.NotNil(t, val)
   109  				resp = append(resp, *val)
   110  			}
   111  
   112  			assert.ElementsMatch(t, tc.wantResp, resp)
   113  		})
   114  	}
   115  }
   116  
   117  func (s *SingleSuite) TestDBExec() {
   118  	testCases := []struct {
   119  		name         string
   120  		lastInsertId int64
   121  		rowsAffected int64
   122  		wantErr      error
   123  		mockResult   driver.Result
   124  		query        datasource.Query
   125  	}{
   126  		{
   127  			name: "res 1",
   128  			query: datasource.Query{
   129  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4)",
   130  			},
   131  			mockResult: func() driver.Result {
   132  				return sqlmock.NewResult(2, 1)
   133  			}(),
   134  			lastInsertId: int64(2),
   135  			rowsAffected: int64(1),
   136  		},
   137  		{
   138  			name: "res 2",
   139  			query: datasource.Query{
   140  				SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES(1,2,3,4) (1,2,3,4)",
   141  			},
   142  			mockResult: func() driver.Result {
   143  				return sqlmock.NewResult(4, 2)
   144  			}(),
   145  			lastInsertId: int64(4),
   146  			rowsAffected: int64(2),
   147  		},
   148  	}
   149  	for _, tc := range testCases {
   150  		s.mock.ExpectExec("^INSERT INTO (.+)").WillReturnResult(tc.mockResult)
   151  	}
   152  	for _, tc := range testCases {
   153  		s.T().Run(tc.name, func(t *testing.T) {
   154  			db := NewDB(s.mockDB)
   155  			res, err := db.Exec(context.Background(), tc.query)
   156  			assert.Nil(t, err)
   157  			lastInsertId, err := res.LastInsertId()
   158  			assert.Nil(t, err)
   159  			assert.EqualValues(t, tc.lastInsertId, lastInsertId)
   160  			rowsAffected, err := res.RowsAffected()
   161  			assert.Nil(t, err)
   162  			assert.EqualValues(t, tc.rowsAffected, rowsAffected)
   163  		})
   164  	}
   165  }
   166  
   167  func TestSingleSuite(t *testing.T) {
   168  	suite.Run(t, &SingleSuite{})
   169  }
   170  
   171  func TestDB_BeginTx(t *testing.T) {
   172  	mockDB, mock, err := sqlmock.New()
   173  	if err != nil {
   174  		t.Fatal(err)
   175  	}
   176  	defer func() { _ = mockDB.Close() }()
   177  
   178  	db := NewDB(mockDB)
   179  	// Begin 失败
   180  	mock.ExpectBegin().WillReturnError(errors.New("begin failed"))
   181  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
   182  	assert.Equal(t, errors.New("begin failed"), err)
   183  	assert.Nil(t, tx)
   184  
   185  	mock.ExpectBegin()
   186  	tx, err = db.BeginTx(context.Background(), &sql.TxOptions{})
   187  	assert.Nil(t, err)
   188  	assert.NotNil(t, tx)
   189  }
   190  
   191  func TestDB_Wait(t *testing.T) {
   192  	mockDB, mock, err := sqlmock.New()
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	defer func() { _ = mockDB.Close() }()
   197  
   198  	db := NewDB(mockDB)
   199  	if err != nil {
   200  		t.Fatal(err)
   201  	}
   202  	mock.ExpectPing()
   203  	err = db.Wait()
   204  	assert.Nil(t, err)
   205  }
   206  
   207  func ExampleDB_BeginTx() {
   208  	db, _ := OpenDB("sqlite3", "file:test.db?cache=shared&mode=memory")
   209  	defer func() {
   210  		_ = db.Close()
   211  	}()
   212  	tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
   213  	if err == nil {
   214  		fmt.Println("Begin")
   215  	}
   216  	// 或者 tx.Rollback()
   217  	err = tx.Commit()
   218  	if err == nil {
   219  		fmt.Println("Commit")
   220  	}
   221  	// Output:
   222  	// Begin
   223  	// Commit
   224  }
   225  
   226  func ExampleDB_Close() {
   227  	db, _ := OpenDB("sqlite3", "file:test.db?cache=shared&mode=memory")
   228  	err := db.Close()
   229  	if err == nil {
   230  		fmt.Println("close")
   231  	}
   232  
   233  	// Output:
   234  	// close
   235  }