github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/internal/repo/testdb/generic_repo_get_tests.go (about)

     1  package testdb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"testing"
     8  
     9  	"github.com/kyma-incubator/compass/components/director/pkg/apperrors"
    10  	"github.com/kyma-incubator/compass/components/director/pkg/persistence"
    11  	"github.com/pkg/errors"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // RepoGetTestSuite represents a generic test suite for repository Get method of any global entity or entity that has externally managed tenants in m2m table/view.
    16  // This test suite is not suitable entities with embedded tenant in them.
    17  type RepoGetTestSuite struct {
    18  	Name                                  string
    19  	SQLQueryDetails                       []SQLQueryDetails
    20  	ConverterMockProvider                 func() Mock
    21  	RepoConstructorFunc                   interface{}
    22  	ExpectedModelEntity                   interface{}
    23  	ExpectedDBEntity                      interface{}
    24  	MethodArgs                            []interface{}
    25  	AdditionalConverterArgs               []interface{}
    26  	DisableConverterErrorTest             bool
    27  	MethodName                            string
    28  	ExpectNotFoundError                   bool
    29  	AfterNotFoundErrorSQLQueryDetails     []SQLQueryDetails
    30  	AfterNotFoundErrorExpectedModelEntity interface{}
    31  	AfterNotFoundErrorExpectedDBEntity    interface{}
    32  }
    33  
    34  // Run runs the generic repo get test suite
    35  func (suite *RepoGetTestSuite) Run(t *testing.T) bool {
    36  	if len(suite.MethodName) == 0 {
    37  		suite.MethodName = "GetByID"
    38  	}
    39  
    40  	for _, queryDetails := range suite.SQLQueryDetails {
    41  		if !queryDetails.IsSelect {
    42  			panic("get suite should expect only select SQL statements")
    43  		}
    44  	}
    45  
    46  	return t.Run(suite.Name, func(t *testing.T) {
    47  		testErr := errors.New("test error")
    48  
    49  		t.Run("success", func(t *testing.T) {
    50  			sqlxDB, sqlMock := MockDatabase(t)
    51  			ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
    52  
    53  			configureValidSQLQueries(sqlMock, suite.SQLQueryDetails)
    54  
    55  			convMock := suite.ConverterMockProvider()
    56  			convMock.On("FromEntity", append([]interface{}{suite.ExpectedDBEntity}, suite.AdditionalConverterArgs...)...).Return(suite.ExpectedModelEntity, nil).Once()
    57  			pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
    58  			// WHEN
    59  			res, err := callGet(pgRepository, ctx, suite.MethodName, suite.MethodArgs)
    60  			// THEN
    61  			require.NoError(t, err)
    62  			require.Equal(t, suite.ExpectedModelEntity, res)
    63  			sqlMock.AssertExpectations(t)
    64  			convMock.AssertExpectations(t)
    65  		})
    66  
    67  		t.Run("returns not found error when no rows", func(t *testing.T) {
    68  			sqlxDB, sqlMock := MockDatabase(t)
    69  			ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
    70  
    71  			configureInvalidSelect(sqlMock, suite.SQLQueryDetails)
    72  
    73  			convMock := suite.ConverterMockProvider()
    74  			if suite.ExpectNotFoundError {
    75  				convMock.On("FromEntity", append([]interface{}{suite.AfterNotFoundErrorExpectedDBEntity}, suite.AdditionalConverterArgs...)...).Return(suite.AfterNotFoundErrorExpectedModelEntity, nil).Once()
    76  				configureValidSQLQueries(sqlMock, suite.AfterNotFoundErrorSQLQueryDetails)
    77  			}
    78  			pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
    79  			// WHEN
    80  			res, err := callGet(pgRepository, ctx, suite.MethodName, suite.MethodArgs)
    81  			// THEN
    82  			if !suite.ExpectNotFoundError {
    83  				require.Error(t, err)
    84  				require.Equal(t, apperrors.NotFound, apperrors.ErrorCode(err))
    85  				require.Contains(t, err.Error(), apperrors.NotFoundMsg)
    86  				require.Nil(t, res)
    87  				sqlMock.AssertExpectations(t)
    88  				convMock.AssertExpectations(t)
    89  			} else {
    90  				require.NoError(t, err)
    91  				require.Equal(t, suite.AfterNotFoundErrorExpectedModelEntity, res)
    92  				sqlMock.AssertExpectations(t)
    93  				convMock.AssertExpectations(t)
    94  			}
    95  		})
    96  
    97  		for i := range suite.SQLQueryDetails {
    98  			t.Run(fmt.Sprintf("error if SQL query %d fail", i), func(t *testing.T) {
    99  				sqlxDB, sqlMock := MockDatabase(t)
   100  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
   101  
   102  				configureFailureForSQLQueryOnIndex(sqlMock, suite.SQLQueryDetails, i, testErr)
   103  
   104  				convMock := suite.ConverterMockProvider()
   105  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   106  
   107  				// WHEN
   108  				res, err := callGet(pgRepository, ctx, suite.MethodName, suite.MethodArgs)
   109  
   110  				// THEN
   111  				require.Nil(t, res)
   112  
   113  				require.Error(t, err)
   114  				require.Equal(t, apperrors.InternalError, apperrors.ErrorCode(err))
   115  				require.Contains(t, err.Error(), "Internal Server Error: Unexpected error while executing SQL query")
   116  
   117  				sqlMock.AssertExpectations(t)
   118  				convMock.AssertExpectations(t)
   119  			})
   120  		}
   121  
   122  		if suite.ExpectNotFoundError {
   123  			t.Run("error if SQL query fail", func(t *testing.T) {
   124  				sqlxDB, sqlMock := MockDatabase(t)
   125  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
   126  
   127  				configureInvalidSelect(sqlMock, suite.SQLQueryDetails)
   128  
   129  				convMock := suite.ConverterMockProvider()
   130  				configureFailureForSQLQueryOnIndex(sqlMock, suite.AfterNotFoundErrorSQLQueryDetails, 0, testErr)
   131  
   132  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   133  				// WHEN
   134  				res, err := callGet(pgRepository, ctx, suite.MethodName, suite.MethodArgs)
   135  				// THEN
   136  				require.Nil(t, res)
   137  
   138  				require.Error(t, err)
   139  				require.Equal(t, apperrors.InternalError, apperrors.ErrorCode(err))
   140  				require.Contains(t, err.Error(), "Internal Server Error: Unexpected error while executing SQL query")
   141  
   142  				sqlMock.AssertExpectations(t)
   143  				convMock.AssertExpectations(t)
   144  			})
   145  		}
   146  
   147  		if !suite.DisableConverterErrorTest {
   148  			t.Run("error when conversion fail", func(t *testing.T) {
   149  				sqlxDB, sqlMock := MockDatabase(t)
   150  				ctx := persistence.SaveToContext(context.TODO(), sqlxDB)
   151  
   152  				configureValidSQLQueries(sqlMock, suite.SQLQueryDetails)
   153  
   154  				convMock := suite.ConverterMockProvider()
   155  				convMock.On("FromEntity", append([]interface{}{suite.ExpectedDBEntity}, suite.AdditionalConverterArgs...)...).Return(nil, testErr).Once()
   156  				pgRepository := createRepo(suite.RepoConstructorFunc, convMock)
   157  				// WHEN
   158  				res, err := callGet(pgRepository, ctx, suite.MethodName, suite.MethodArgs)
   159  				// THEN
   160  				require.Nil(t, res)
   161  
   162  				require.Error(t, err)
   163  				require.Contains(t, err.Error(), testErr.Error())
   164  
   165  				sqlMock.AssertExpectations(t)
   166  				convMock.AssertExpectations(t)
   167  			})
   168  		}
   169  	})
   170  }
   171  
   172  func callGet(repo interface{}, ctx context.Context, methodName string, args []interface{}) (interface{}, error) {
   173  	argsVals := make([]reflect.Value, 1, len(args))
   174  	argsVals[0] = reflect.ValueOf(ctx)
   175  	for _, arg := range args {
   176  		argsVals = append(argsVals, reflect.ValueOf(arg))
   177  	}
   178  	results := reflect.ValueOf(repo).MethodByName(methodName).Call(argsVals)
   179  	if len(results) != 2 {
   180  		panic("Get should return two argument")
   181  	}
   182  
   183  	errResult := results[1].Interface()
   184  	if errResult == nil {
   185  		return results[0].Interface(), nil
   186  	}
   187  	err, ok := errResult.(error)
   188  	if !ok {
   189  		panic("Expected result to be an error")
   190  	}
   191  	return results[0].Interface(), err
   192  }