github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/schemadsl/generator/generator.go (about)

     1  package generator
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"golang.org/x/exp/maps"
    10  
    11  	"github.com/authzed/spicedb/pkg/caveats"
    12  	caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
    13  	"github.com/authzed/spicedb/pkg/graph"
    14  	"github.com/authzed/spicedb/pkg/namespace"
    15  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    16  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    17  	"github.com/authzed/spicedb/pkg/spiceerrors"
    18  )
    19  
    20  // Ellipsis is the relation name for terminal subjects.
    21  const Ellipsis = "..."
    22  
    23  // MaxSingleLineCommentLength sets the maximum length for a comment to made single line.
    24  const MaxSingleLineCommentLength = 70 // 80 - the comment parts and some padding
    25  
    26  // GenerateSchema generates a DSL view of the given schema.
    27  func GenerateSchema(definitions []compiler.SchemaDefinition) (string, bool, error) {
    28  	generated := make([]string, 0, len(definitions))
    29  	result := true
    30  	for _, definition := range definitions {
    31  		switch def := definition.(type) {
    32  		case *core.CaveatDefinition:
    33  			generatedCaveat, ok, err := GenerateCaveatSource(def)
    34  			if err != nil {
    35  				return "", false, err
    36  			}
    37  
    38  			result = result && ok
    39  			generated = append(generated, generatedCaveat)
    40  
    41  		case *core.NamespaceDefinition:
    42  			generatedSchema, ok, err := GenerateSource(def)
    43  			if err != nil {
    44  				return "", false, err
    45  			}
    46  
    47  			result = result && ok
    48  			generated = append(generated, generatedSchema)
    49  
    50  		default:
    51  			return "", false, spiceerrors.MustBugf("unknown type of definition %T in GenerateSchema", def)
    52  		}
    53  	}
    54  
    55  	return strings.Join(generated, "\n\n"), result, nil
    56  }
    57  
    58  // GenerateCaveatSource generates a DSL view of the given caveat definition.
    59  func GenerateCaveatSource(caveat *core.CaveatDefinition) (string, bool, error) {
    60  	generator := &sourceGenerator{
    61  		indentationLevel: 0,
    62  		hasNewline:       true,
    63  		hasBlankline:     true,
    64  		hasNewScope:      true,
    65  	}
    66  
    67  	err := generator.emitCaveat(caveat)
    68  	if err != nil {
    69  		return "", false, err
    70  	}
    71  
    72  	return generator.buf.String(), !generator.hasIssue, nil
    73  }
    74  
    75  // GenerateSource generates a DSL view of the given namespace definition.
    76  func GenerateSource(namespace *core.NamespaceDefinition) (string, bool, error) {
    77  	generator := &sourceGenerator{
    78  		indentationLevel: 0,
    79  		hasNewline:       true,
    80  		hasBlankline:     true,
    81  		hasNewScope:      true,
    82  	}
    83  
    84  	err := generator.emitNamespace(namespace)
    85  	if err != nil {
    86  		return "", false, err
    87  	}
    88  
    89  	return generator.buf.String(), !generator.hasIssue, nil
    90  }
    91  
    92  // GenerateRelationSource generates a DSL view of the given relation definition.
    93  func GenerateRelationSource(relation *core.Relation) (string, error) {
    94  	generator := &sourceGenerator{
    95  		indentationLevel: 0,
    96  		hasNewline:       true,
    97  		hasBlankline:     true,
    98  		hasNewScope:      true,
    99  	}
   100  
   101  	err := generator.emitRelation(relation)
   102  	if err != nil {
   103  		return "", err
   104  	}
   105  
   106  	return generator.buf.String(), nil
   107  }
   108  
   109  func (sg *sourceGenerator) emitCaveat(caveat *core.CaveatDefinition) error {
   110  	sg.emitComments(caveat.Metadata)
   111  	sg.append("caveat ")
   112  	sg.append(caveat.Name)
   113  	sg.append("(")
   114  
   115  	parameterNames := maps.Keys(caveat.ParameterTypes)
   116  	sort.Strings(parameterNames)
   117  
   118  	for index, paramName := range parameterNames {
   119  		if index > 0 {
   120  			sg.append(", ")
   121  		}
   122  
   123  		decoded, err := caveattypes.DecodeParameterType(caveat.ParameterTypes[paramName])
   124  		if err != nil {
   125  			return fmt.Errorf("invalid parameter type on caveat: %w", err)
   126  		}
   127  
   128  		sg.append(paramName)
   129  		sg.append(" ")
   130  		sg.append(decoded.String())
   131  	}
   132  
   133  	sg.append(")")
   134  
   135  	sg.append(" {")
   136  	sg.appendLine()
   137  	sg.indent()
   138  	sg.markNewScope()
   139  
   140  	parameterTypes, err := caveattypes.DecodeParameterTypes(caveat.ParameterTypes)
   141  	if err != nil {
   142  		return fmt.Errorf("invalid caveat parameters: %w", err)
   143  	}
   144  
   145  	deserializedExpression, err := caveats.DeserializeCaveat(caveat.SerializedExpression, parameterTypes)
   146  	if err != nil {
   147  		return fmt.Errorf("invalid caveat expression bytes: %w", err)
   148  	}
   149  
   150  	exprString, err := deserializedExpression.ExprString()
   151  	if err != nil {
   152  		return fmt.Errorf("invalid caveat expression: %w", err)
   153  	}
   154  
   155  	sg.append(strings.TrimSpace(exprString))
   156  	sg.appendLine()
   157  
   158  	sg.dedent()
   159  	sg.append("}")
   160  	return nil
   161  }
   162  
   163  func (sg *sourceGenerator) emitNamespace(namespace *core.NamespaceDefinition) error {
   164  	sg.emitComments(namespace.Metadata)
   165  	sg.append("definition ")
   166  	sg.append(namespace.Name)
   167  
   168  	if len(namespace.Relation) == 0 {
   169  		sg.append(" {}")
   170  		return nil
   171  	}
   172  
   173  	sg.append(" {")
   174  	sg.appendLine()
   175  	sg.indent()
   176  	sg.markNewScope()
   177  
   178  	for _, relation := range namespace.Relation {
   179  		err := sg.emitRelation(relation)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  
   185  	sg.dedent()
   186  	sg.append("}")
   187  	return nil
   188  }
   189  
   190  func (sg *sourceGenerator) emitRelation(relation *core.Relation) error {
   191  	hasThis, err := graph.HasThis(relation.UsersetRewrite)
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	isPermission := relation.UsersetRewrite != nil && !hasThis
   197  
   198  	sg.emitComments(relation.Metadata)
   199  	if isPermission {
   200  		sg.append("permission ")
   201  	} else {
   202  		sg.append("relation ")
   203  	}
   204  
   205  	sg.append(relation.Name)
   206  
   207  	if !isPermission {
   208  		sg.append(": ")
   209  		if relation.TypeInformation == nil || relation.TypeInformation.AllowedDirectRelations == nil || len(relation.TypeInformation.AllowedDirectRelations) == 0 {
   210  			sg.appendIssue("missing allowed types")
   211  		} else {
   212  			for index, allowedRelation := range relation.TypeInformation.AllowedDirectRelations {
   213  				if index > 0 {
   214  					sg.append(" | ")
   215  				}
   216  
   217  				sg.emitAllowedRelation(allowedRelation)
   218  			}
   219  		}
   220  	}
   221  
   222  	if relation.UsersetRewrite != nil {
   223  		sg.append(" = ")
   224  		sg.emitRewrite(relation.UsersetRewrite)
   225  	}
   226  
   227  	sg.appendLine()
   228  	return nil
   229  }
   230  
   231  func (sg *sourceGenerator) emitAllowedRelation(allowedRelation *core.AllowedRelation) {
   232  	sg.append(allowedRelation.Namespace)
   233  	if allowedRelation.GetRelation() != "" && allowedRelation.GetRelation() != Ellipsis {
   234  		sg.append("#")
   235  		sg.append(allowedRelation.GetRelation())
   236  	}
   237  	if allowedRelation.GetPublicWildcard() != nil {
   238  		sg.append(":*")
   239  	}
   240  	if allowedRelation.GetRequiredCaveat() != nil {
   241  		sg.append(" with ")
   242  		sg.append(allowedRelation.RequiredCaveat.CaveatName)
   243  	}
   244  }
   245  
   246  func (sg *sourceGenerator) emitRewrite(rewrite *core.UsersetRewrite) {
   247  	switch rw := rewrite.RewriteOperation.(type) {
   248  	case *core.UsersetRewrite_Union:
   249  		sg.emitRewriteOps(rw.Union, "+")
   250  	case *core.UsersetRewrite_Intersection:
   251  		sg.emitRewriteOps(rw.Intersection, "&")
   252  	case *core.UsersetRewrite_Exclusion:
   253  		sg.emitRewriteOps(rw.Exclusion, "-")
   254  	}
   255  }
   256  
   257  func (sg *sourceGenerator) emitRewriteOps(setOp *core.SetOperation, op string) {
   258  	for index, child := range setOp.Child {
   259  		if index > 0 {
   260  			sg.append(" " + op + " ")
   261  		}
   262  
   263  		sg.emitSetOpChild(child)
   264  	}
   265  }
   266  
   267  func (sg *sourceGenerator) isAllUnion(rewrite *core.UsersetRewrite) bool {
   268  	switch rw := rewrite.RewriteOperation.(type) {
   269  	case *core.UsersetRewrite_Union:
   270  		for _, setOpChild := range rw.Union.Child {
   271  			switch child := setOpChild.ChildType.(type) {
   272  			case *core.SetOperation_Child_UsersetRewrite:
   273  				if !sg.isAllUnion(child.UsersetRewrite) {
   274  					return false
   275  				}
   276  			default:
   277  				continue
   278  			}
   279  		}
   280  		return true
   281  	default:
   282  		return false
   283  	}
   284  }
   285  
   286  func (sg *sourceGenerator) emitSetOpChild(setOpChild *core.SetOperation_Child) {
   287  	switch child := setOpChild.ChildType.(type) {
   288  	case *core.SetOperation_Child_UsersetRewrite:
   289  		if sg.isAllUnion(child.UsersetRewrite) {
   290  			sg.emitRewrite(child.UsersetRewrite)
   291  			break
   292  		}
   293  
   294  		sg.append("(")
   295  		sg.emitRewrite(child.UsersetRewrite)
   296  		sg.append(")")
   297  
   298  	case *core.SetOperation_Child_XThis:
   299  		sg.appendIssue("_this unsupported here. Please rewrite into a relation and permission")
   300  
   301  	case *core.SetOperation_Child_XNil:
   302  		sg.append("nil")
   303  
   304  	case *core.SetOperation_Child_ComputedUserset:
   305  		sg.append(child.ComputedUserset.Relation)
   306  
   307  	case *core.SetOperation_Child_TupleToUserset:
   308  		sg.append(child.TupleToUserset.Tupleset.Relation)
   309  		sg.append("->")
   310  		sg.append(child.TupleToUserset.ComputedUserset.Relation)
   311  	}
   312  }
   313  
   314  func (sg *sourceGenerator) emitComments(metadata *core.Metadata) {
   315  	if len(namespace.GetComments(metadata)) > 0 {
   316  		sg.ensureBlankLineOrNewScope()
   317  	}
   318  
   319  	for _, comment := range namespace.GetComments(metadata) {
   320  		sg.appendComment(comment)
   321  	}
   322  }
   323  
   324  func (sg *sourceGenerator) appendComment(comment string) {
   325  	switch {
   326  	case strings.HasPrefix(comment, "/*"):
   327  		stripped := strings.TrimSpace(comment)
   328  
   329  		if strings.HasPrefix(stripped, "/**") {
   330  			stripped = strings.TrimPrefix(stripped, "/**")
   331  			sg.append("/**")
   332  		} else {
   333  			stripped = strings.TrimPrefix(stripped, "/*")
   334  			sg.append("/*")
   335  		}
   336  
   337  		stripped = strings.TrimSuffix(stripped, "*/")
   338  		stripped = strings.TrimSpace(stripped)
   339  
   340  		requireMultiline := len(stripped) > MaxSingleLineCommentLength || strings.ContainsRune(stripped, '\n')
   341  
   342  		if requireMultiline {
   343  			sg.appendLine()
   344  			scanner := bufio.NewScanner(strings.NewReader(stripped))
   345  			for scanner.Scan() {
   346  				sg.append(" * ")
   347  				sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(scanner.Text()), "*")))
   348  				sg.appendLine()
   349  			}
   350  			sg.append(" */")
   351  			sg.appendLine()
   352  		} else {
   353  			sg.append(" ")
   354  			sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(stripped), "*")))
   355  			sg.append(" */")
   356  			sg.appendLine()
   357  		}
   358  
   359  	case strings.HasPrefix(comment, "//"):
   360  		sg.append("// ")
   361  		sg.append(strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(comment), "//")))
   362  		sg.appendLine()
   363  	}
   364  }