github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/remote/cluster.go (about)

     1  package remote
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/authzed/consistent"
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/rs/zerolog"
    14  	"google.golang.org/grpc"
    15  	"google.golang.org/grpc/connectivity"
    16  	"google.golang.org/protobuf/proto"
    17  
    18  	"github.com/authzed/spicedb/internal/dispatch"
    19  	"github.com/authzed/spicedb/internal/dispatch/keys"
    20  	log "github.com/authzed/spicedb/internal/logging"
    21  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    22  	"github.com/authzed/spicedb/pkg/spiceerrors"
    23  )
    24  
    25  var dispatchCounter = prometheus.NewCounterVec(prometheus.CounterOpts{
    26  	Namespace: "spicedb",
    27  	Subsystem: "dispatch",
    28  	Name:      "remote_dispatch_handler_total",
    29  	Help:      "which dispatcher handled a request",
    30  }, []string{"request_kind", "handler_name"})
    31  
    32  func init() {
    33  	prometheus.MustRegister(dispatchCounter)
    34  }
    35  
    36  type ClusterClient interface {
    37  	DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest, opts ...grpc.CallOption) (*v1.DispatchCheckResponse, error)
    38  	DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest, opts ...grpc.CallOption) (*v1.DispatchExpandResponse, error)
    39  	DispatchReachableResources(ctx context.Context, in *v1.DispatchReachableResourcesRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchReachableResourcesClient, error)
    40  	DispatchLookupResources(ctx context.Context, in *v1.DispatchLookupResourcesRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchLookupResourcesClient, error)
    41  	DispatchLookupSubjects(ctx context.Context, in *v1.DispatchLookupSubjectsRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchLookupSubjectsClient, error)
    42  }
    43  
    44  type ClusterDispatcherConfig struct {
    45  	// KeyHandler is then handler to use for generating dispatch hash ring keys.
    46  	KeyHandler keys.Handler
    47  
    48  	// DispatchOverallTimeout is the maximum duration of a dispatched request
    49  	// before it should timeout.
    50  	DispatchOverallTimeout time.Duration
    51  }
    52  
    53  // SecondaryDispatch defines a struct holding a client and its name for secondary
    54  // dispatching.
    55  type SecondaryDispatch struct {
    56  	Name   string
    57  	Client ClusterClient
    58  }
    59  
    60  // NewClusterDispatcher creates a dispatcher implementation that uses the provided client
    61  // to dispatch requests to peer nodes in the cluster.
    62  func NewClusterDispatcher(client ClusterClient, conn *grpc.ClientConn, config ClusterDispatcherConfig, secondaryDispatch map[string]SecondaryDispatch, secondaryDispatchExprs map[string]*DispatchExpr) dispatch.Dispatcher {
    63  	keyHandler := config.KeyHandler
    64  	if keyHandler == nil {
    65  		keyHandler = &keys.DirectKeyHandler{}
    66  	}
    67  
    68  	dispatchOverallTimeout := config.DispatchOverallTimeout
    69  	if dispatchOverallTimeout <= 0 {
    70  		dispatchOverallTimeout = 60 * time.Second
    71  	}
    72  
    73  	return &clusterDispatcher{
    74  		clusterClient:          client,
    75  		conn:                   conn,
    76  		keyHandler:             keyHandler,
    77  		dispatchOverallTimeout: dispatchOverallTimeout,
    78  		secondaryDispatch:      secondaryDispatch,
    79  		secondaryDispatchExprs: secondaryDispatchExprs,
    80  	}
    81  }
    82  
    83  type clusterDispatcher struct {
    84  	clusterClient          ClusterClient
    85  	conn                   *grpc.ClientConn
    86  	keyHandler             keys.Handler
    87  	dispatchOverallTimeout time.Duration
    88  	secondaryDispatch      map[string]SecondaryDispatch
    89  	secondaryDispatchExprs map[string]*DispatchExpr
    90  }
    91  
    92  func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
    93  	if err := dispatch.CheckDepth(ctx, req); err != nil {
    94  		return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, err
    95  	}
    96  
    97  	requestKey, err := cr.keyHandler.CheckDispatchKey(ctx, req)
    98  	if err != nil {
    99  		return &v1.DispatchCheckResponse{Metadata: emptyMetadata}, err
   100  	}
   101  
   102  	ctx = context.WithValue(ctx, consistent.CtxKey, requestKey)
   103  
   104  	resp, err := dispatchRequest(ctx, cr, "check", req, func(ctx context.Context, client ClusterClient) (*v1.DispatchCheckResponse, error) {
   105  		resp, err := client.DispatchCheck(ctx, req)
   106  		if err != nil {
   107  			return resp, err
   108  		}
   109  
   110  		err = adjustMetadataForDispatch(resp.Metadata)
   111  		return resp, err
   112  	})
   113  	if err != nil {
   114  		return &v1.DispatchCheckResponse{Metadata: requestFailureMetadata}, err
   115  	}
   116  
   117  	return resp, err
   118  }
   119  
   120  type requestMessage interface {
   121  	zerolog.LogObjectMarshaler
   122  
   123  	GetMetadata() *v1.ResolverMeta
   124  }
   125  
   126  type responseMessage interface {
   127  	proto.Message
   128  
   129  	GetMetadata() *v1.ResponseMeta
   130  }
   131  
   132  type respTuple[S responseMessage] struct {
   133  	resp S
   134  	err  error
   135  }
   136  
   137  type secondaryRespTuple[S responseMessage] struct {
   138  	handlerName string
   139  	resp        S
   140  }
   141  
   142  func dispatchRequest[Q requestMessage, S responseMessage](ctx context.Context, cr *clusterDispatcher, reqKey string, req Q, handler func(context.Context, ClusterClient) (S, error)) (S, error) {
   143  	withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout)
   144  	defer cancelFn()
   145  
   146  	if len(cr.secondaryDispatchExprs) == 0 || len(cr.secondaryDispatch) == 0 {
   147  		return handler(withTimeout, cr.clusterClient)
   148  	}
   149  
   150  	// If no secondary dispatches are defined, just invoke directly.
   151  	expr, ok := cr.secondaryDispatchExprs[reqKey]
   152  	if !ok {
   153  		return handler(withTimeout, cr.clusterClient)
   154  	}
   155  
   156  	// Otherwise invoke in parallel with any secondary matches.
   157  	primaryResultChan := make(chan respTuple[S], 1)
   158  	secondaryResultChan := make(chan secondaryRespTuple[S], len(cr.secondaryDispatch))
   159  
   160  	// Run the main dispatch.
   161  	go func() {
   162  		resp, err := handler(withTimeout, cr.clusterClient)
   163  		primaryResultChan <- respTuple[S]{resp, err}
   164  	}()
   165  
   166  	result, err := RunDispatchExpr(expr, req)
   167  	if err != nil {
   168  		log.Warn().Err(err).Msg("error when trying to evaluate the dispatch expression")
   169  	}
   170  
   171  	log.Trace().Str("secondary-dispatchers", strings.Join(result, ",")).Object("request", req).Msg("running secondary dispatchers")
   172  
   173  	for _, secondaryDispatchName := range result {
   174  		secondary, ok := cr.secondaryDispatch[secondaryDispatchName]
   175  		if !ok {
   176  			log.Warn().Str("secondary-dispatcher-name", secondaryDispatchName).Msg("received unknown secondary dispatcher")
   177  			continue
   178  		}
   179  
   180  		log.Trace().Str("secondary-dispatcher", secondary.Name).Object("request", req).Msg("running secondary dispatcher")
   181  		go func() {
   182  			resp, err := handler(withTimeout, secondary.Client)
   183  			if err != nil {
   184  				// For secondary dispatches, ignore any errors, as only the primary will be handled in
   185  				// that scenario.
   186  				log.Trace().Str("secondary", secondary.Name).Err(err).Msg("got ignored secondary dispatch error")
   187  				return
   188  			}
   189  
   190  			secondaryResultChan <- secondaryRespTuple[S]{resp: resp, handlerName: secondary.Name}
   191  		}()
   192  	}
   193  
   194  	var foundError error
   195  	select {
   196  	case <-withTimeout.Done():
   197  		return *new(S), fmt.Errorf("check dispatch has timed out")
   198  
   199  	case r := <-primaryResultChan:
   200  		if r.err == nil {
   201  			dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1)
   202  			return r.resp, nil
   203  		}
   204  
   205  		// Otherwise, if an error was found, log it and we'll return after *all* the secondaries have run.
   206  		// This allows an otherwise error-state to be handled by one of the secondaries.
   207  		foundError = r.err
   208  
   209  	case r := <-secondaryResultChan:
   210  		dispatchCounter.WithLabelValues(reqKey, r.handlerName).Add(1)
   211  		return r.resp, nil
   212  	}
   213  
   214  	dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1)
   215  	return *new(S), foundError
   216  }
   217  
   218  func adjustMetadataForDispatch(metadata *v1.ResponseMeta) error {
   219  	if metadata == nil {
   220  		return spiceerrors.MustBugf("received a nil metadata")
   221  	}
   222  
   223  	// NOTE: We only add 1 to the dispatch count if it was not already handled by the downstream dispatch,
   224  	// which will only be the case in a fully cached or further undispatched call.
   225  	if metadata.DispatchCount == 0 {
   226  		metadata.DispatchCount++
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  func (cr *clusterDispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) {
   233  	if err := dispatch.CheckDepth(ctx, req); err != nil {
   234  		return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
   235  	}
   236  
   237  	requestKey, err := cr.keyHandler.ExpandDispatchKey(ctx, req)
   238  	if err != nil {
   239  		return &v1.DispatchExpandResponse{Metadata: emptyMetadata}, err
   240  	}
   241  
   242  	ctx = context.WithValue(ctx, consistent.CtxKey, requestKey)
   243  
   244  	withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout)
   245  	defer cancelFn()
   246  
   247  	resp, err := cr.clusterClient.DispatchExpand(withTimeout, req)
   248  	if err != nil {
   249  		return &v1.DispatchExpandResponse{Metadata: requestFailureMetadata}, err
   250  	}
   251  
   252  	err = adjustMetadataForDispatch(resp.Metadata)
   253  	return resp, err
   254  }
   255  
   256  func (cr *clusterDispatcher) DispatchReachableResources(
   257  	req *v1.DispatchReachableResourcesRequest,
   258  	stream dispatch.ReachableResourcesStream,
   259  ) error {
   260  	requestKey, err := cr.keyHandler.ReachableResourcesDispatchKey(stream.Context(), req)
   261  	if err != nil {
   262  		return err
   263  	}
   264  
   265  	ctx := context.WithValue(stream.Context(), consistent.CtxKey, requestKey)
   266  	stream = dispatch.StreamWithContext(ctx, stream)
   267  
   268  	if err := dispatch.CheckDepth(ctx, req); err != nil {
   269  		return err
   270  	}
   271  
   272  	withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout)
   273  	defer cancelFn()
   274  
   275  	client, err := cr.clusterClient.DispatchReachableResources(withTimeout, req)
   276  	if err != nil {
   277  		return err
   278  	}
   279  
   280  	for {
   281  		select {
   282  		case <-withTimeout.Done():
   283  			return withTimeout.Err()
   284  
   285  		default:
   286  			result, err := client.Recv()
   287  			if errors.Is(err, io.EOF) {
   288  				return nil
   289  			} else if err != nil {
   290  				return err
   291  			}
   292  
   293  			merr := adjustMetadataForDispatch(result.Metadata)
   294  			if merr != nil {
   295  				return merr
   296  			}
   297  
   298  			serr := stream.Publish(result)
   299  			if serr != nil {
   300  				return serr
   301  			}
   302  		}
   303  	}
   304  }
   305  
   306  func (cr *clusterDispatcher) DispatchLookupResources(
   307  	req *v1.DispatchLookupResourcesRequest,
   308  	stream dispatch.LookupResourcesStream,
   309  ) error {
   310  	requestKey, err := cr.keyHandler.LookupResourcesDispatchKey(stream.Context(), req)
   311  	if err != nil {
   312  		return err
   313  	}
   314  
   315  	ctx := context.WithValue(stream.Context(), consistent.CtxKey, requestKey)
   316  	stream = dispatch.StreamWithContext(ctx, stream)
   317  
   318  	if err := dispatch.CheckDepth(ctx, req); err != nil {
   319  		return err
   320  	}
   321  
   322  	withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout)
   323  	defer cancelFn()
   324  
   325  	client, err := cr.clusterClient.DispatchLookupResources(withTimeout, req)
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	for {
   331  		select {
   332  		case <-withTimeout.Done():
   333  			return withTimeout.Err()
   334  
   335  		default:
   336  			result, err := client.Recv()
   337  			if errors.Is(err, io.EOF) {
   338  				return nil
   339  			} else if err != nil {
   340  				return err
   341  			}
   342  
   343  			merr := adjustMetadataForDispatch(result.Metadata)
   344  			if merr != nil {
   345  				return merr
   346  			}
   347  
   348  			serr := stream.Publish(result)
   349  			if serr != nil {
   350  				return serr
   351  			}
   352  		}
   353  	}
   354  }
   355  
   356  func (cr *clusterDispatcher) DispatchLookupSubjects(
   357  	req *v1.DispatchLookupSubjectsRequest,
   358  	stream dispatch.LookupSubjectsStream,
   359  ) error {
   360  	requestKey, err := cr.keyHandler.LookupSubjectsDispatchKey(stream.Context(), req)
   361  	if err != nil {
   362  		return err
   363  	}
   364  
   365  	ctx := context.WithValue(stream.Context(), consistent.CtxKey, requestKey)
   366  	stream = dispatch.StreamWithContext(ctx, stream)
   367  
   368  	if err := dispatch.CheckDepth(ctx, req); err != nil {
   369  		return err
   370  	}
   371  
   372  	withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout)
   373  	defer cancelFn()
   374  
   375  	client, err := cr.clusterClient.DispatchLookupSubjects(withTimeout, req)
   376  	if err != nil {
   377  		return err
   378  	}
   379  
   380  	for {
   381  		select {
   382  		case <-withTimeout.Done():
   383  			return withTimeout.Err()
   384  
   385  		default:
   386  			result, err := client.Recv()
   387  			if errors.Is(err, io.EOF) {
   388  				return nil
   389  			} else if err != nil {
   390  				return err
   391  			}
   392  
   393  			merr := adjustMetadataForDispatch(result.Metadata)
   394  			if merr != nil {
   395  				return merr
   396  			}
   397  
   398  			serr := stream.Publish(result)
   399  			if serr != nil {
   400  				return serr
   401  			}
   402  		}
   403  	}
   404  }
   405  
   406  func (cr *clusterDispatcher) Close() error {
   407  	return nil
   408  }
   409  
   410  // ReadyState returns whether the underlying dispatch connection is available
   411  func (cr *clusterDispatcher) ReadyState() dispatch.ReadyState {
   412  	state := cr.conn.GetState()
   413  	log.Trace().Interface("connection-state", state).Msg("checked if cluster dispatcher is ready")
   414  	return dispatch.ReadyState{
   415  		IsReady: state == connectivity.Ready || state == connectivity.Idle,
   416  		Message: fmt.Sprintf("found expected state when trying to connect to cluster: %v", state),
   417  	}
   418  }
   419  
   420  // Always verify that we implement the interface
   421  var _ dispatch.Dispatcher = &clusterDispatcher{}
   422  
   423  var emptyMetadata = &v1.ResponseMeta{
   424  	DispatchCount: 0,
   425  }
   426  
   427  var requestFailureMetadata = &v1.ResponseMeta{
   428  	DispatchCount: 1,
   429  }