github.com/openfga/openfga@v1.5.4-rc1/tests/check/check.go (about)

     1  // Package check contains integration tests for the Check API.
     2  package check
     3  
     4  import (
     5  	"context"
     6  	"fmt"
     7  	"math"
     8  	"testing"
     9  
    10  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    11  	"github.com/stretchr/testify/require"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/status"
    14  	"sigs.k8s.io/yaml"
    15  
    16  	"github.com/openfga/openfga/pkg/testutils"
    17  
    18  	"github.com/openfga/openfga/assets"
    19  	checktest "github.com/openfga/openfga/internal/test/check"
    20  	"github.com/openfga/openfga/pkg/typesystem"
    21  	"github.com/openfga/openfga/tests"
    22  )
    23  
    24  var writeMaxChunkSize = 40 // chunk write requests into a chunks of this max size
    25  
    26  type individualTest struct {
    27  	Name   string
    28  	Stages []*stage
    29  }
    30  
    31  type checkTests struct {
    32  	Tests []individualTest
    33  }
    34  
    35  type testParams struct {
    36  	schemaVersion string
    37  	client        ClientInterface
    38  }
    39  
    40  // stage is a stage of a test. All stages will be run in a single store.
    41  type stage struct {
    42  	Model           string
    43  	Tuples          []*openfgav1.TupleKey
    44  	CheckAssertions []*checktest.Assertion `json:"checkAssertions"`
    45  }
    46  
    47  // ClientInterface defines client interface for running check tests.
    48  type ClientInterface interface {
    49  	tests.TestClientBootstrapper
    50  	Check(ctx context.Context, in *openfgav1.CheckRequest, opts ...grpc.CallOption) (*openfgav1.CheckResponse, error)
    51  }
    52  
    53  // RunAllTests will run all check tests.
    54  func RunAllTests(t *testing.T, client ClientInterface) {
    55  	t.Run("RunAllTests", func(t *testing.T) {
    56  		t.Run("Check", func(t *testing.T) {
    57  			t.Parallel()
    58  			runTests(t, testParams{typesystem.SchemaVersion1_1, client})
    59  		})
    60  	})
    61  }
    62  
    63  func runTests(t *testing.T, params testParams) {
    64  	files := []string{
    65  		"tests/consolidated_1_1_tests.yaml",
    66  		"tests/abac_tests.yaml",
    67  	}
    68  
    69  	var allTestCases []individualTest
    70  
    71  	for _, file := range files {
    72  		var b []byte
    73  		var err error
    74  		schemaVersion := params.schemaVersion
    75  		if schemaVersion == typesystem.SchemaVersion1_1 {
    76  			b, err = assets.EmbedTests.ReadFile(file)
    77  		}
    78  		require.NoError(t, err)
    79  
    80  		var testCases checkTests
    81  		err = yaml.Unmarshal(b, &testCases)
    82  		require.NoError(t, err)
    83  
    84  		allTestCases = append(allTestCases, testCases.Tests...)
    85  	}
    86  
    87  	for _, test := range allTestCases {
    88  		test := test
    89  		runTest(t, test, params, false)
    90  		runTest(t, test, params, true)
    91  	}
    92  }
    93  
    94  func runTest(t *testing.T, test individualTest, params testParams, contextTupleTest bool) {
    95  	schemaVersion := params.schemaVersion
    96  	client := params.client
    97  	name := test.Name
    98  
    99  	if contextTupleTest {
   100  		name += "_ctxTuples"
   101  	}
   102  
   103  	t.Run(name, func(t *testing.T) {
   104  		if contextTupleTest && len(test.Stages) > 1 {
   105  			// we don't want to run special contextual tuples test for these cases
   106  			// as multi-stages test has expectation tuples are in system
   107  			t.Skipf("multi-stages test has expectation tuples are in system")
   108  		}
   109  
   110  		t.Parallel()
   111  		ctx := context.Background()
   112  
   113  		resp, err := client.CreateStore(ctx, &openfgav1.CreateStoreRequest{Name: name})
   114  		require.NoError(t, err)
   115  
   116  		storeID := resp.GetId()
   117  
   118  		for stageNumber, stage := range test.Stages {
   119  			t.Run(fmt.Sprintf("stage_%d", stageNumber), func(t *testing.T) {
   120  				if contextTupleTest && len(stage.Tuples) > 20 {
   121  					// https://github.com/openfga/api/blob/05de9d8be3ee12fa4e796b92dbdd4bbbf87107f2/openfga/v1/openfga.proto#L151
   122  					t.Skipf("cannot send more than 20 contextual tuples in one request")
   123  				}
   124  				// arrange: write model
   125  				model := testutils.MustTransformDSLToProtoWithID(stage.Model)
   126  
   127  				writeModelResponse, err := client.WriteAuthorizationModel(ctx, &openfgav1.WriteAuthorizationModelRequest{
   128  					StoreId:         storeID,
   129  					SchemaVersion:   schemaVersion,
   130  					TypeDefinitions: model.GetTypeDefinitions(),
   131  					Conditions:      model.GetConditions(),
   132  				})
   133  				require.NoError(t, err)
   134  
   135  				tuples := stage.Tuples
   136  				tuplesLength := len(tuples)
   137  				// arrange: write tuples
   138  				if tuplesLength > 0 && !contextTupleTest {
   139  					for i := 0; i < tuplesLength; i += writeMaxChunkSize {
   140  						end := int(math.Min(float64(i+writeMaxChunkSize), float64(tuplesLength)))
   141  						writeChunk := (tuples)[i:end]
   142  						_, err = client.Write(ctx, &openfgav1.WriteRequest{
   143  							StoreId:              storeID,
   144  							AuthorizationModelId: writeModelResponse.GetAuthorizationModelId(),
   145  							Writes: &openfgav1.WriteRequestWrites{
   146  								TupleKeys: writeChunk,
   147  							},
   148  						})
   149  						require.NoError(t, err)
   150  					}
   151  				}
   152  
   153  				if len(stage.CheckAssertions) == 0 {
   154  					t.Skipf("no check assertions defined")
   155  				}
   156  				for assertionNumber, assertion := range stage.CheckAssertions {
   157  					t.Run(fmt.Sprintf("assertion_%d", assertionNumber), func(t *testing.T) {
   158  						detailedInfo := fmt.Sprintf("Check request: %s. Model: %s. Tuples: %s. Contextual tuples: %s", assertion.Tuple, stage.Model, stage.Tuples, assertion.ContextualTuples)
   159  
   160  						ctxTuples := assertion.ContextualTuples
   161  						if contextTupleTest {
   162  							ctxTuples = append(ctxTuples, stage.Tuples...)
   163  						}
   164  
   165  						var tupleKey *openfgav1.CheckRequestTupleKey
   166  						if assertion.Tuple != nil {
   167  							tupleKey = &openfgav1.CheckRequestTupleKey{
   168  								User:     assertion.Tuple.GetUser(),
   169  								Relation: assertion.Tuple.GetRelation(),
   170  								Object:   assertion.Tuple.GetObject(),
   171  							}
   172  						}
   173  						resp, err := client.Check(ctx, &openfgav1.CheckRequest{
   174  							StoreId:              storeID,
   175  							AuthorizationModelId: writeModelResponse.GetAuthorizationModelId(),
   176  							TupleKey:             tupleKey,
   177  							ContextualTuples: &openfgav1.ContextualTupleKeys{
   178  								TupleKeys: ctxTuples,
   179  							},
   180  							Context: assertion.Context,
   181  							Trace:   true,
   182  						})
   183  
   184  						if assertion.ErrorCode == 0 {
   185  							require.NoError(t, err, detailedInfo)
   186  							require.Equal(t, assertion.Expectation, resp.GetAllowed(), detailedInfo)
   187  						} else {
   188  							require.Error(t, err, detailedInfo)
   189  							e, ok := status.FromError(err)
   190  							require.True(t, ok, detailedInfo)
   191  							require.Equal(t, assertion.ErrorCode, int(e.Code()), detailedInfo)
   192  						}
   193  					})
   194  				}
   195  			})
   196  		}
   197  	})
   198  }