github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/middleware_test.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/authzed/spicedb/pkg/cmd/datastore"
     9  	"github.com/authzed/spicedb/pkg/cmd/util"
    10  
    11  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
    12  	"github.com/stretchr/testify/require"
    13  	"google.golang.org/grpc"
    14  )
    15  
    16  func TestInvalidModification(t *testing.T) {
    17  	for _, tt := range []struct {
    18  		name string
    19  		mod  MiddlewareModification[grpc.UnaryServerInterceptor]
    20  		err  string
    21  	}{
    22  		{
    23  			name: "invalid operation without dependency",
    24  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
    25  				Operation: OperationReplace,
    26  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    27  					{
    28  						Middleware: nil,
    29  					},
    30  				},
    31  			},
    32  			err: "without a dependency",
    33  		},
    34  		{
    35  			name: "valid replace all without dependency",
    36  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
    37  				Operation: OperationReplaceAllUnsafe,
    38  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    39  					{
    40  						Name:       "foobar",
    41  						Middleware: nil,
    42  					},
    43  				},
    44  			},
    45  		},
    46  		{
    47  			name: "invalid unnamed middleware",
    48  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
    49  				Operation:                OperationAppend,
    50  				DependencyMiddlewareName: "foobar",
    51  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    52  					{
    53  						Middleware: nil,
    54  					},
    55  				},
    56  			},
    57  			err: "unnamed middleware",
    58  		},
    59  		{
    60  			name: "invalid modification with duplicate middlewares",
    61  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
    62  				Operation:                OperationAppend,
    63  				DependencyMiddlewareName: "foobar",
    64  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    65  					{
    66  						Name:       "foo",
    67  						Middleware: nil,
    68  					},
    69  					{
    70  						Name:       "foo",
    71  						Middleware: nil,
    72  					},
    73  				},
    74  			},
    75  			err: "duplicate names in middleware modification",
    76  		},
    77  	} {
    78  		tt := tt
    79  		t.Run(tt.name, func(t *testing.T) {
    80  			err := tt.mod.validate()
    81  			if tt.err != "" {
    82  				require.ErrorContains(t, err, tt.err)
    83  			} else {
    84  				require.NoError(t, err)
    85  			}
    86  		})
    87  	}
    88  }
    89  
    90  func TestNewMiddlewareChain(t *testing.T) {
    91  	mc, err := NewMiddlewareChain(ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    92  		Name:       "foobar",
    93  		Middleware: nil,
    94  	})
    95  	require.NoError(t, err)
    96  	require.Len(t, mc.chain, 1)
    97  
    98  	_, err = NewMiddlewareChain(ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
    99  		Middleware: nil,
   100  	})
   101  	require.ErrorContains(t, err, "unnamed middleware")
   102  }
   103  
   104  func TestChainValidate(t *testing.T) {
   105  	mc, err := NewMiddlewareChain(ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   106  		Name:       "public",
   107  		Middleware: nil,
   108  	}, ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   109  		Name:       "internal",
   110  		Internal:   true,
   111  		Middleware: nil,
   112  	})
   113  	require.NoError(t, err)
   114  
   115  	for _, tt := range []struct {
   116  		name string
   117  		mod  MiddlewareModification[grpc.UnaryServerInterceptor]
   118  		err  string
   119  	}{
   120  		{
   121  			name: "invalid modification on non-existing dependency",
   122  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
   123  				Operation:                OperationReplace,
   124  				DependencyMiddlewareName: "doesnotexist",
   125  			},
   126  			err: "dependency does not exist",
   127  		},
   128  		{
   129  			name: "invalid modification that causes a duplicate",
   130  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
   131  				Operation:                OperationAppend,
   132  				DependencyMiddlewareName: "public",
   133  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   134  					{
   135  						Name:       "public",
   136  						Middleware: nil,
   137  					},
   138  				},
   139  			},
   140  			err: "will cause a duplicate",
   141  		},
   142  		{
   143  			name: "invalid replace of an internal middlewares",
   144  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
   145  				Operation:                OperationReplace,
   146  				DependencyMiddlewareName: "internal",
   147  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   148  					{
   149  						Name:       "foobar",
   150  						Middleware: nil,
   151  					},
   152  				},
   153  			},
   154  			err: "attempts to replace an internal middleware",
   155  		},
   156  		{
   157  			name: "valid replace of a public middleware",
   158  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
   159  				Operation:                OperationReplace,
   160  				DependencyMiddlewareName: "public",
   161  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   162  					{
   163  						Name:       "foobar",
   164  						Middleware: nil,
   165  					},
   166  				},
   167  			},
   168  		},
   169  		{
   170  			name: "valid replace of a public middleware with same name",
   171  			mod: MiddlewareModification[grpc.UnaryServerInterceptor]{
   172  				Operation:                OperationReplace,
   173  				DependencyMiddlewareName: "public",
   174  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   175  					{
   176  						Name:       "public",
   177  						Middleware: nil,
   178  					},
   179  				},
   180  			},
   181  		},
   182  	} {
   183  		tt := tt
   184  		t.Run(tt.name, func(t *testing.T) {
   185  			err := mc.validate(tt.mod)
   186  			if tt.err != "" {
   187  				require.ErrorContains(t, err, tt.err)
   188  			} else {
   189  				require.NoError(t, err)
   190  			}
   191  		})
   192  	}
   193  }
   194  
   195  func TestReplaceAllMiddleware(t *testing.T) {
   196  	// Test fully replacing default Middleware
   197  	var replaceUnary grpc.UnaryServerInterceptor = mockUnaryInterceptor{val: 1}.unaryIntercept
   198  	mod := MiddlewareModification[grpc.UnaryServerInterceptor]{
   199  		Operation: OperationReplaceAllUnsafe,
   200  		Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   201  			{
   202  				Name:       "foobar",
   203  				Middleware: replaceUnary,
   204  			},
   205  		},
   206  	}
   207  
   208  	mc, err := NewMiddlewareChain[grpc.UnaryServerInterceptor]()
   209  	require.NoError(t, err)
   210  
   211  	err = mc.modify(mod)
   212  	require.NoError(t, err)
   213  
   214  	outUnary := mc.ToGRPCInterceptors()
   215  	require.NoError(t, err)
   216  
   217  	expectedReplaceUnary, _ := replaceUnary(context.Background(), nil, nil, nil)
   218  	receivedReplaceUnary, _ := outUnary[0](context.Background(), nil, nil, nil)
   219  	require.Equal(t, expectedReplaceUnary, receivedReplaceUnary)
   220  }
   221  
   222  func TestPrependMiddleware(t *testing.T) {
   223  	// Test prepending and appending to Default middleware
   224  	var replaceUnary grpc.UnaryServerInterceptor = mockUnaryInterceptor{val: 1}.unaryIntercept
   225  	var prependUnary grpc.UnaryServerInterceptor = mockUnaryInterceptor{val: 2}.unaryIntercept
   226  	defaultMiddleware, err := NewMiddlewareChain(
   227  		ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   228  			Name:       "foobar",
   229  			Middleware: replaceUnary,
   230  		})
   231  	require.NoError(t, err)
   232  
   233  	mod := MiddlewareModification[grpc.UnaryServerInterceptor]{
   234  		Operation:                OperationPrepend,
   235  		DependencyMiddlewareName: "foobar",
   236  		Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   237  			{
   238  				Name:       "prepended",
   239  				Middleware: prependUnary,
   240  			},
   241  		},
   242  	}
   243  	err = defaultMiddleware.modify(mod)
   244  	require.NoError(t, err)
   245  
   246  	// testing function equality is not possible, so we test the results of executing the functions
   247  	outUnary := defaultMiddleware.ToGRPCInterceptors()
   248  	require.NoError(t, err)
   249  	require.Len(t, outUnary, 2)
   250  
   251  	expectedPrepend, _ := prependUnary(context.Background(), nil, nil, nil)
   252  	receivedPrepend, _ := outUnary[0](context.Background(), nil, nil, nil)
   253  	require.Equal(t, expectedPrepend, receivedPrepend)
   254  
   255  	expectedReplace, _ := replaceUnary(context.Background(), nil, nil, nil)
   256  	receivedReplace, _ := outUnary[1](context.Background(), nil, nil, nil)
   257  	require.Equal(t, expectedReplace, receivedReplace)
   258  	require.NotEqual(t, receivedPrepend, receivedReplace)
   259  }
   260  
   261  func TestAppendMiddleware(t *testing.T) {
   262  	// Test prepending and appending to Default middleware
   263  	var replaceUnary grpc.UnaryServerInterceptor = mockUnaryInterceptor{val: 1}.unaryIntercept
   264  	var appendUnary grpc.UnaryServerInterceptor = mockUnaryInterceptor{val: 3}.unaryIntercept
   265  	defaultMiddleware := &MiddlewareChain[grpc.UnaryServerInterceptor]{
   266  		chain: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   267  			{
   268  				Name:       "foobar",
   269  				Middleware: replaceUnary,
   270  			},
   271  		},
   272  	}
   273  	mod := MiddlewareModification[grpc.UnaryServerInterceptor]{
   274  		Operation:                OperationAppend,
   275  		DependencyMiddlewareName: "foobar",
   276  		Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   277  			{
   278  				Name:       "appended",
   279  				Middleware: appendUnary,
   280  			},
   281  		},
   282  	}
   283  	err := defaultMiddleware.modify(mod)
   284  	require.NoError(t, err)
   285  
   286  	// testing function equality is not possible, so we test the results of executing the functions
   287  	outUnary := defaultMiddleware.ToGRPCInterceptors()
   288  	require.NoError(t, err)
   289  	require.Len(t, outUnary, 2)
   290  
   291  	expectedReplace, _ := replaceUnary(context.Background(), nil, nil, nil)
   292  	receivedReplace, _ := outUnary[0](context.Background(), nil, nil, nil)
   293  	expectedAppend, _ := appendUnary(context.Background(), nil, nil, nil)
   294  	receivedAppend, _ := outUnary[1](context.Background(), nil, nil, nil)
   295  	require.Equal(t, expectedReplace, receivedReplace)
   296  	require.Equal(t, expectedAppend, receivedAppend)
   297  }
   298  
   299  func TestDeleteMiddleware(t *testing.T) {
   300  	defaultMiddleware := &MiddlewareChain[grpc.UnaryServerInterceptor]{
   301  		chain: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   302  			{
   303  				Name:       "foobar",
   304  				Middleware: mockUnaryInterceptor{}.unaryIntercept,
   305  			},
   306  		},
   307  	}
   308  	mod := MiddlewareModification[grpc.UnaryServerInterceptor]{
   309  		Operation:                OperationReplace,
   310  		DependencyMiddlewareName: "foobar",
   311  	}
   312  	err := defaultMiddleware.modify(mod)
   313  	require.NoError(t, err)
   314  
   315  	outUnary := defaultMiddleware.ToGRPCInterceptors()
   316  	require.NoError(t, err)
   317  	require.Len(t, outUnary, 0)
   318  }
   319  
   320  func TestCannotReplaceInternalMiddleware(t *testing.T) {
   321  	defaultMiddleware := &MiddlewareChain[grpc.UnaryServerInterceptor]{
   322  		chain: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   323  			{
   324  				Name:       "foobar",
   325  				Internal:   true,
   326  				Middleware: mockUnaryInterceptor{}.unaryIntercept,
   327  			},
   328  		},
   329  	}
   330  	mod := MiddlewareModification[grpc.UnaryServerInterceptor]{
   331  		Operation:                OperationReplace,
   332  		DependencyMiddlewareName: "foobar",
   333  	}
   334  	err := defaultMiddleware.modify(mod)
   335  	require.ErrorContains(t, err, "replace an internal middleware")
   336  }
   337  
   338  type mockUnaryInterceptor struct {
   339  	val int
   340  }
   341  
   342  func (m mockUnaryInterceptor) unaryIntercept(_ context.Context, _ interface{}, _ *grpc.UnaryServerInfo, _ grpc.UnaryHandler) (resp interface{}, err error) {
   343  	return m.val, nil
   344  }
   345  
   346  type mockStreamInterceptor struct {
   347  	val error
   348  }
   349  
   350  func (m mockStreamInterceptor) streamIntercept(_ interface{}, _ grpc.ServerStream, _ *grpc.StreamServerInfo, _ grpc.StreamHandler) error {
   351  	return m.val
   352  }
   353  
   354  func TestMiddlewareOrdering(t *testing.T) {
   355  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   356  	defer cancel()
   357  
   358  	ds, err := datastore.NewDatastore(ctx,
   359  		datastore.DefaultDatastoreConfig().ToOption(),
   360  		datastore.WithBootstrapFiles("testdata/test_schema.yaml"),
   361  		datastore.WithRequestHedgingEnabled(false),
   362  	)
   363  	require.NoError(t, err)
   364  
   365  	c := ConfigWithOptions(
   366  		&Config{},
   367  		WithPresharedSecureKey("psk"),
   368  		WithDatastore(ds),
   369  		WithGRPCServer(util.GRPCServerConfig{
   370  			Network: util.BufferedNetwork,
   371  			Enabled: true,
   372  		}),
   373  	)
   374  	rs, err := c.Complete(ctx)
   375  	require.NoError(t, err)
   376  
   377  	clientConn, err := rs.GRPCDialContext(ctx)
   378  	require.NoError(t, err)
   379  
   380  	psc := v1.NewPermissionsServiceClient(clientConn)
   381  
   382  	go func() {
   383  		_ = rs.Run(ctx)
   384  	}()
   385  	time.Sleep(100 * time.Millisecond)
   386  
   387  	req := &v1.CheckPermissionRequest{
   388  		Resource: &v1.ObjectReference{
   389  			ObjectType: "resource",
   390  			ObjectId:   "resource1",
   391  		},
   392  		Subject: &v1.SubjectReference{
   393  			Object: &v1.ObjectReference{
   394  				ObjectType: "user",
   395  				ObjectId:   "user1",
   396  			},
   397  		},
   398  		Permission: "read",
   399  	}
   400  
   401  	_, err = psc.CheckPermission(ctx, req)
   402  	require.NoError(t, err)
   403  
   404  	lrreq := &v1.LookupResourcesRequest{
   405  		ResourceObjectType: "resource",
   406  		Subject: &v1.SubjectReference{
   407  			Object: &v1.ObjectReference{
   408  				ObjectType: "user",
   409  				ObjectId:   "user1",
   410  			},
   411  		},
   412  		Permission: "read",
   413  	}
   414  	lrc, err := psc.LookupResources(ctx, lrreq)
   415  	require.NoError(t, err)
   416  
   417  	_, err = lrc.Recv()
   418  	require.NoError(t, err)
   419  }
   420  
   421  func TestIncorrectOrderAssertionFails(t *testing.T) {
   422  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   423  	defer cancel()
   424  
   425  	ds, err := datastore.NewDatastore(ctx,
   426  		datastore.DefaultDatastoreConfig().ToOption(),
   427  		datastore.WithBootstrapFiles("testdata/test_schema.yaml"),
   428  		datastore.WithRequestHedgingEnabled(false),
   429  	)
   430  	require.NoError(t, err)
   431  	noopUnary := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   432  		return nil, nil
   433  	}
   434  	noopStreaming := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   435  		return handler(srv, ss)
   436  	}
   437  
   438  	c := ConfigWithOptions(
   439  		&Config{},
   440  		WithPresharedSecureKey("psk"),
   441  		WithDatastore(ds),
   442  		WithGRPCServer(util.GRPCServerConfig{
   443  			Network: util.BufferedNetwork,
   444  			Enabled: true,
   445  		}),
   446  		SetUnaryMiddlewareModification([]MiddlewareModification[grpc.UnaryServerInterceptor]{
   447  			{
   448  				Operation: OperationReplaceAllUnsafe,
   449  				Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   450  					NewUnaryMiddleware().
   451  						WithName("test").
   452  						WithInterceptor(noopUnary).
   453  						EnsureAlreadyExecuted("does-not-exist").
   454  						Done(),
   455  				},
   456  			},
   457  		}),
   458  		SetStreamingMiddlewareModification([]MiddlewareModification[grpc.StreamServerInterceptor]{
   459  			{
   460  				Operation: OperationReplaceAllUnsafe,
   461  				Middlewares: []ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   462  					NewStreamMiddleware().
   463  						WithName("test").
   464  						WithInterceptor(noopStreaming).
   465  						EnsureWrapperAlreadyExecuted("does-not-exist").
   466  						Done(),
   467  				},
   468  			},
   469  		}),
   470  	)
   471  	rs, err := c.Complete(ctx)
   472  	require.NoError(t, err)
   473  
   474  	clientConn, err := rs.GRPCDialContext(ctx)
   475  	require.NoError(t, err)
   476  
   477  	psc := v1.NewPermissionsServiceClient(clientConn)
   478  
   479  	go func() {
   480  		_ = rs.Run(ctx)
   481  	}()
   482  	time.Sleep(100 * time.Millisecond)
   483  
   484  	req := &v1.CheckPermissionRequest{
   485  		Resource: &v1.ObjectReference{
   486  			ObjectType: "resource",
   487  			ObjectId:   "resource1",
   488  		},
   489  		Subject: &v1.SubjectReference{
   490  			Object: &v1.ObjectReference{
   491  				ObjectType: "user",
   492  				ObjectId:   "user1",
   493  			},
   494  		},
   495  		Permission: "read",
   496  	}
   497  
   498  	_, err = psc.CheckPermission(ctx, req)
   499  	require.ErrorContains(t, err, "expected interceptor does-not-exist to be already executed")
   500  
   501  	lrreq := &v1.LookupResourcesRequest{
   502  		ResourceObjectType: "resource",
   503  		Subject: &v1.SubjectReference{
   504  			Object: &v1.ObjectReference{
   505  				ObjectType: "user",
   506  				ObjectId:   "user1",
   507  			},
   508  		},
   509  		Permission: "read",
   510  	}
   511  
   512  	lrc, err := psc.LookupResources(ctx, lrreq)
   513  	require.NoError(t, err)
   514  
   515  	_, err = lrc.Recv()
   516  	require.ErrorContains(t, err, "expected interceptor does-not-exist to be already executed")
   517  }