github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/internal/domain/tenant/tenant_test.go (about)

     1  package tenant_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/kyma-incubator/compass/components/director/internal/domain/tenant"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestLoadFromContext(t *testing.T) {
    15  	value := "foo"
    16  	tenants := tenant.TenantCtx{InternalID: value, ExternalID: value}
    17  	tenantsEmptyInternalID := tenant.TenantCtx{InternalID: "", ExternalID: value}
    18  
    19  	testCases := []struct {
    20  		Name    string
    21  		Context context.Context
    22  
    23  		ExpectedResult     string
    24  		ExpectedErrMessage string
    25  	}{
    26  		{
    27  			Name:               "Success",
    28  			Context:            context.WithValue(context.TODO(), tenant.TenantContextKey, tenants),
    29  			ExpectedResult:     value,
    30  			ExpectedErrMessage: "",
    31  		},
    32  		{
    33  			Name:               "Error",
    34  			Context:            context.TODO(),
    35  			ExpectedResult:     "",
    36  			ExpectedErrMessage: "cannot read tenant from context",
    37  		},
    38  		{
    39  			Name:               "Error required tenant",
    40  			Context:            context.WithValue(context.TODO(), tenant.TenantContextKey, tenantsEmptyInternalID),
    41  			ExpectedResult:     "",
    42  			ExpectedErrMessage: "Tenant is required",
    43  		},
    44  	}
    45  
    46  	for i, testCase := range testCases {
    47  		t.Run(fmt.Sprintf("%d: %s", i, testCase.Name), func(t *testing.T) {
    48  			// WHEN
    49  			result, err := tenant.LoadFromContext(testCase.Context)
    50  
    51  			// then
    52  			if testCase.ExpectedErrMessage != "" {
    53  				require.Equal(t, testCase.ExpectedErrMessage, err.Error())
    54  				return
    55  			}
    56  
    57  			assert.Equal(t, testCase.ExpectedResult, result)
    58  		})
    59  	}
    60  }
    61  
    62  func TestLoadTenantPairFromContext(t *testing.T) {
    63  	value := "foo"
    64  	tenants := tenant.TenantCtx{InternalID: value, ExternalID: value}
    65  	tenantsEmptyInternalID := tenant.TenantCtx{InternalID: "", ExternalID: value}
    66  
    67  	testCases := []struct {
    68  		Name    string
    69  		Context context.Context
    70  
    71  		ExpectedResult     tenant.TenantCtx
    72  		ExpectedErrMessage string
    73  	}{
    74  		{
    75  			Name:               "Success",
    76  			Context:            context.WithValue(context.TODO(), tenant.TenantContextKey, tenants),
    77  			ExpectedResult:     tenants,
    78  			ExpectedErrMessage: "",
    79  		},
    80  		{
    81  			Name:               "Error",
    82  			Context:            context.TODO(),
    83  			ExpectedResult:     tenant.TenantCtx{},
    84  			ExpectedErrMessage: "cannot read tenant from context",
    85  		},
    86  		{
    87  			Name:               "Error required tenant",
    88  			Context:            context.WithValue(context.TODO(), tenant.TenantContextKey, tenantsEmptyInternalID),
    89  			ExpectedResult:     tenant.TenantCtx{},
    90  			ExpectedErrMessage: "Tenant is required",
    91  		},
    92  	}
    93  
    94  	for i, testCase := range testCases {
    95  		t.Run(fmt.Sprintf("%d: %s", i, testCase.Name), func(t *testing.T) {
    96  			// WHEN
    97  			result, err := tenant.LoadTenantPairFromContext(testCase.Context)
    98  
    99  			// then
   100  			if testCase.ExpectedErrMessage != "" {
   101  				require.Equal(t, testCase.ExpectedErrMessage, err.Error())
   102  				return
   103  			}
   104  
   105  			assert.Equal(t, testCase.ExpectedResult, result)
   106  		})
   107  	}
   108  }
   109  
   110  func TestLoadTenantPairFromContextNoChecks(t *testing.T) {
   111  	value := "foo"
   112  	tenants := tenant.TenantCtx{InternalID: value, ExternalID: value}
   113  
   114  	testCases := []struct {
   115  		Name    string
   116  		Context context.Context
   117  
   118  		ExpectedResult     tenant.TenantCtx
   119  		ExpectedErrMessage string
   120  	}{
   121  		{
   122  			Name:               "Success",
   123  			Context:            context.WithValue(context.TODO(), tenant.TenantContextKey, tenants),
   124  			ExpectedResult:     tenants,
   125  			ExpectedErrMessage: "",
   126  		},
   127  		{
   128  			Name:               "Error",
   129  			Context:            context.TODO(),
   130  			ExpectedResult:     tenant.TenantCtx{},
   131  			ExpectedErrMessage: "cannot read tenant from context",
   132  		},
   133  	}
   134  
   135  	for i, testCase := range testCases {
   136  		t.Run(fmt.Sprintf("%d: %s", i, testCase.Name), func(t *testing.T) {
   137  			// WHEN
   138  			result, err := tenant.LoadTenantPairFromContextNoChecks(testCase.Context)
   139  
   140  			// then
   141  			if testCase.ExpectedErrMessage != "" {
   142  				require.Equal(t, testCase.ExpectedErrMessage, err.Error())
   143  				return
   144  			}
   145  
   146  			assert.Equal(t, testCase.ExpectedResult, result)
   147  		})
   148  	}
   149  }
   150  
   151  func TestSaveToLoadFromContext(t *testing.T) {
   152  	// GIVEN
   153  	value := "foo"
   154  	externalValue := "bar"
   155  	ctx := context.TODO()
   156  
   157  	tenants := tenant.TenantCtx{InternalID: value, ExternalID: externalValue}
   158  	// WHEN
   159  	result := tenant.SaveToContext(ctx, value, externalValue)
   160  
   161  	// then
   162  	assert.Equal(t, tenants, result.Value(tenant.TenantContextKey))
   163  }