github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/mock/database/manager.go (about) 1 package database 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "fmt" 7 "github.com/DATA-DOG/go-sqlmock" 8 "github.com/johnnyeven/libtools/sqlx" 9 "github.com/pkg/errors" 10 ) 11 12 type MockDB struct { 13 mockDB *sql.DB 14 mock sqlmock.Sqlmock 15 db *sqlx.DB 16 } 17 18 func (m *MockDB) Init() error { 19 var err error 20 m.mockDB, m.mock, err = sqlmock.New() 21 if err != nil { 22 return err 23 } 24 25 return nil 26 } 27 28 func (m *MockDB) Open() *sqlx.DB { 29 if m.db == nil { 30 m.db, _ = sqlx.Open("", "", func(driverName string, dataSourceName string) (db *sql.DB, err error) { 31 return m.mockDB, nil 32 }) 33 } 34 35 return m.db 36 } 37 38 func (m *MockDB) GetDB() *sqlx.DB { 39 return m.db 40 } 41 42 func (m *MockDB) LoadTestSuite(path string) error { 43 queries := make([]Query, 0) 44 err := LoadAndParse(path, &queries) 45 if err != nil { 46 return err 47 } 48 49 for _, q := range queries { 50 switch q.Type { 51 case "begin": 52 m.mock.ExpectBegin() 53 case "commit": 54 m.mock.ExpectCommit() 55 case "rollback": 56 m.mock.ExpectRollback() 57 case "exec": 58 exec := m.mock.ExpectExec(q.ExpectedSQLKeyWord) 59 if q.WithArgs != nil { 60 exec = exec.WithArgs(convertDBValue(q.WithArgs)...) 61 } 62 if q.ReturnError != "" { 63 exec.WillReturnError(errors.Errorf(q.ReturnError)) 64 } else if q.ReturnResult != nil { 65 exec.WillReturnResult(sqlmock.NewResult(q.ReturnResult.LastInsertID, q.ReturnResult.RowsEffected)) 66 } else { 67 return fmt.Errorf("error or result are all nil") 68 } 69 case "query": 70 query := m.mock.ExpectQuery(q.ExpectedSQLKeyWord) 71 if q.WithArgs != nil { 72 query = query.WithArgs(convertDBValue(q.WithArgs)...) 73 } 74 if q.ReturnError != "" { 75 query.WillReturnError(errors.Errorf(q.ReturnError)) 76 } else if q.ReturnRows != nil { 77 rows := sqlmock.NewRows(q.ReturnRows.Columns) 78 for _, r := range q.ReturnRows.Rows { 79 values := make([]driver.Value, 0) 80 for _, v := range r { 81 values = append(values, v) 82 } 83 rows.AddRow(values...) 84 } 85 query.WillReturnRows(rows) 86 } else { 87 return fmt.Errorf("error or rows are all nil") 88 } 89 default: 90 return fmt.Errorf("not supported type %s", q.Type) 91 } 92 } 93 94 return nil 95 } 96 97 func convertDBValue(data []interface{}) []driver.Value { 98 args := make([]driver.Value, 0) 99 for _, a := range data { 100 var v driver.Value 101 f := a.(float64) 102 if isValidFloatValue(f) { 103 v = f 104 } else { 105 v = int64(f) 106 } 107 args = append(args, v) 108 } 109 110 return args 111 } 112 113 func isValidFloatValue(v float64) bool { 114 compare := v 115 if v == float64(int64(compare)) { 116 return false 117 } 118 119 return true 120 }