github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/server_test.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"log"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/authzed/spicedb/internal/datastore/memdb"
    11  	"github.com/authzed/spicedb/internal/logging"
    12  	"github.com/authzed/spicedb/pkg/cmd/datastore"
    13  	"github.com/authzed/spicedb/pkg/cmd/util"
    14  
    15  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
    16  	"github.com/stretchr/testify/require"
    17  	"go.opentelemetry.io/otel"
    18  	"go.opentelemetry.io/otel/sdk/trace"
    19  	"go.opentelemetry.io/otel/sdk/trace/tracetest"
    20  	"go.uber.org/goleak"
    21  	"google.golang.org/grpc"
    22  )
    23  
    24  func TestServerGracefulTermination(t *testing.T) {
    25  	defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
    26  
    27  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    28  	ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second)
    29  	require.NoError(t, err)
    30  
    31  	c := ConfigWithOptions(
    32  		&Config{},
    33  		WithPresharedSecureKey("psk"),
    34  		WithDatastore(ds),
    35  		WithGRPCServer(util.GRPCServerConfig{
    36  			Network: util.BufferedNetwork,
    37  			Enabled: true,
    38  		}),
    39  		WithNamespaceCacheConfig(CacheConfig{Enabled: true}),
    40  		WithDispatchCacheConfig(CacheConfig{Enabled: true}),
    41  		WithClusterDispatchCacheConfig(CacheConfig{Enabled: true}),
    42  		WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: true, HTTPAddress: ":"}),
    43  		WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: true, HTTPAddress: ":"}),
    44  	)
    45  	rs, err := c.Complete(ctx)
    46  	require.NoError(t, err)
    47  
    48  	ch := make(chan struct{}, 1)
    49  	st := make(chan struct{}, 1)
    50  	go func() {
    51  		st <- struct{}{}
    52  		_ = rs.Run(ctx)
    53  		ch <- struct{}{}
    54  	}()
    55  	<-st
    56  	time.Sleep(10 * time.Millisecond)
    57  	cancel()
    58  	<-ch
    59  }
    60  
    61  func TestOTelReporting(t *testing.T) {
    62  	defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
    63  
    64  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    65  	defer cancel()
    66  
    67  	ds, err := datastore.NewDatastore(ctx,
    68  		datastore.DefaultDatastoreConfig().ToOption(),
    69  		datastore.WithRequestHedgingEnabled(false),
    70  	)
    71  	if err != nil {
    72  		log.Fatalf("unable to start memdb datastore: %s", err)
    73  	}
    74  
    75  	configOpts := []ConfigOption{
    76  		WithGRPCServer(util.GRPCServerConfig{
    77  			Network: util.BufferedNetwork,
    78  			Enabled: true,
    79  		}),
    80  		WithGRPCAuthFunc(func(ctx context.Context) (context.Context, error) {
    81  			return ctx, nil
    82  		}),
    83  		WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: false}),
    84  		WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: false}),
    85  		WithDispatchCacheConfig(CacheConfig{Enabled: false, Metrics: false}),
    86  		WithNamespaceCacheConfig(CacheConfig{Enabled: false, Metrics: false}),
    87  		WithClusterDispatchCacheConfig(CacheConfig{Enabled: false, Metrics: false}),
    88  		WithDatastore(ds),
    89  	}
    90  
    91  	srv, err := NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx)
    92  	require.NoError(t, err)
    93  
    94  	conn, err := srv.GRPCDialContext(ctx)
    95  	require.NoError(t, err)
    96  	defer conn.Close()
    97  
    98  	schemaSrv := v1.NewSchemaServiceClient(conn)
    99  
   100  	go func() {
   101  		require.NoError(t, srv.Run(ctx))
   102  	}()
   103  
   104  	spanrecorder, restoreOtel := setupSpanRecorder()
   105  	defer restoreOtel()
   106  
   107  	// test unary OTel middleware
   108  	_, err = schemaSrv.WriteSchema(ctx, &v1.WriteSchemaRequest{
   109  		Schema: `definition user {}`,
   110  	})
   111  	require.NoError(t, err)
   112  	requireSpanExists(t, spanrecorder, "authzed.api.v1.SchemaService/WriteSchema")
   113  
   114  	// test streaming OTel middleware
   115  	permSrv := v1.NewPermissionsServiceClient(conn)
   116  	rrCli, err := permSrv.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{})
   117  	require.NoError(t, err)
   118  
   119  	_, err = rrCli.Recv()
   120  	require.Error(t, err)
   121  
   122  	requireSpanExists(t, spanrecorder, "authzed.api.v1.PermissionsService/ReadRelationships")
   123  
   124  	lrCli, err := permSrv.LookupResources(ctx, &v1.LookupResourcesRequest{})
   125  	require.NoError(t, err)
   126  
   127  	_, err = lrCli.Recv()
   128  	require.Error(t, err)
   129  
   130  	requireSpanExists(t, spanrecorder, "authzed.api.v1.PermissionsService/LookupResources")
   131  }
   132  
   133  func requireSpanExists(t *testing.T, spanrecorder *tracetest.SpanRecorder, spanName string) {
   134  	t.Helper()
   135  
   136  	ended := spanrecorder.Ended()
   137  	var present bool
   138  	for _, span := range ended {
   139  		if span.Name() == spanName {
   140  			present = true
   141  		}
   142  	}
   143  
   144  	require.True(t, present, "missing trace for Streaming gRPC call")
   145  }
   146  
   147  func setupSpanRecorder() (*tracetest.SpanRecorder, func()) {
   148  	defaultProvider := otel.GetTracerProvider()
   149  
   150  	provider := trace.NewTracerProvider(
   151  		trace.WithSampler(trace.AlwaysSample()),
   152  	)
   153  	spanrecorder := tracetest.NewSpanRecorder()
   154  	provider.RegisterSpanProcessor(spanrecorder)
   155  	otel.SetTracerProvider(provider)
   156  
   157  	return spanrecorder, func() {
   158  		otel.SetTracerProvider(defaultProvider)
   159  	}
   160  }
   161  
   162  func TestServerGracefulTerminationOnError(t *testing.T) {
   163  	defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
   164  
   165  	ctx, cancel := context.WithCancel(context.Background())
   166  	ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second)
   167  	require.NoError(t, err)
   168  
   169  	c := ConfigWithOptions(&Config{
   170  		GRPCServer: util.GRPCServerConfig{
   171  			Network: util.BufferedNetwork,
   172  		},
   173  	}, WithPresharedSecureKey("psk"), WithDatastore(ds))
   174  	cancel()
   175  	_, err = c.Complete(ctx)
   176  	require.NoError(t, err)
   177  }
   178  
   179  func TestReplaceUnaryMiddleware(t *testing.T) {
   180  	c := Config{UnaryMiddlewareModification: []MiddlewareModification[grpc.UnaryServerInterceptor]{
   181  		{
   182  			Operation: OperationReplaceAllUnsafe,
   183  			Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   184  				{
   185  					Name:       "foobar",
   186  					Middleware: mockUnaryInterceptor{val: 1}.unaryIntercept,
   187  				},
   188  			},
   189  		},
   190  	}}
   191  	unary, err := c.buildUnaryMiddleware(nil)
   192  	require.NoError(t, err)
   193  	require.Len(t, unary, 1)
   194  
   195  	val, _ := unary[0](context.Background(), nil, nil, nil)
   196  	require.Equal(t, 1, val)
   197  }
   198  
   199  func TestReplaceStreamingMiddleware(t *testing.T) {
   200  	c := Config{StreamingMiddlewareModification: []MiddlewareModification[grpc.StreamServerInterceptor]{
   201  		{
   202  			Operation: OperationReplaceAllUnsafe,
   203  			Middlewares: []ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   204  				{
   205  					Name:       "foobar",
   206  					Middleware: mockStreamInterceptor{val: errors.New("hi")}.streamIntercept,
   207  				},
   208  			},
   209  		},
   210  	}}
   211  	streaming, err := c.buildStreamingMiddleware(nil)
   212  	require.NoError(t, err)
   213  	require.Len(t, streaming, 1)
   214  
   215  	err = streaming[0](context.Background(), nil, nil, nil)
   216  	require.ErrorContains(t, err, "hi")
   217  }
   218  
   219  func TestModifyUnaryMiddleware(t *testing.T) {
   220  	c := Config{UnaryMiddlewareModification: []MiddlewareModification[grpc.UnaryServerInterceptor]{
   221  		{
   222  			Operation:                OperationPrepend,
   223  			DependencyMiddlewareName: DefaultMiddlewareLog,
   224  			Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   225  				{
   226  					Name:       "foobar",
   227  					Middleware: mockUnaryInterceptor{val: 1}.unaryIntercept,
   228  				},
   229  			},
   230  		},
   231  	}}
   232  
   233  	opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
   234  	defaultMw, err := DefaultUnaryMiddleware(opt)
   235  	require.NoError(t, err)
   236  
   237  	unary, err := c.buildUnaryMiddleware(defaultMw)
   238  	require.NoError(t, err)
   239  	require.Len(t, unary, len(defaultMw.chain)+1)
   240  
   241  	val, _ := unary[1](context.Background(), nil, nil, nil)
   242  	require.Equal(t, 1, val)
   243  }
   244  
   245  func TestModifyStreamingMiddleware(t *testing.T) {
   246  	c := Config{StreamingMiddlewareModification: []MiddlewareModification[grpc.StreamServerInterceptor]{
   247  		{
   248  			Operation:                OperationPrepend,
   249  			DependencyMiddlewareName: DefaultMiddlewareLog,
   250  			Middlewares: []ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   251  				{
   252  					Name:       "foobar",
   253  					Middleware: mockStreamInterceptor{val: errors.New("hi")}.streamIntercept,
   254  				},
   255  			},
   256  		},
   257  	}}
   258  
   259  	opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false}
   260  	defaultMw, err := DefaultStreamingMiddleware(opt)
   261  	require.NoError(t, err)
   262  
   263  	streaming, err := c.buildStreamingMiddleware(defaultMw)
   264  	require.NoError(t, err)
   265  	require.Len(t, streaming, len(defaultMw.chain)+1)
   266  
   267  	err = streaming[1](context.Background(), nil, nil, nil)
   268  	require.ErrorContains(t, err, "hi")
   269  }