github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/pertoken/pertoken.go (about)

     1  package pertoken
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
    10  	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
    11  	"google.golang.org/grpc"
    12  
    13  	"github.com/authzed/spicedb/internal/datastore/memdb"
    14  	log "github.com/authzed/spicedb/internal/logging"
    15  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    16  	"github.com/authzed/spicedb/pkg/datastore"
    17  	"github.com/authzed/spicedb/pkg/validationfile"
    18  )
    19  
    20  const (
    21  	gcWindow             = 1 * time.Hour
    22  	revisionQuantization = 10 * time.Millisecond
    23  )
    24  
    25  // MiddlewareForTesting is used to create a unique datastore for each token. It is intended for use in the
    26  // testserver only.
    27  type MiddlewareForTesting struct {
    28  	datastoreByToken *sync.Map
    29  	configFilePaths  []string
    30  }
    31  
    32  // NewMiddleware returns a new per-token datastore middleware that initializes each datastore with the data in the
    33  // config files.
    34  func NewMiddleware(configFilePaths []string) *MiddlewareForTesting {
    35  	return &MiddlewareForTesting{
    36  		datastoreByToken: &sync.Map{},
    37  		configFilePaths:  configFilePaths,
    38  	}
    39  }
    40  
    41  type squashable interface {
    42  	SquashRevisionsForTesting()
    43  }
    44  
    45  func (m *MiddlewareForTesting) getOrCreateDatastore(ctx context.Context) (datastore.Datastore, error) {
    46  	tokenStr, _ := grpcauth.AuthFromMD(ctx, "bearer")
    47  	tokenDatastore, ok := m.datastoreByToken.Load(tokenStr)
    48  	if ok {
    49  		return tokenDatastore.(datastore.Datastore), nil
    50  	}
    51  
    52  	log.Ctx(ctx).Debug().Str("token", tokenStr).Msg("initializing new upstream for token")
    53  	ds, err := memdb.NewMemdbDatastore(0, revisionQuantization, gcWindow)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("failed to init datastore: %w", err)
    56  	}
    57  
    58  	_, _, err = validationfile.PopulateFromFiles(ctx, ds, m.configFilePaths)
    59  	if err != nil {
    60  		return nil, fmt.Errorf("failed to load config files: %w", err)
    61  	}
    62  
    63  	// Squash the revisions so that the caller sees all the populated data.
    64  	ds.(squashable).SquashRevisionsForTesting()
    65  
    66  	m.datastoreByToken.Store(tokenStr, ds)
    67  	return ds, nil
    68  }
    69  
    70  // UnaryServerInterceptor returns a new unary server interceptor that sets a separate in-memory datastore per token
    71  func (m *MiddlewareForTesting) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    72  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    73  		tokenDatastore, err := m.getOrCreateDatastore(ctx)
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  
    78  		newCtx := datastoremw.ContextWithHandle(ctx)
    79  		if err := datastoremw.SetInContext(newCtx, tokenDatastore); err != nil {
    80  			return nil, err
    81  		}
    82  
    83  		return handler(newCtx, req)
    84  	}
    85  }
    86  
    87  // StreamServerInterceptor returns a new stream server interceptor that sets a separate in-memory datastore per token
    88  func (m *MiddlewareForTesting) StreamServerInterceptor() grpc.StreamServerInterceptor {
    89  	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    90  		tokenDatastore, err := m.getOrCreateDatastore(stream.Context())
    91  		if err != nil {
    92  			return err
    93  		}
    94  
    95  		wrapped := middleware.WrapServerStream(stream)
    96  		wrapped.WrappedContext = datastoremw.ContextWithHandle(wrapped.WrappedContext)
    97  		if err := datastoremw.SetInContext(wrapped.WrappedContext, tokenDatastore); err != nil {
    98  			return err
    99  		}
   100  		return handler(srv, wrapped)
   101  	}
   102  }