github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/v1/bulkcheck.go (about)

     1  package v1
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
     8  	"github.com/jzelinskie/stringz"
     9  	"google.golang.org/grpc/status"
    10  
    11  	"github.com/authzed/spicedb/internal/dispatch"
    12  	"github.com/authzed/spicedb/internal/graph"
    13  	"github.com/authzed/spicedb/internal/graph/computed"
    14  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    15  	"github.com/authzed/spicedb/internal/middleware/usagemetrics"
    16  	"github.com/authzed/spicedb/internal/namespace"
    17  	"github.com/authzed/spicedb/internal/services/shared"
    18  	"github.com/authzed/spicedb/internal/taskrunner"
    19  	"github.com/authzed/spicedb/pkg/genutil"
    20  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    21  	"github.com/authzed/spicedb/pkg/genutil/slicez"
    22  	"github.com/authzed/spicedb/pkg/middleware/consistency"
    23  	dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    24  	"github.com/authzed/spicedb/pkg/spiceerrors"
    25  )
    26  
    27  // bulkChecker contains the logic to allow ExperimentalService/BulkCheckPermission and
    28  // PermissionsService/CheckBulkPermissions to share the same implementation.
    29  type bulkChecker struct {
    30  	maxAPIDepth          uint32
    31  	maxCaveatContextSize int
    32  	maxConcurrency       uint16
    33  
    34  	dispatch dispatch.Dispatcher
    35  }
    36  
    37  func (bc *bulkChecker) checkBulkPermissions(ctx context.Context, req *v1.CheckBulkPermissionsRequest) (*v1.CheckBulkPermissionsResponse, error) {
    38  	atRevision, checkedAt, err := consistency.RevisionFromContext(ctx)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	if len(req.Items) > maxBulkCheckCount {
    44  		return nil, NewExceedsMaximumChecksErr(uint64(len(req.Items)), maxBulkCheckCount)
    45  	}
    46  
    47  	// Compute a hash for each requested item and record its index(es) for the items, to be used for sorting of results.
    48  	itemCount, err := genutil.EnsureUInt32(len(req.Items))
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	itemIndexByHash := mapz.NewMultiMapWithCap[string, int](itemCount)
    54  	for index, item := range req.Items {
    55  		itemHash, err := computeCheckBulkPermissionsItemHash(item)
    56  		if err != nil {
    57  			return nil, err
    58  		}
    59  
    60  		itemIndexByHash.Add(itemHash, index)
    61  	}
    62  
    63  	// Identify checks with same permission+subject over different resources and group them. This is doable because
    64  	// the dispatching system already internally supports this kind of batching for performance.
    65  	groupedItems, err := groupItems(ctx, groupingParameters{
    66  		atRevision:           atRevision,
    67  		maxCaveatContextSize: bc.maxCaveatContextSize,
    68  		maximumAPIDepth:      bc.maxAPIDepth,
    69  	}, req.Items)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	bulkResponseMutex := sync.Mutex{}
    75  
    76  	tr := taskrunner.NewPreloadedTaskRunner(ctx, bc.maxConcurrency, len(groupedItems))
    77  
    78  	respMetadata := &dispatchv1.ResponseMeta{
    79  		DispatchCount:       1,
    80  		CachedDispatchCount: 0,
    81  		DepthRequired:       1,
    82  		DebugInfo:           nil,
    83  	}
    84  	usagemetrics.SetInContext(ctx, respMetadata)
    85  
    86  	orderedPairs := make([]*v1.CheckBulkPermissionsPair, len(req.Items))
    87  
    88  	addPair := func(pair *v1.CheckBulkPermissionsPair) error {
    89  		pairItemHash, err := computeCheckBulkPermissionsItemHash(pair.Request)
    90  		if err != nil {
    91  			return err
    92  		}
    93  
    94  		found, ok := itemIndexByHash.Get(pairItemHash)
    95  		if !ok {
    96  			return spiceerrors.MustBugf("missing expected item hash")
    97  		}
    98  
    99  		for _, index := range found {
   100  			orderedPairs[index] = pair
   101  		}
   102  
   103  		return nil
   104  	}
   105  
   106  	appendResultsForError := func(params *computed.CheckParameters, resourceIDs []string, err error) error {
   107  		rewritten := shared.RewriteError(ctx, err, &shared.ConfigForErrors{
   108  			MaximumAPIDepth: bc.maxAPIDepth,
   109  		})
   110  		statusResp, ok := status.FromError(rewritten)
   111  		if !ok {
   112  			// If error is not a gRPC Status, fail the entire bulk check request.
   113  			return err
   114  		}
   115  
   116  		bulkResponseMutex.Lock()
   117  		defer bulkResponseMutex.Unlock()
   118  
   119  		for _, resourceID := range resourceIDs {
   120  			reqItem, err := requestItemFromResourceAndParameters(params, resourceID)
   121  			if err != nil {
   122  				return err
   123  			}
   124  
   125  			if err := addPair(&v1.CheckBulkPermissionsPair{
   126  				Request: reqItem,
   127  				Response: &v1.CheckBulkPermissionsPair_Error{
   128  					Error: statusResp.Proto(),
   129  				},
   130  			}); err != nil {
   131  				return err
   132  			}
   133  		}
   134  
   135  		return nil
   136  	}
   137  
   138  	appendResultsForCheck := func(params *computed.CheckParameters, resourceIDs []string, metadata *dispatchv1.ResponseMeta, results map[string]*dispatchv1.ResourceCheckResult) error {
   139  		bulkResponseMutex.Lock()
   140  		defer bulkResponseMutex.Unlock()
   141  
   142  		for _, resourceID := range resourceIDs {
   143  			reqItem, err := requestItemFromResourceAndParameters(params, resourceID)
   144  			if err != nil {
   145  				return err
   146  			}
   147  
   148  			if err := addPair(&v1.CheckBulkPermissionsPair{
   149  				Request:  reqItem,
   150  				Response: pairItemFromCheckResult(results[resourceID]),
   151  			}); err != nil {
   152  				return err
   153  			}
   154  		}
   155  
   156  		respMetadata.DispatchCount += metadata.DispatchCount
   157  		respMetadata.CachedDispatchCount += metadata.CachedDispatchCount
   158  		return nil
   159  	}
   160  
   161  	for _, group := range groupedItems {
   162  		group := group
   163  
   164  		slicez.ForEachChunk(group.resourceIDs, MaxBulkCheckDispatchChunkSize, func(resourceIDs []string) {
   165  			tr.Add(func(ctx context.Context) error {
   166  				ds := datastoremw.MustFromContext(ctx).SnapshotReader(atRevision)
   167  
   168  				// Ensure the check namespaces and relations are valid.
   169  				err := namespace.CheckNamespaceAndRelations(ctx,
   170  					[]namespace.TypeAndRelationToCheck{
   171  						{
   172  							NamespaceName: group.params.ResourceType.Namespace,
   173  							RelationName:  group.params.ResourceType.Relation,
   174  							AllowEllipsis: false,
   175  						},
   176  						{
   177  							NamespaceName: group.params.Subject.Namespace,
   178  							RelationName:  stringz.DefaultEmpty(group.params.Subject.Relation, graph.Ellipsis),
   179  							AllowEllipsis: true,
   180  						},
   181  					}, ds)
   182  				if err != nil {
   183  					return appendResultsForError(group.params, resourceIDs, err)
   184  				}
   185  
   186  				// Call bulk check to compute the check result(s) for the resource ID(s).
   187  				rcr, metadata, err := computed.ComputeBulkCheck(ctx, bc.dispatch, *group.params, resourceIDs)
   188  				if err != nil {
   189  					return appendResultsForError(group.params, resourceIDs, err)
   190  				}
   191  
   192  				return appendResultsForCheck(group.params, resourceIDs, metadata, rcr)
   193  			})
   194  		})
   195  	}
   196  
   197  	// Run the checks in parallel.
   198  	if err := tr.StartAndWait(); err != nil {
   199  		return nil, err
   200  	}
   201  
   202  	return &v1.CheckBulkPermissionsResponse{CheckedAt: checkedAt, Pairs: orderedPairs}, nil
   203  }
   204  
   205  func toCheckBulkPermissionsRequest(req *v1.BulkCheckPermissionRequest) *v1.CheckBulkPermissionsRequest {
   206  	items := make([]*v1.CheckBulkPermissionsRequestItem, len(req.Items))
   207  	for i, item := range req.Items {
   208  		items[i] = &v1.CheckBulkPermissionsRequestItem{
   209  			Resource:   item.Resource,
   210  			Permission: item.Permission,
   211  			Subject:    item.Subject,
   212  			Context:    item.Context,
   213  		}
   214  	}
   215  
   216  	return &v1.CheckBulkPermissionsRequest{Items: items}
   217  }
   218  
   219  func toBulkCheckPermissionResponse(resp *v1.CheckBulkPermissionsResponse) *v1.BulkCheckPermissionResponse {
   220  	pairs := make([]*v1.BulkCheckPermissionPair, len(resp.Pairs))
   221  	for i, pair := range resp.Pairs {
   222  		pairs[i] = &v1.BulkCheckPermissionPair{}
   223  		pairs[i].Request = &v1.BulkCheckPermissionRequestItem{
   224  			Resource:   pair.Request.Resource,
   225  			Permission: pair.Request.Permission,
   226  			Subject:    pair.Request.Subject,
   227  			Context:    pair.Request.Context,
   228  		}
   229  
   230  		switch t := pair.Response.(type) {
   231  		case *v1.CheckBulkPermissionsPair_Item:
   232  			pairs[i].Response = &v1.BulkCheckPermissionPair_Item{
   233  				Item: &v1.BulkCheckPermissionResponseItem{
   234  					Permissionship:    t.Item.Permissionship,
   235  					PartialCaveatInfo: t.Item.PartialCaveatInfo,
   236  				},
   237  			}
   238  		case *v1.CheckBulkPermissionsPair_Error:
   239  			pairs[i].Response = &v1.BulkCheckPermissionPair_Error{
   240  				Error: t.Error,
   241  			}
   242  		default:
   243  			panic("unknown CheckBulkPermissionResponse pair response type")
   244  		}
   245  	}
   246  
   247  	return &v1.BulkCheckPermissionResponse{
   248  		CheckedAt: resp.CheckedAt,
   249  		Pairs:     pairs,
   250  	}
   251  }