github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/server_test.go (about) 1 package server 2 3 import ( 4 "context" 5 "errors" 6 "log" 7 "testing" 8 "time" 9 10 "github.com/authzed/spicedb/internal/datastore/memdb" 11 "github.com/authzed/spicedb/internal/logging" 12 "github.com/authzed/spicedb/pkg/cmd/datastore" 13 "github.com/authzed/spicedb/pkg/cmd/util" 14 15 v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" 16 "github.com/stretchr/testify/require" 17 "go.opentelemetry.io/otel" 18 "go.opentelemetry.io/otel/sdk/trace" 19 "go.opentelemetry.io/otel/sdk/trace/tracetest" 20 "go.uber.org/goleak" 21 "google.golang.org/grpc" 22 ) 23 24 func TestServerGracefulTermination(t *testing.T) { 25 defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) 26 27 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 28 ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) 29 require.NoError(t, err) 30 31 c := ConfigWithOptions( 32 &Config{}, 33 WithPresharedSecureKey("psk"), 34 WithDatastore(ds), 35 WithGRPCServer(util.GRPCServerConfig{ 36 Network: util.BufferedNetwork, 37 Enabled: true, 38 }), 39 WithNamespaceCacheConfig(CacheConfig{Enabled: true}), 40 WithDispatchCacheConfig(CacheConfig{Enabled: true}), 41 WithClusterDispatchCacheConfig(CacheConfig{Enabled: true}), 42 WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: true, HTTPAddress: ":"}), 43 WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: true, HTTPAddress: ":"}), 44 ) 45 rs, err := c.Complete(ctx) 46 require.NoError(t, err) 47 48 ch := make(chan struct{}, 1) 49 st := make(chan struct{}, 1) 50 go func() { 51 st <- struct{}{} 52 _ = rs.Run(ctx) 53 ch <- struct{}{} 54 }() 55 <-st 56 time.Sleep(10 * time.Millisecond) 57 cancel() 58 <-ch 59 } 60 61 func TestOTelReporting(t *testing.T) { 62 defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) 63 64 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 65 defer cancel() 66 67 ds, err := datastore.NewDatastore(ctx, 68 datastore.DefaultDatastoreConfig().ToOption(), 69 datastore.WithRequestHedgingEnabled(false), 70 ) 71 if err != nil { 72 log.Fatalf("unable to start memdb datastore: %s", err) 73 } 74 75 configOpts := []ConfigOption{ 76 WithGRPCServer(util.GRPCServerConfig{ 77 Network: util.BufferedNetwork, 78 Enabled: true, 79 }), 80 WithGRPCAuthFunc(func(ctx context.Context) (context.Context, error) { 81 return ctx, nil 82 }), 83 WithHTTPGateway(util.HTTPServerConfig{HTTPEnabled: false}), 84 WithMetricsAPI(util.HTTPServerConfig{HTTPEnabled: false}), 85 WithDispatchCacheConfig(CacheConfig{Enabled: false, Metrics: false}), 86 WithNamespaceCacheConfig(CacheConfig{Enabled: false, Metrics: false}), 87 WithClusterDispatchCacheConfig(CacheConfig{Enabled: false, Metrics: false}), 88 WithDatastore(ds), 89 } 90 91 srv, err := NewConfigWithOptionsAndDefaults(configOpts...).Complete(ctx) 92 require.NoError(t, err) 93 94 conn, err := srv.GRPCDialContext(ctx) 95 require.NoError(t, err) 96 defer conn.Close() 97 98 schemaSrv := v1.NewSchemaServiceClient(conn) 99 100 go func() { 101 require.NoError(t, srv.Run(ctx)) 102 }() 103 104 spanrecorder, restoreOtel := setupSpanRecorder() 105 defer restoreOtel() 106 107 // test unary OTel middleware 108 _, err = schemaSrv.WriteSchema(ctx, &v1.WriteSchemaRequest{ 109 Schema: `definition user {}`, 110 }) 111 require.NoError(t, err) 112 requireSpanExists(t, spanrecorder, "authzed.api.v1.SchemaService/WriteSchema") 113 114 // test streaming OTel middleware 115 permSrv := v1.NewPermissionsServiceClient(conn) 116 rrCli, err := permSrv.ReadRelationships(ctx, &v1.ReadRelationshipsRequest{}) 117 require.NoError(t, err) 118 119 _, err = rrCli.Recv() 120 require.Error(t, err) 121 122 requireSpanExists(t, spanrecorder, "authzed.api.v1.PermissionsService/ReadRelationships") 123 124 lrCli, err := permSrv.LookupResources(ctx, &v1.LookupResourcesRequest{}) 125 require.NoError(t, err) 126 127 _, err = lrCli.Recv() 128 require.Error(t, err) 129 130 requireSpanExists(t, spanrecorder, "authzed.api.v1.PermissionsService/LookupResources") 131 } 132 133 func requireSpanExists(t *testing.T, spanrecorder *tracetest.SpanRecorder, spanName string) { 134 t.Helper() 135 136 ended := spanrecorder.Ended() 137 var present bool 138 for _, span := range ended { 139 if span.Name() == spanName { 140 present = true 141 } 142 } 143 144 require.True(t, present, "missing trace for Streaming gRPC call") 145 } 146 147 func setupSpanRecorder() (*tracetest.SpanRecorder, func()) { 148 defaultProvider := otel.GetTracerProvider() 149 150 provider := trace.NewTracerProvider( 151 trace.WithSampler(trace.AlwaysSample()), 152 ) 153 spanrecorder := tracetest.NewSpanRecorder() 154 provider.RegisterSpanProcessor(spanrecorder) 155 otel.SetTracerProvider(provider) 156 157 return spanrecorder, func() { 158 otel.SetTracerProvider(defaultProvider) 159 } 160 } 161 162 func TestServerGracefulTerminationOnError(t *testing.T) { 163 defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) 164 165 ctx, cancel := context.WithCancel(context.Background()) 166 ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) 167 require.NoError(t, err) 168 169 c := ConfigWithOptions(&Config{ 170 GRPCServer: util.GRPCServerConfig{ 171 Network: util.BufferedNetwork, 172 }, 173 }, WithPresharedSecureKey("psk"), WithDatastore(ds)) 174 cancel() 175 _, err = c.Complete(ctx) 176 require.NoError(t, err) 177 } 178 179 func TestReplaceUnaryMiddleware(t *testing.T) { 180 c := Config{UnaryMiddlewareModification: []MiddlewareModification[grpc.UnaryServerInterceptor]{ 181 { 182 Operation: OperationReplaceAllUnsafe, 183 Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ 184 { 185 Name: "foobar", 186 Middleware: mockUnaryInterceptor{val: 1}.unaryIntercept, 187 }, 188 }, 189 }, 190 }} 191 unary, err := c.buildUnaryMiddleware(nil) 192 require.NoError(t, err) 193 require.Len(t, unary, 1) 194 195 val, _ := unary[0](context.Background(), nil, nil, nil) 196 require.Equal(t, 1, val) 197 } 198 199 func TestReplaceStreamingMiddleware(t *testing.T) { 200 c := Config{StreamingMiddlewareModification: []MiddlewareModification[grpc.StreamServerInterceptor]{ 201 { 202 Operation: OperationReplaceAllUnsafe, 203 Middlewares: []ReferenceableMiddleware[grpc.StreamServerInterceptor]{ 204 { 205 Name: "foobar", 206 Middleware: mockStreamInterceptor{val: errors.New("hi")}.streamIntercept, 207 }, 208 }, 209 }, 210 }} 211 streaming, err := c.buildStreamingMiddleware(nil) 212 require.NoError(t, err) 213 require.Len(t, streaming, 1) 214 215 err = streaming[0](context.Background(), nil, nil, nil) 216 require.ErrorContains(t, err, "hi") 217 } 218 219 func TestModifyUnaryMiddleware(t *testing.T) { 220 c := Config{UnaryMiddlewareModification: []MiddlewareModification[grpc.UnaryServerInterceptor]{ 221 { 222 Operation: OperationPrepend, 223 DependencyMiddlewareName: DefaultMiddlewareLog, 224 Middlewares: []ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ 225 { 226 Name: "foobar", 227 Middleware: mockUnaryInterceptor{val: 1}.unaryIntercept, 228 }, 229 }, 230 }, 231 }} 232 233 opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false} 234 defaultMw, err := DefaultUnaryMiddleware(opt) 235 require.NoError(t, err) 236 237 unary, err := c.buildUnaryMiddleware(defaultMw) 238 require.NoError(t, err) 239 require.Len(t, unary, len(defaultMw.chain)+1) 240 241 val, _ := unary[1](context.Background(), nil, nil, nil) 242 require.Equal(t, 1, val) 243 } 244 245 func TestModifyStreamingMiddleware(t *testing.T) { 246 c := Config{StreamingMiddlewareModification: []MiddlewareModification[grpc.StreamServerInterceptor]{ 247 { 248 Operation: OperationPrepend, 249 DependencyMiddlewareName: DefaultMiddlewareLog, 250 Middlewares: []ReferenceableMiddleware[grpc.StreamServerInterceptor]{ 251 { 252 Name: "foobar", 253 Middleware: mockStreamInterceptor{val: errors.New("hi")}.streamIntercept, 254 }, 255 }, 256 }, 257 }} 258 259 opt := MiddlewareOption{logging.Logger, nil, false, nil, nil, false, false, false} 260 defaultMw, err := DefaultStreamingMiddleware(opt) 261 require.NoError(t, err) 262 263 streaming, err := c.buildStreamingMiddleware(defaultMw) 264 require.NoError(t, err) 265 require.Len(t, streaming, len(defaultMw.chain)+1) 266 267 err = streaming[1](context.Background(), nil, nil, nil) 268 require.ErrorContains(t, err, "hi") 269 }