github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/internal/repo/testdb/generic_repo_create_tests.go (about)

     1  package testdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"reflect"
     8  	"testing"
     9  
    10  	"github.com/DATA-DOG/go-sqlmock"
    11  	"github.com/kyma-incubator/compass/components/director/pkg/apperrors"
    12  	"github.com/kyma-incubator/compass/components/director/pkg/persistence"
    13  	"github.com/pkg/errors"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/mock"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  // Mock represents a mockery Mock.
    20  type Mock interface {
    21  	AssertExpectations(t mock.TestingT) bool
    22  	On(methodName string, arguments ...interface{}) *mock.Call
    23  }
    24  
    25  // SQLQueryDetails represent an SQL expected query details to provide to the DB mock.
    26  type SQLQueryDetails struct {
    27  	Query               string
    28  	IsSelect            bool
    29  	Args                []driver.Value
    30  	ValidResult         driver.Result
    31  	InvalidResult       driver.Result
    32  	ValidRowsProvider   func() []*sqlmock.Rows
    33  	InvalidRowsProvider func() []*sqlmock.Rows
    34  }
    35  
    36  // RepoCreateTestSuite represents a generic test suite for repository Create method of any global entity or entity that has externally managed tenants in m2m table/view.
    37  type RepoCreateTestSuite struct {
    38  	Name                      string
    39  	SQLQueryDetails           []SQLQueryDetails
    40  	ConverterMockProvider     func() Mock
    41  	RepoConstructorFunc       interface{}
    42  	ModelEntity               interface{}
    43  	DBEntity                  interface{}
    44  	NilModelEntity            interface{}
    45  	TenantID                  string
    46  	DisableConverterErrorTest bool
    47  	MethodName                string
    48  	IsTopLevelEntity          bool
    49  	IsGlobal                  bool
    50  }
    51  
    52  // Run runs the generic repo create test suite
    53  func (suite *RepoCreateTestSuite) Run(t *testing.T) bool {
    54  	if len(suite.MethodName) == 0 {
    55  		suite.MethodName = "Create"
    56  	}
    57  
    58  	return t.Run(suite.Name, func(t *testing.T) {
    59  		testErr := errors.New("test error")
    60  
    61  		t.Run("success", func(t *testing.T) {
    62  			sqlxDB, sqlMock := MockDatabase(t)
    63  			ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
    64  
    65  			configureValidSQLQueries(sqlMock, suite.SQLQueryDetails)
    66  
    67  			convMock := suite.ConverterMockProvider()
    68  			convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once()
    69  			pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
    70  
    71  			// WHEN
    72  			err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity)
    73  
    74  			// THEN
    75  			require.NoError(t, err)
    76  			sqlMock.AssertExpectations(t)
    77  			convMock.AssertExpectations(t)
    78  		})
    79  
    80  		if !suite.IsTopLevelEntity && !suite.IsGlobal {
    81  			t.Run("error when parent access is missing", func(t *testing.T) {
    82  				sqlxDB, sqlMock := MockDatabase(t)
    83  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
    84  
    85  				configureInvalidSelect(sqlMock, suite.SQLQueryDetails)
    86  
    87  				convMock := suite.ConverterMockProvider()
    88  				convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once()
    89  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
    90  				// WHEN
    91  				err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity)
    92  				// THEN
    93  				require.Error(t, err)
    94  				require.Equal(t, apperrors.Unauthorized, apperrors.ErrorCode(err))
    95  				require.Contains(t, err.Error(), fmt.Sprintf("Tenant %s does not have access to the parent", suite.TenantID))
    96  
    97  				sqlMock.AssertExpectations(t)
    98  				convMock.AssertExpectations(t)
    99  			})
   100  		}
   101  
   102  		for i := range suite.SQLQueryDetails {
   103  			t.Run(fmt.Sprintf("error if SQL query %d fail", i), func(t *testing.T) {
   104  				sqlxDB, sqlMock := MockDatabase(t)
   105  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
   106  
   107  				configureFailureForSQLQueryOnIndex(sqlMock, suite.SQLQueryDetails, i, testErr)
   108  
   109  				convMock := suite.ConverterMockProvider()
   110  				convMock.On("ToEntity", suite.ModelEntity).Return(suite.DBEntity, nil).Once()
   111  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   112  				// WHEN
   113  				err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity)
   114  				// THEN
   115  				require.Error(t, err)
   116  				if suite.SQLQueryDetails[i].IsSelect {
   117  					require.Equal(t, apperrors.Unauthorized, apperrors.ErrorCode(err))
   118  					require.Contains(t, err.Error(), fmt.Sprintf("Tenant %s does not have access to the parent", suite.TenantID))
   119  				} else {
   120  					require.Equal(t, apperrors.InternalError, apperrors.ErrorCode(err))
   121  					require.Contains(t, err.Error(), "Internal Server Error: Unexpected error while executing SQL query")
   122  				}
   123  				sqlMock.AssertExpectations(t)
   124  				convMock.AssertExpectations(t)
   125  			})
   126  		}
   127  
   128  		if !suite.DisableConverterErrorTest {
   129  			t.Run("error when conversion fail", func(t *testing.T) {
   130  				sqlxDB, sqlMock := MockDatabase(t)
   131  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
   132  
   133  				convMock := suite.ConverterMockProvider()
   134  				convMock.On("ToEntity", suite.ModelEntity).Return(nil, testErr).Once()
   135  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   136  				// WHEN
   137  				err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.ModelEntity)
   138  				// THEN
   139  				require.Error(t, err)
   140  				require.Contains(t, err.Error(), testErr.Error())
   141  
   142  				sqlMock.AssertExpectations(t)
   143  				convMock.AssertExpectations(t)
   144  			})
   145  		}
   146  
   147  		t.Run("returns error when item is nil", func(t *testing.T) {
   148  			ctx := context.TODO()
   149  			convMock := suite.ConverterMockProvider()
   150  			pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   151  			// WHEN
   152  			err := callCreate(pgRepository, suite.MethodName, ctx, suite.TenantID, suite.NilModelEntity)
   153  			// THEN
   154  			require.Error(t, err)
   155  			assert.Contains(t, err.Error(), "Internal Server Error")
   156  			convMock.AssertExpectations(t)
   157  		})
   158  	})
   159  }
   160  
   161  // callCreate calls the Create method of the given repository.
   162  // In order to do this for all the different repository implementations we need to do it via reflection.
   163  func callCreate(repo interface{}, methodName string, ctx context.Context, tenant string, modelEntity interface{}) error {
   164  	args := []reflect.Value{reflect.ValueOf(ctx)}
   165  	if len(tenant) > 0 {
   166  		args = append(args, reflect.ValueOf(tenant))
   167  	}
   168  	args = append(args, reflect.ValueOf(modelEntity))
   169  	results := reflect.ValueOf(repo).MethodByName(methodName).Call(args)
   170  	if len(results) != 1 {
   171  		panic("Create should return one argument")
   172  	}
   173  	result := results[0].Interface()
   174  	if result == nil {
   175  		return nil
   176  	}
   177  	err, ok := result.(error)
   178  	if !ok {
   179  		panic("Expected result to be an error")
   180  	}
   181  	return err
   182  }
   183  
   184  // createRepo creates a new repository by the provided constructor func.
   185  // In order to do this for all the different repository implementations we need to do it via reflection.
   186  func createRepo(repoConstructorFunc interface{}, convMock interface{}) interface{} {
   187  	v := reflect.ValueOf(repoConstructorFunc)
   188  	if v.Kind() != reflect.Func {
   189  		panic("Repo constructor should be a function")
   190  	}
   191  	t := v.Type()
   192  
   193  	if t.NumOut() != 1 {
   194  		panic("Repo constructor should return only one argument")
   195  	}
   196  
   197  	if t.NumIn() == 0 {
   198  		return v.Call(nil)[0].Interface()
   199  	}
   200  
   201  	if t.NumIn() != 1 {
   202  		panic("Repo constructor should accept zero or one arguments")
   203  	}
   204  
   205  	mockVal := reflect.ValueOf(convMock)
   206  	return v.Call([]reflect.Value{mockVal})[0].Interface()
   207  }
   208  
   209  func configureValidSQLQueries(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails) {
   210  	for _, sqlDetails := range sqlQueryDetails {
   211  		if sqlDetails.IsSelect {
   212  			sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.ValidRowsProvider()...)
   213  		} else {
   214  			sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnResult(sqlDetails.ValidResult)
   215  		}
   216  	}
   217  }
   218  
   219  func configureInvalidSelect(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails) {
   220  	for _, sqlDetails := range sqlQueryDetails {
   221  		if sqlDetails.IsSelect {
   222  			sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.InvalidRowsProvider()...)
   223  			break
   224  		}
   225  	}
   226  }
   227  
   228  func configureFailureForSQLQueryOnIndex(sqlMock DBMock, sqlQueryDetails []SQLQueryDetails, i int, expectedErr error) {
   229  	for _, sqlDetails := range sqlQueryDetails {
   230  		if sqlDetails.IsSelect {
   231  			if sqlDetails.Query == sqlQueryDetails[i].Query {
   232  				sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnError(expectedErr)
   233  				break
   234  			} else {
   235  				sqlMock.ExpectQuery(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnRows(sqlDetails.ValidRowsProvider()...)
   236  			}
   237  		} else {
   238  			if sqlDetails.Query == sqlQueryDetails[i].Query {
   239  				sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnError(expectedErr)
   240  				break
   241  			} else {
   242  				sqlMock.ExpectExec(sqlDetails.Query).WithArgs(sqlDetails.Args...).WillReturnResult(sqlDetails.ValidResult)
   243  			}
   244  		}
   245  	}
   246  }