github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/remote/expr.go (about)

     1  package remote
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/authzed/cel-go/cel"
     7  	"github.com/authzed/cel-go/common"
     8  	"github.com/authzed/cel-go/common/types"
     9  	"github.com/authzed/cel-go/common/types/ref"
    10  	"google.golang.org/protobuf/proto"
    11  
    12  	corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
    13  	dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    14  )
    15  
    16  // DispatchExpr is a CEL expression that can be run to determine the secondary dispatchers, if any,
    17  // to invoke for the incoming request.
    18  type DispatchExpr struct {
    19  	env        *cel.Env
    20  	registry   *types.Registry
    21  	methodName string
    22  	exprAst    *cel.Ast
    23  }
    24  
    25  var dispatchRequestTypes = []proto.Message{
    26  	&dispatchv1.DispatchCheckRequest{},
    27  	&corev1.RelationReference{},
    28  	&corev1.ObjectAndRelation{},
    29  }
    30  
    31  // ParseDispatchExpression parses a dispatch expression via CEL.
    32  func ParseDispatchExpression(methodName string, exprString string) (*DispatchExpr, error) {
    33  	registry, err := types.NewRegistry(dispatchRequestTypes...)
    34  	if err != nil {
    35  		return nil, fmt.Errorf("unable to initialize dispatch expression type registry")
    36  	}
    37  
    38  	opts := make([]cel.EnvOption, 0)
    39  	opts = append(opts, cel.OptionalTypes(cel.OptionalTypesVersion(0)))
    40  	opts = append(opts, cel.Variable("request", cel.DynType))
    41  
    42  	celEnv, err := cel.NewEnv(opts...)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	ast, issues := celEnv.CompileSource(common.NewStringSource(exprString, methodName))
    48  	if issues != nil && issues.Err() != nil {
    49  		return nil, issues.Err()
    50  	}
    51  
    52  	if !ast.OutputType().IsEquivalentType(cel.ListType(cel.StringType)) {
    53  		return nil, fmt.Errorf("dispatch expression must result in a list[string] value: found `%s`", ast.OutputType().String())
    54  	}
    55  
    56  	return &DispatchExpr{
    57  		env:        celEnv,
    58  		registry:   registry,
    59  		methodName: methodName,
    60  		exprAst:    ast,
    61  	}, nil
    62  }
    63  
    64  // RunDispatchExpr runs a dispatch CEL expression over the given request and returns the secondary dispatchers
    65  // to invoke, if any.
    66  func RunDispatchExpr[R any](de *DispatchExpr, request R) ([]string, error) {
    67  	celopts := make([]cel.ProgramOption, 0, 3)
    68  
    69  	celopts = append(celopts, cel.EvalOptions(cel.OptTrackState))
    70  	celopts = append(celopts, cel.EvalOptions(cel.OptPartialEval))
    71  	celopts = append(celopts, cel.CostLimit(50))
    72  
    73  	prg, err := de.env.Program(de.exprAst, celopts...)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	// Mark any unspecified variables as unknown, to ensure that partial application
    79  	// will result in producing a type of Unknown.
    80  	activation, err := de.env.PartialVars(map[string]any{
    81  		"request": de.registry.NativeToValue(request),
    82  	})
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	val, _, err := prg.Eval(activation)
    88  	if err != nil {
    89  		return nil, fmt.Errorf("unable to evaluate dispatch expression: %w", err)
    90  	}
    91  
    92  	// If the value produced has Unknown type, then it means required context was missing.
    93  	if types.IsUnknown(val) {
    94  		return nil, fmt.Errorf("unable to eval dispatch expression; did you make sure you use `request.`?")
    95  	}
    96  
    97  	values := val.Value().([]ref.Val)
    98  	convertedValues := make([]string, 0, len(values))
    99  	for _, value := range values {
   100  		convertedValues = append(convertedValues, value.Value().(string))
   101  	}
   102  	return convertedValues, nil
   103  }