github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/caveats/compile.go (about)

     1  package caveats
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/authzed/cel-go/cel"
     8  	"github.com/authzed/cel-go/common"
     9  	"google.golang.org/protobuf/proto"
    10  
    11  	"github.com/authzed/spicedb/pkg/caveats/types"
    12  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    13  	impl "github.com/authzed/spicedb/pkg/proto/impl/v1"
    14  )
    15  
    16  const anonymousCaveat = ""
    17  
    18  const maxCaveatExpressionSize = 100_000 // characters
    19  
    20  // CompiledCaveat is a compiled form of a caveat.
    21  type CompiledCaveat struct {
    22  	// env is the environment under which the CEL program was compiled.
    23  	celEnv *cel.Env
    24  
    25  	// ast is the AST form of the CEL program.
    26  	ast *cel.Ast
    27  
    28  	// name of the caveat
    29  	name string
    30  }
    31  
    32  // Name represents a user-friendly reference to a caveat
    33  func (cc CompiledCaveat) Name() string {
    34  	return cc.name
    35  }
    36  
    37  // ExprString returns the string-form of the caveat.
    38  func (cc CompiledCaveat) ExprString() (string, error) {
    39  	return cel.AstToString(cc.ast)
    40  }
    41  
    42  // Serialize serializes the compiled caveat into a byte string for storage.
    43  func (cc CompiledCaveat) Serialize() ([]byte, error) {
    44  	cexpr, err := cel.AstToCheckedExpr(cc.ast)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	caveat := &impl.DecodedCaveat{
    50  		KindOneof: &impl.DecodedCaveat_Cel{
    51  			Cel: cexpr,
    52  		},
    53  		Name: cc.name,
    54  	}
    55  
    56  	// TODO(jschorr): change back to MarshalVT once stable is supported.
    57  	// See: https://github.com/planetscale/vtprotobuf/pull/133
    58  	return proto.MarshalOptions{Deterministic: true}.Marshal(caveat)
    59  }
    60  
    61  // ReferencedParameters returns the names of the parameters referenced in the expression.
    62  func (cc CompiledCaveat) ReferencedParameters(parameters []string) (*mapz.Set[string], error) {
    63  	referencedParams := mapz.NewSet[string]()
    64  	definedParameters := mapz.NewSet[string]()
    65  	definedParameters.Extend(parameters)
    66  
    67  	checked, err := cel.AstToCheckedExpr(cc.ast)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	referencedParameters(definedParameters, checked.Expr, referencedParams)
    73  	return referencedParams, nil
    74  }
    75  
    76  // CompileCaveatWithName compiles a caveat string into a compiled caveat with a given name,
    77  // or returns the compilation errors.
    78  func CompileCaveatWithName(env *Environment, exprString, name string) (*CompiledCaveat, error) {
    79  	c, err := CompileCaveatWithSource(env, name, common.NewStringSource(exprString, name), nil)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	c.name = name
    84  	return c, nil
    85  }
    86  
    87  // CompileCaveatWithSource compiles a caveat source into a compiled caveat, or returns the compilation errors.
    88  func CompileCaveatWithSource(env *Environment, name string, source common.Source, startPosition SourcePosition) (*CompiledCaveat, error) {
    89  	celEnv, err := env.asCelEnvironment()
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	if len(strings.TrimSpace(source.Content())) > maxCaveatExpressionSize {
    95  		return nil, fmt.Errorf("caveat expression provided exceeds maximum allowed size of %d characters", maxCaveatExpressionSize)
    96  	}
    97  
    98  	ast, issues := celEnv.CompileSource(source)
    99  	if issues != nil && issues.Err() != nil {
   100  		if startPosition == nil {
   101  			return nil, CompilationErrors{issues.Err(), issues}
   102  		}
   103  
   104  		// Construct errors with the source location adjusted based on the starting source position
   105  		// in the parent schema (if any). This ensures that the errors coming out of CEL show the correct
   106  		// *overall* location information..
   107  		line, col, err := startPosition.LineAndColumn()
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  
   112  		adjustedErrors := common.NewErrors(source)
   113  		for _, existingErr := range issues.Errors() {
   114  			location := existingErr.Location
   115  
   116  			// NOTE: Our locations are zero-indexed while CEL is 1-indexed, so we need to adjust the line/column values accordingly.
   117  			if location.Line() == 1 {
   118  				location = common.NewLocation(line+location.Line(), col+location.Column())
   119  			} else {
   120  				location = common.NewLocation(line+location.Line(), location.Column())
   121  			}
   122  
   123  			adjustedError := &common.Error{
   124  				Message:  existingErr.Message,
   125  				ExprID:   existingErr.ExprID,
   126  				Location: location,
   127  			}
   128  
   129  			adjustedErrors = adjustedErrors.Append([]*common.Error{
   130  				adjustedError,
   131  			})
   132  		}
   133  
   134  		adjustedIssues := cel.NewIssues(adjustedErrors)
   135  		return nil, CompilationErrors{adjustedIssues.Err(), adjustedIssues}
   136  	}
   137  
   138  	if ast.OutputType() != cel.BoolType {
   139  		return nil, CompilationErrors{fmt.Errorf("caveat expression must result in a boolean value: found `%s`", ast.OutputType().String()), nil}
   140  	}
   141  
   142  	compiled := &CompiledCaveat{celEnv, ast, anonymousCaveat}
   143  	compiled.name = name
   144  	return compiled, nil
   145  }
   146  
   147  // compileCaveat compiles a caveat string into a compiled caveat, or returns the compilation errors.
   148  func compileCaveat(env *Environment, exprString string) (*CompiledCaveat, error) {
   149  	s := common.NewStringSource(exprString, "caveat")
   150  	return CompileCaveatWithSource(env, "caveat", s, nil)
   151  }
   152  
   153  // DeserializeCaveat deserializes a byte-serialized caveat back into a CompiledCaveat.
   154  func DeserializeCaveat(serialized []byte, parameterTypes map[string]types.VariableType) (*CompiledCaveat, error) {
   155  	if len(serialized) == 0 {
   156  		return nil, fmt.Errorf("given empty serialized")
   157  	}
   158  
   159  	caveat := &impl.DecodedCaveat{}
   160  	err := caveat.UnmarshalVT(serialized)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	env, err := EnvForVariables(parameterTypes)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	celEnv, err := env.asCelEnvironment()
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	ast := cel.CheckedExprToAst(caveat.GetCel())
   176  	return &CompiledCaveat{celEnv, ast, caveat.Name}, nil
   177  }