github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/scenario/directive.go (about)

     1  package scenario
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/kyma-incubator/compass/components/director/pkg/str"
     7  
     8  	"github.com/kyma-incubator/compass/components/director/pkg/apperrors"
     9  
    10  	"github.com/kyma-incubator/compass/components/director/pkg/persistence"
    11  
    12  	"github.com/kyma-incubator/compass/components/director/pkg/log"
    13  
    14  	"github.com/kyma-incubator/compass/components/director/internal/domain/bundleinstanceauth"
    15  
    16  	"github.com/kyma-incubator/compass/components/director/internal/domain/bundle"
    17  
    18  	"github.com/kyma-incubator/compass/components/director/internal/model"
    19  
    20  	"github.com/kyma-incubator/compass/components/director/internal/domain/label"
    21  	"github.com/kyma-incubator/compass/components/director/internal/domain/tenant"
    22  	"github.com/pkg/errors"
    23  
    24  	"github.com/kyma-incubator/compass/components/director/pkg/consumer"
    25  
    26  	"github.com/99designs/gqlgen/graphql"
    27  )
    28  
    29  const (
    30  	// GetApplicationID missing godoc
    31  	GetApplicationID = "GetApplicationID"
    32  	// GetApplicationIDByBundle missing godoc
    33  	GetApplicationIDByBundle = "GetApplicationIDByBundle"
    34  	// GetApplicationIDByBundleInstanceAuth missing godoc
    35  	GetApplicationIDByBundleInstanceAuth = "GetApplicationIDByBundleInstanceAuth"
    36  )
    37  
    38  // ErrMissingScenario missing godoc
    39  var ErrMissingScenario = errors.New("Forbidden: Missing scenarios")
    40  
    41  type directive struct {
    42  	labelRepo label.LabelRepository
    43  	transact  persistence.Transactioner
    44  
    45  	applicationProviders map[string]func(context.Context, string, string) (string, error)
    46  }
    47  
    48  // NewDirective returns a new scenario directive
    49  func NewDirective(transact persistence.Transactioner, labelRepo label.LabelRepository, bundleRepo bundle.BundleRepository, bundleInstanceAuthRepo bundleinstanceauth.Repository) *directive {
    50  	getApplicationIDByBundleFunc := func(ctx context.Context, tenantID, bundleID string) (string, error) {
    51  		bndl, err := bundleRepo.GetByID(ctx, tenantID, bundleID)
    52  		if err != nil {
    53  			return "", errors.Wrapf(err, "while getting Bundle with id %s", bundleID)
    54  		}
    55  		return str.PtrStrToStr(bndl.ApplicationID), nil
    56  	}
    57  
    58  	return &directive{
    59  		transact:  transact,
    60  		labelRepo: labelRepo,
    61  		applicationProviders: map[string]func(context.Context, string, string) (string, error){
    62  			GetApplicationID: func(ctx context.Context, tenantID string, appID string) (string, error) {
    63  				return appID, nil
    64  			},
    65  			GetApplicationIDByBundle: getApplicationIDByBundleFunc,
    66  			GetApplicationIDByBundleInstanceAuth: func(ctx context.Context, tenantID, bundleInstanceAuthID string) (string, error) {
    67  				bundleInstanceAuth, err := bundleInstanceAuthRepo.GetByID(ctx, tenantID, bundleInstanceAuthID)
    68  				if err != nil {
    69  					return "", errors.Wrapf(err, "while getting Bundle instance auth with id %s", bundleInstanceAuthID)
    70  				}
    71  
    72  				return getApplicationIDByBundleFunc(ctx, tenantID, bundleInstanceAuth.BundleID)
    73  			},
    74  		},
    75  	}
    76  }
    77  
    78  // HasScenario ensures that the runtime is in a scenario with the application which resources are being manipulated.
    79  // If the caller is not a Runtime, then request is forwarded to the next resolver.
    80  func (d *directive) HasScenario(ctx context.Context, _ interface{}, next graphql.Resolver, applicationProvider string, idField string) (res interface{}, err error) {
    81  	consumerInfo, err := consumer.LoadFromContext(ctx)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	if consumerInfo.ConsumerType != consumer.Runtime {
    87  		log.C(ctx).Debugf("Consumer type %v is not of type %v. Skipping verification directive...", consumerInfo.ConsumerType, consumer.Runtime)
    88  		return next(ctx)
    89  	}
    90  	log.C(ctx).Infof("Attempting to verify that the requesting runtime is in scenario with the owning application entity")
    91  
    92  	runtimeID := consumerInfo.ConsumerID
    93  	log.C(ctx).Debugf("Found Runtime ID for the requesting runtime: %v", runtimeID)
    94  
    95  	commonScenarios, err := d.extractCommonScenarios(ctx, runtimeID, applicationProvider, idField)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	if len(commonScenarios) == 0 {
   101  		return nil, apperrors.NewInvalidOperationError("requesting runtime should be in same scenario as the requested application resource")
   102  	}
   103  	log.C(ctx).Debugf("Found the following common scenarios: %+v", commonScenarios)
   104  
   105  	log.C(ctx).Infof("Runtime with ID %s is in scenario with the owning application entity", runtimeID)
   106  	return next(ctx)
   107  }
   108  
   109  func (d *directive) extractCommonScenarios(ctx context.Context, runtimeID, applicationProvider, idField string) ([]string, error) {
   110  	tenantID, err := tenant.LoadFromContext(ctx)
   111  	if err != nil {
   112  		return nil, errors.Wrapf(err, "while loading tenant from context")
   113  	}
   114  
   115  	resCtx := graphql.GetFieldContext(ctx)
   116  	id, ok := resCtx.Args[idField].(string)
   117  	if !ok {
   118  		return nil, errors.Errorf("Could not get idField: %s from request context", idField)
   119  	}
   120  
   121  	appProviderFunc, ok := d.applicationProviders[applicationProvider]
   122  	if !ok {
   123  		return nil, errors.Errorf("Could not get app provider func: %s from provider list", applicationProvider)
   124  	}
   125  
   126  	tx, err := d.transact.Begin()
   127  	if err != nil {
   128  		log.C(ctx).WithError(err).Errorf("An error occurred while opening the db transaction: %v", err)
   129  		return nil, err
   130  	}
   131  	defer d.transact.RollbackUnlessCommitted(ctx, tx)
   132  
   133  	ctx = persistence.SaveToContext(ctx, tx)
   134  
   135  	appID, err := appProviderFunc(ctx, tenantID, id)
   136  	if err != nil {
   137  		return nil, errors.Wrapf(err, "Could not derive app id, an error occurred")
   138  	}
   139  	log.C(ctx).Infof("Found owning Application ID based on the request parameter %s: %s", idField, appID)
   140  
   141  	appScenarios, err := d.getObjectScenarios(ctx, tenantID, model.ApplicationLabelableObject, appID)
   142  	if err != nil {
   143  		return nil, errors.Wrap(err, "while fetching scenarios for application")
   144  	}
   145  	log.C(ctx).Debugf("Found the following application scenarios: %s", appScenarios)
   146  
   147  	runtimeScenarios, err := d.getObjectScenarios(ctx, tenantID, model.RuntimeLabelableObject, runtimeID)
   148  	if err != nil {
   149  		return nil, errors.Wrap(err, "while fetching scenarios for runtime")
   150  	}
   151  	log.C(ctx).Debugf("Found the following runtime scenarios: %s", runtimeScenarios)
   152  
   153  	if err := tx.Commit(); err != nil {
   154  		log.C(ctx).WithError(err).Errorf("An error occurred while committing transaction: %v", err)
   155  		return nil, err
   156  	}
   157  
   158  	commonScenarios := stringsIntersection(appScenarios, runtimeScenarios)
   159  	return commonScenarios, nil
   160  }
   161  
   162  func (d *directive) getObjectScenarios(ctx context.Context, tenantID string, objectType model.LabelableObject, objectID string) ([]string, error) {
   163  	scenariosLabel, err := d.labelRepo.GetByKey(ctx, tenantID, objectType, objectID, model.ScenariosKey)
   164  	if err != nil {
   165  		if apperrors.IsNotFoundError(err) {
   166  			return make([]string, 0), nil
   167  		}
   168  		return nil, errors.Wrapf(err, "while fetching scenarios for object with id: %s and type: %s", objectID, objectType)
   169  	}
   170  	return label.ValueToStringsSlice(scenariosLabel.Value)
   171  }
   172  
   173  // stringsIntersection returns the common elements in two string slices.
   174  func stringsIntersection(str1, str2 []string) []string {
   175  	var intersection []string
   176  	strings := make(map[string]bool)
   177  	for _, v := range str1 {
   178  		strings[v] = true
   179  	}
   180  	for _, v := range str2 {
   181  		if strings[v] {
   182  			intersection = append(intersection, v)
   183  		}
   184  	}
   185  	return intersection
   186  }