github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/caveats/run.go (about)

     1  package caveats
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"maps"
     8  	"strings"
     9  
    10  	"google.golang.org/protobuf/types/known/structpb"
    11  
    12  	"github.com/authzed/spicedb/pkg/caveats"
    13  	caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
    14  	"github.com/authzed/spicedb/pkg/datastore"
    15  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    16  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    17  	"github.com/authzed/spicedb/pkg/spiceerrors"
    18  )
    19  
    20  // RunCaveatExpressionDebugOption are the options for running caveat expression evaluation
    21  // with debugging enabled or disabled.
    22  type RunCaveatExpressionDebugOption int
    23  
    24  const (
    25  	// RunCaveatExpressionNoDebugging runs the evaluation without debugging enabled.
    26  	RunCaveatExpressionNoDebugging RunCaveatExpressionDebugOption = 0
    27  
    28  	// RunCaveatExpressionNoDebugging runs the evaluation with debugging enabled.
    29  	RunCaveatExpressionWithDebugInformation RunCaveatExpressionDebugOption = 1
    30  )
    31  
    32  // RunCaveatExpression runs a caveat expression over the given context and returns the result.
    33  func RunCaveatExpression(
    34  	ctx context.Context,
    35  	expr *core.CaveatExpression,
    36  	context map[string]any,
    37  	reader datastore.CaveatReader,
    38  	debugOption RunCaveatExpressionDebugOption,
    39  ) (ExpressionResult, error) {
    40  	env := caveats.NewEnvironment()
    41  	return runExpression(ctx, env, expr, context, reader, debugOption)
    42  }
    43  
    44  // ExpressionResult is the result of a caveat expression being run.
    45  type ExpressionResult interface {
    46  	// Value is the resolved value for the expression. For partially applied expressions, this value will be false.
    47  	Value() bool
    48  
    49  	// IsPartial returns whether the expression was only partially applied.
    50  	IsPartial() bool
    51  
    52  	// MissingVarNames returns the names of the parameters missing from the context.
    53  	MissingVarNames() ([]string, error)
    54  
    55  	// ContextValues returns the context values used when computing this result.
    56  	ContextValues() map[string]any
    57  
    58  	// ContextStruct returns the context values as a structpb Struct.
    59  	ContextStruct() (*structpb.Struct, error)
    60  
    61  	// ExpressionString returns the human-readable expression for the caveat expression.
    62  	ExpressionString() (string, error)
    63  }
    64  
    65  type syntheticResult struct {
    66  	value         bool
    67  	contextValues map[string]any
    68  	exprString    string
    69  }
    70  
    71  func (sr syntheticResult) Value() bool {
    72  	return sr.value
    73  }
    74  
    75  func (sr syntheticResult) IsPartial() bool {
    76  	return false
    77  }
    78  
    79  func (sr syntheticResult) MissingVarNames() ([]string, error) {
    80  	return nil, fmt.Errorf("not a partial value")
    81  }
    82  
    83  func (sr syntheticResult) ContextValues() map[string]any {
    84  	return sr.contextValues
    85  }
    86  
    87  func (sr syntheticResult) ContextStruct() (*structpb.Struct, error) {
    88  	return caveats.ConvertContextToStruct(sr.contextValues)
    89  }
    90  
    91  func (sr syntheticResult) ExpressionString() (string, error) {
    92  	return sr.exprString, nil
    93  }
    94  
    95  func runExpression(
    96  	ctx context.Context,
    97  	env *caveats.Environment,
    98  	expr *core.CaveatExpression,
    99  	context map[string]any,
   100  	reader datastore.CaveatReader,
   101  	debugOption RunCaveatExpressionDebugOption,
   102  ) (ExpressionResult, error) {
   103  	// Collect all referenced caveat definitions in the expression.
   104  	caveatNames := mapz.NewSet[string]()
   105  	collectCaveatNames(expr, caveatNames)
   106  
   107  	if caveatNames.IsEmpty() {
   108  		return nil, fmt.Errorf("received empty caveat expression")
   109  	}
   110  
   111  	// Bulk lookup all of the referenced caveat definitions.
   112  	caveatDefs, err := reader.LookupCaveatsWithNames(ctx, caveatNames.AsSlice())
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	lc := loadedCaveats{
   118  		caveatDefs:          map[string]*core.CaveatDefinition{},
   119  		deserializedCaveats: map[string]*caveats.CompiledCaveat{},
   120  	}
   121  
   122  	for _, cd := range caveatDefs {
   123  		lc.caveatDefs[cd.Definition.GetName()] = cd.Definition
   124  	}
   125  
   126  	return runExpressionWithCaveats(ctx, env, expr, context, lc, debugOption)
   127  }
   128  
   129  type loadedCaveats struct {
   130  	caveatDefs          map[string]*core.CaveatDefinition
   131  	deserializedCaveats map[string]*caveats.CompiledCaveat
   132  }
   133  
   134  func (lc loadedCaveats) Get(caveatDefName string) (*core.CaveatDefinition, *caveats.CompiledCaveat, error) {
   135  	caveat, ok := lc.caveatDefs[caveatDefName]
   136  	if !ok {
   137  		return nil, nil, datastore.NewCaveatNameNotFoundErr(caveatDefName)
   138  	}
   139  
   140  	deserialized, ok := lc.deserializedCaveats[caveatDefName]
   141  	if ok {
   142  		return caveat, deserialized, nil
   143  	}
   144  
   145  	parameterTypes, err := caveattypes.DecodeParameterTypes(caveat.ParameterTypes)
   146  	if err != nil {
   147  		return nil, nil, err
   148  	}
   149  
   150  	justDeserialized, err := caveats.DeserializeCaveat(caveat.SerializedExpression, parameterTypes)
   151  	if err != nil {
   152  		return caveat, nil, err
   153  	}
   154  
   155  	lc.deserializedCaveats[caveatDefName] = justDeserialized
   156  	return caveat, justDeserialized, nil
   157  }
   158  
   159  func runExpressionWithCaveats(
   160  	ctx context.Context,
   161  	env *caveats.Environment,
   162  	expr *core.CaveatExpression,
   163  	context map[string]any,
   164  	loadedCaveats loadedCaveats,
   165  	debugOption RunCaveatExpressionDebugOption,
   166  ) (ExpressionResult, error) {
   167  	if expr.GetCaveat() != nil {
   168  		caveat, compiled, err := loadedCaveats.Get(expr.GetCaveat().CaveatName)
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  
   173  		// Create a combined context, with the written context taking precedence over that specified.
   174  		untypedFullContext := maps.Clone(context)
   175  		if untypedFullContext == nil {
   176  			untypedFullContext = map[string]any{}
   177  		}
   178  
   179  		relationshipContext := expr.GetCaveat().GetContext().AsMap()
   180  		maps.Copy(untypedFullContext, relationshipContext)
   181  
   182  		// Perform type checking and conversion on the context map.
   183  		typedParameters, err := caveats.ConvertContextToParameters(
   184  			untypedFullContext,
   185  			caveat.ParameterTypes,
   186  			caveats.SkipUnknownParameters,
   187  		)
   188  		if err != nil {
   189  			return nil, NewParameterTypeError(expr, err)
   190  		}
   191  
   192  		result, err := caveats.EvaluateCaveat(compiled, typedParameters)
   193  		if err != nil {
   194  			var evalErr caveats.EvaluationErr
   195  			if errors.As(err, &evalErr) {
   196  				return nil, NewEvaluationErr(expr, evalErr)
   197  			}
   198  
   199  			return nil, err
   200  		}
   201  
   202  		return result, nil
   203  	}
   204  
   205  	cop := expr.GetOperation()
   206  	boolResult := false
   207  	if cop.Op == core.CaveatOperation_AND {
   208  		boolResult = true
   209  	}
   210  
   211  	var contextValues map[string]any
   212  	var exprStringPieces []string
   213  
   214  	buildExprString := func() (string, error) {
   215  		switch cop.Op {
   216  		case core.CaveatOperation_AND:
   217  			return strings.Join(exprStringPieces, " && "), nil
   218  
   219  		case core.CaveatOperation_OR:
   220  			return strings.Join(exprStringPieces, " || "), nil
   221  
   222  		case core.CaveatOperation_NOT:
   223  			return strings.Join(exprStringPieces, " "), nil
   224  
   225  		default:
   226  			return "", spiceerrors.MustBugf("unknown caveat operation: %v", cop.Op)
   227  		}
   228  	}
   229  
   230  	for _, child := range cop.Children {
   231  		childResult, err := runExpressionWithCaveats(ctx, env, child, context, loadedCaveats, debugOption)
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  
   236  		if childResult.IsPartial() {
   237  			return childResult, nil
   238  		}
   239  
   240  		switch cop.Op {
   241  		case core.CaveatOperation_AND:
   242  			boolResult = boolResult && childResult.Value()
   243  
   244  			if debugOption == RunCaveatExpressionWithDebugInformation {
   245  				contextValues = combineMaps(contextValues, childResult.ContextValues())
   246  				exprString, err := childResult.ExpressionString()
   247  				if err != nil {
   248  					return nil, err
   249  				}
   250  
   251  				exprStringPieces = append(exprStringPieces, exprString)
   252  			}
   253  
   254  			if !boolResult {
   255  				built, err := buildExprString()
   256  				if err != nil {
   257  					return nil, err
   258  				}
   259  
   260  				return syntheticResult{false, contextValues, built}, nil
   261  			}
   262  
   263  		case core.CaveatOperation_OR:
   264  			boolResult = boolResult || childResult.Value()
   265  
   266  			if debugOption == RunCaveatExpressionWithDebugInformation {
   267  				contextValues = combineMaps(contextValues, childResult.ContextValues())
   268  				exprString, err := childResult.ExpressionString()
   269  				if err != nil {
   270  					return nil, err
   271  				}
   272  
   273  				exprStringPieces = append(exprStringPieces, exprString)
   274  			}
   275  
   276  			if boolResult {
   277  				built, err := buildExprString()
   278  				if err != nil {
   279  					return nil, err
   280  				}
   281  
   282  				return syntheticResult{true, contextValues, built}, nil
   283  			}
   284  
   285  		case core.CaveatOperation_NOT:
   286  			if debugOption == RunCaveatExpressionWithDebugInformation {
   287  				contextValues = combineMaps(contextValues, childResult.ContextValues())
   288  				exprString, err := childResult.ExpressionString()
   289  				if err != nil {
   290  					return nil, err
   291  				}
   292  
   293  				exprStringPieces = append(exprStringPieces, "!("+exprString+")")
   294  			}
   295  
   296  			built, err := buildExprString()
   297  			if err != nil {
   298  				return nil, err
   299  			}
   300  
   301  			return syntheticResult{!childResult.Value(), contextValues, built}, nil
   302  
   303  		default:
   304  			return nil, spiceerrors.MustBugf("unknown caveat operation: %v", cop.Op)
   305  		}
   306  	}
   307  
   308  	built, err := buildExprString()
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  
   313  	return syntheticResult{boolResult, contextValues, built}, nil
   314  }
   315  
   316  func combineMaps(first map[string]any, second map[string]any) map[string]any {
   317  	if first == nil {
   318  		first = make(map[string]any, len(second))
   319  	}
   320  
   321  	cloned := maps.Clone(first)
   322  	maps.Copy(cloned, second)
   323  	return cloned
   324  }
   325  
   326  func collectCaveatNames(expr *core.CaveatExpression, caveatNames *mapz.Set[string]) {
   327  	if expr.GetCaveat() != nil {
   328  		caveatNames.Add(expr.GetCaveat().CaveatName)
   329  		return
   330  	}
   331  
   332  	cop := expr.GetOperation()
   333  	for _, child := range cop.Children {
   334  		collectCaveatNames(child, caveatNames)
   335  	}
   336  }