github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/internal/repo/testdb/generic_repo_create_tests.go (about) 1 package testdb 2 3 import ( 4 "context" 5 "database/sql/driver" 6 "fmt" 7 "reflect" 8 "testing" 9 10 "github.com/DATA-DOG/go-sqlmock" 11 "github.com/kyma-incubator/compass/components/director/pkg/apperrors" 12 "github.com/kyma-incubator/compass/components/director/pkg/persistence" 13 "github.com/pkg/errors" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/mock" 16 "github.com/stretchr/testify/require" 17 ) 18 19 // Mock represents a mockery Mock. 20 type Mock interface { 21 AssertExpectations(t mock.TestingT) bool 22 On(methodName string, arguments ...interface{}) *mock.Call 23 } 24 25 // SQLQueryDetails represent an SQL expected query details to provide to the DB mock. 26 type SQLQueryDetails struct { 27 Query string 28 IsSelect bool 29 Args []driver.Value 30 ValidResult driver.Result 31 InvalidResult driver.Result 32 ValidRowsProvider func() []*sqlmock.Rows 33 InvalidRowsProvider func() []*sqlmock.Rows 34 } 35 36 // RepoCreateTestSuite represents a generic test suite for repository Create method of any global entity or entity that has externally managed tenants in m2m table/view. 37 type RepoCreateTestSuite struct { 38 Name string 39 SQLQueryDetails []SQLQueryDetails 40 ConverterMockProvider func() Mock 41 RepoConstructorFunc interface{} 42 ModelEntity interface{} 43 DBEntity interface{} 44 NilModelEntity interface{} 45 TenantID string 46 DisableConverterErrorTest bool 47 MethodName string 48 IsTopLevelEntity bool 49 IsGlobal bool 50 } 51 52 // Run runs the generic repo create test suite 53 func (suite *RepoCreateTestSuite) Run(t *testing.T) bool { 54 if len(suite.MethodName) == 0 { 55 suite.MethodName = "Create" 56 } 57 58 return t.Run(suite.Name, func(t *testing.T) { 59 testErr := errors.New("test error") 60 61 t.Run("success", func(t *testing.T) { 62 sqlxDB, sqlMock := MockDatabase(t) 63 ctx := persistence.SaveToContext(context.TODO(), sqlxDB) 64 65 configureValidSQLQueries(sqlMock, suite.SQLQueryDetails) 66 67 convMock := suite.ConverterMockProvider() 68 convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once() 69 pgRepository := createRepo(suite.RepoConstructorFunc, convMock) 70 71 // WHEN 72 err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity) 73 74 // THEN 75 require.NoError(t, err) 76 sqlMock.AssertExpectations(t) 77 convMock.AssertExpectations(t) 78 }) 79 80 if !suite.IsTopLevelEntity && !suite.IsGlobal { 81 t.Run("error when parent access is missing", func(t *testing.T) { 82 sqlxDB, sqlMock := MockDatabase(t) 83 ctx := persistence.SaveToContext(context.TODO(), sqlxDB) 84 85 configureInvalidSelect(sqlMock, suite.SQLQueryDetails) 86 87 convMock := suite.ConverterMockProvider() 88 convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once() 89 pgRepository := createRepo(suite.RepoConstructorFunc, convMock) 90 // WHEN 91 err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity) 92 // THEN 93 require.Error(t, err) 94 require.Equal(t, apperrors.Unauthorized, apperrors.ErrorCode(err)) 95 require.Contains(t, err.Error(), fmt.Sprintf("Tenant %s does not have access to the parent", suite.TenantID)) 96 97 sqlMock.AssertExpectations(t) 98 convMock.AssertExpectations(t) 99 }) 100 } 101 102 for i := range suite.SQLQueryDetails { 103 t.Run(fmt.Sprintf("error if SQL query %d fail", i), func(t *testing.T) { 104 sqlxDB, sqlMock := MockDatabase(t) 105 ctx := persistence.SaveToContext(context.TODO(), sqlxDB) 106 107 configureFailureForSQLQueryOnIndex(sqlMock, suite.SQLQueryDetails, i, testErr) 108 109 convMock := suite.ConverterMockProvider() 110 convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once() 111 pgRepository := createRepo(suite.RepoConstructorFunc, convMock) 112 // WHEN 113 err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity) 114 // THEN 115 require.Error(t, err) 116 if suite.SQLQueryDetails[i].IsSelect { 117 require.Equal(t, apperrors.Unauthorized, apperrors.ErrorCode(err)) 118 require.Contains(t, err.Error(), fmt.Sprintf("Tenant %s does not have access to the parent", suite.TenantID)) 119 } else { 120 require.Equal(t, apperrors.InternalError, apperrors.ErrorCode(err)) 121 require.Contains(t, err.Error(), "Internal Server Error: Unexpected error while executing SQL query") 122 } 123 sqlMock.AssertExpectations(t) 124 convMock.AssertExpectations(t) 125 }) 126 } 127 128 if !suite.DisableConverterErrorTest { 129 t.Run("error when conversion fail", func(t *testing.T) { 130 sqlxDB, sqlMock := MockDatabase(t) 131 ctx := persistence.SaveToContext(context.TODO(), sqlxDB) 132 133 convMock := suite.ConverterMockProvider() 134 convMock.On("ToEntity", suite.ModelEntity).Return(nil, testErr).Once() 135 pgRepository := createRepo(suite.RepoConstructorFunc, convMock) 136 // WHEN 137 err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity) 138 // THEN 139 require.Error(t, err) 140 require.Contains(t, err.Error(), testErr.Error()) 141 142 sqlMock.AssertExpectations(t) 143 convMock.AssertExpectations(t) 144 }) 145 } 146 147 t.Run("returns error when item is nil", func(t *testing.T) { 148 ctx := context.TODO() 149 convMock := suite.ConverterMockProvider() 150 pgRepository := createRepo(suite.RepoConstructorFunc, convMock) 151 // WHEN 152 err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.NilModelEntity) 153 // THEN 154 require.Error(t, err) 155 assert.Contains(t, err.Error(), "Internal Server Error") 156 convMock.AssertExpectations(t) 157 }) 158 }) 159 } 160 161 // callCreate calls the Create method of the given repository. 162 // In order to do this for all the different repository implementations we need to do it via reflection. 163 func callCreate(repo interface{}, methodName string, ctx context.Context, tenant string, modelEntity interface{}) error { 164 args := []reflect.Value{reflect.ValueOf(ctx)} 165 if len(tenant) > 0 { 166 args = append(args, reflect.ValueOf(tenant)) 167 } 168 args = append(args, reflect.ValueOf(modelEntity)) 169 results := reflect.ValueOf(repo).MethodByName(methodName).Call(args) 170 if len(results) != 1 { 171 panic("Create should return one argument") 172 } 173 result := results[0].Interface() 174 if result == nil { 175 return nil 176 } 177 err, ok := result.(error) 178 if !ok { 179 panic("Expected result to be an error") 180 } 181 return err 182 } 183 184 // createRepo creates a new repository by the provided constructor func. 185 // In order to do this for all the different repository implementations we need to do it via reflection. 186 func createRepo(repoConstructorFunc interface{}, convMock interface{}) interface{} { 187 v := reflect.ValueOf(repoConstructorFunc) 188 if v.Kind() != reflect.Func { 189 panic("Repo constructor should be a function") 190 } 191 t := v.Type() 192 193 if t.NumOut() != 1 { 194 panic("Repo constructor should return only one argument") 195 } 196 197 if t.NumIn() == 0 { 198 return v.Call(nil)[0].Interface() 199 } 200 201 if t.NumIn() != 1 { 202 panic("Repo constructor should accept zero or one arguments") 203 } 204 205 mockVal := reflect.ValueOf(convMock) 206 return v.Call([]reflect.Value{mockVal})[0].Interface() 207 } 208 209 func configureValidSQLQueries(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails) { 210 for _, sqlDetails := range sqlQueryDetails { 211 if sqlDetails.IsSelect { 212 sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.ValidRowsProvider()...) 213 } else { 214 sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnResult(sqlDetails.ValidResult) 215 } 216 } 217 } 218 219 func configureInvalidSelect(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails) { 220 for _, sqlDetails := range sqlQueryDetails { 221 if sqlDetails.IsSelect { 222 sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.InvalidRowsProvider()...) 223 break 224 } 225 } 226 } 227 228 func configureFailureForSQLQueryOnIndex(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails, i int, expectedErr error) { 229 for _, sqlDetails := range sqlQueryDetails { 230 if sqlDetails.IsSelect { 231 if sqlDetails.Query == sqlQueryDetails[i].Query { 232 sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnError(expectedErr) 233 break 234 } else { 235 sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.ValidRowsProvider()...) 236 } 237 } else { 238 if sqlDetails.Query == sqlQueryDetails[i].Query { 239 sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnError(expectedErr) 240 break 241 } else { 242 sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnResult(sqlDetails.ValidResult) 243 } 244 } 245 } 246 }