github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/mock/database/manager.go (about)

     1  package database
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"github.com/DATA-DOG/go-sqlmock"
     8  	"github.com/johnnyeven/libtools/sqlx"
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  type MockDB struct {
    13  	mockDB *sql.DB
    14  	mock   sqlmock.Sqlmock
    15  	db     *sqlx.DB
    16  }
    17  
    18  func (m *MockDB) Init() error {
    19  	var err error
    20  	m.mockDB, m.mock, err = sqlmock.New()
    21  	if err != nil {
    22  		return err
    23  	}
    24  
    25  	return nil
    26  }
    27  
    28  func (m *MockDB) Open() *sqlx.DB {
    29  	if m.db == nil {
    30  		m.db, _ = sqlx.Open("", "", func(driverName string, dataSourceName string) (db *sql.DB, err error) {
    31  			return m.mockDB, nil
    32  		})
    33  	}
    34  
    35  	return m.db
    36  }
    37  
    38  func (m *MockDB) GetDB() *sqlx.DB {
    39  	return m.db
    40  }
    41  
    42  func (m *MockDB) LoadTestSuite(path string) error {
    43  	queries := make([]Query, 0)
    44  	err := LoadAndParse(path, &queries)
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	for _, q := range queries {
    50  		switch q.Type {
    51  		case "begin":
    52  			m.mock.ExpectBegin()
    53  		case "commit":
    54  			m.mock.ExpectCommit()
    55  		case "rollback":
    56  			m.mock.ExpectRollback()
    57  		case "exec":
    58  			exec := m.mock.ExpectExec(q.ExpectedSQLKeyWord)
    59  			if q.WithArgs != nil {
    60  				exec = exec.WithArgs(convertDBValue(q.WithArgs)...)
    61  			}
    62  			if q.ReturnError != "" {
    63  				exec.WillReturnError(errors.Errorf(q.ReturnError))
    64  			} else if q.ReturnResult != nil {
    65  				exec.WillReturnResult(sqlmock.NewResult(q.ReturnResult.LastInsertID, q.ReturnResult.RowsEffected))
    66  			} else {
    67  				return fmt.Errorf("error or result are all nil")
    68  			}
    69  		case "query":
    70  			query := m.mock.ExpectQuery(q.ExpectedSQLKeyWord)
    71  			if q.WithArgs != nil {
    72  				query = query.WithArgs(convertDBValue(q.WithArgs)...)
    73  			}
    74  			if q.ReturnError != "" {
    75  				query.WillReturnError(errors.Errorf(q.ReturnError))
    76  			} else if q.ReturnRows != nil {
    77  				rows := sqlmock.NewRows(q.ReturnRows.Columns)
    78  				for _, r := range q.ReturnRows.Rows {
    79  					values := make([]driver.Value, 0)
    80  					for _, v := range r {
    81  						values = append(values, v)
    82  					}
    83  					rows.AddRow(values...)
    84  				}
    85  				query.WillReturnRows(rows)
    86  			} else {
    87  				return fmt.Errorf("error or rows are all nil")
    88  			}
    89  		default:
    90  			return fmt.Errorf("not supported type %s", q.Type)
    91  		}
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func convertDBValue(data []interface{}) []driver.Value {
    98  	args := make([]driver.Value, 0)
    99  	for _, a := range data {
   100  		var v driver.Value
   101  		f := a.(float64)
   102  		if isValidFloatValue(f) {
   103  			v = f
   104  		} else {
   105  			v = int64(f)
   106  		}
   107  		args = append(args, v)
   108  	}
   109  
   110  	return args
   111  }
   112  
   113  func isValidFloatValue(v float64) bool {
   114  	compare := v
   115  	if v == float64(int64(compare)) {
   116  		return false
   117  	}
   118  
   119  	return true
   120  }