github.com/TBD54566975/ftl@v0.219.0/internal/rpc/context.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  
     8  	"connectrpc.com/connect"
     9  	"connectrpc.com/otelconnect"
    10  	"github.com/alecthomas/types/optional"
    11  	"golang.org/x/mod/semver"
    12  
    13  	"github.com/TBD54566975/ftl"
    14  	"github.com/TBD54566975/ftl/backend/schema"
    15  	"github.com/TBD54566975/ftl/internal/log"
    16  	"github.com/TBD54566975/ftl/internal/model"
    17  	"github.com/TBD54566975/ftl/internal/rpc/headers"
    18  )
    19  
    20  type ftlDirectRoutingKey struct{}
    21  type ftlVerbKey struct{}
    22  type requestIDKey struct{}
    23  
    24  // WithDirectRouting ensures any hops in Verb routing do not redirect.
    25  //
    26  // This is used so that eg. calls from Drives do not create recursive loops
    27  // when calling back to the Agent.
    28  func WithDirectRouting(ctx context.Context) context.Context {
    29  	return context.WithValue(ctx, ftlDirectRoutingKey{}, "1")
    30  }
    31  
    32  // WithVerbs adds the module.verb chain from the current request to the context.
    33  func WithVerbs(ctx context.Context, verbs []*schema.Ref) context.Context {
    34  	return context.WithValue(ctx, ftlVerbKey{}, verbs)
    35  }
    36  
    37  // VerbFromContext returns the current module.verb of the current request.
    38  func VerbFromContext(ctx context.Context) (*schema.Ref, bool) {
    39  	value := ctx.Value(ftlVerbKey{})
    40  	verbs, ok := value.([]*schema.Ref)
    41  	if len(verbs) == 0 {
    42  		return nil, false
    43  	}
    44  	return verbs[len(verbs)-1], ok
    45  }
    46  
    47  // VerbsFromContext returns the module.verb chain of the current request.
    48  func VerbsFromContext(ctx context.Context) ([]*schema.Ref, bool) {
    49  	value := ctx.Value(ftlVerbKey{})
    50  	verbs, ok := value.([]*schema.Ref)
    51  	return verbs, ok
    52  }
    53  
    54  // IsDirectRouted returns true if the incoming request should be directly
    55  // routed and never redirected.
    56  func IsDirectRouted(ctx context.Context) bool {
    57  	return ctx.Value(ftlDirectRoutingKey{}) != nil
    58  }
    59  
    60  // RequestKeyFromContext returns the request key from the context, if any.
    61  //
    62  // TODO: Return an Option here instead of a bool.
    63  func RequestKeyFromContext(ctx context.Context) (optional.Option[model.RequestKey], error) {
    64  	value := ctx.Value(requestIDKey{})
    65  	keyStr, ok := value.(string)
    66  	if !ok {
    67  		return optional.None[model.RequestKey](), nil
    68  	}
    69  	key, err := model.ParseRequestKey(keyStr)
    70  	if err != nil {
    71  		return optional.None[model.RequestKey](), fmt.Errorf("invalid request key: %w", err)
    72  	}
    73  	return optional.Some(key), nil
    74  }
    75  
    76  // WithRequestName adds the request key to the context.
    77  func WithRequestName(ctx context.Context, key model.RequestKey) context.Context {
    78  	return context.WithValue(ctx, requestIDKey{}, key.String())
    79  }
    80  
    81  func DefaultClientOptions(level log.Level) []connect.ClientOption {
    82  	interceptors := []connect.Interceptor{MetadataInterceptor(log.Debug), otelInterceptor()}
    83  	if ftl.Version != "dev" {
    84  		interceptors = append(interceptors, versionInterceptor{})
    85  	}
    86  	return []connect.ClientOption{
    87  		connect.WithGRPC(), // Use gRPC because some servers will not be using Connect.
    88  		connect.WithInterceptors(interceptors...),
    89  	}
    90  }
    91  
    92  func DefaultHandlerOptions() []connect.HandlerOption {
    93  	interceptors := []connect.Interceptor{MetadataInterceptor(log.Debug), otelInterceptor()}
    94  	if ftl.Version != "dev" {
    95  		interceptors = append(interceptors, versionInterceptor{})
    96  	}
    97  	return []connect.HandlerOption{connect.WithInterceptors(interceptors...)}
    98  }
    99  
   100  func otelInterceptor() connect.Interceptor {
   101  	otel, err := otelconnect.NewInterceptor(otelconnect.WithTrustRemote(), otelconnect.WithoutServerPeerAttributes())
   102  	if err != nil {
   103  		panic(err)
   104  	}
   105  	return otel
   106  }
   107  
   108  // MetadataInterceptor propagates FTL metadata through servers and clients.
   109  //
   110  // "errorLevel" is the level at which errors will be logged
   111  func MetadataInterceptor(errorLevel log.Level) connect.Interceptor {
   112  	return &metadataInterceptor{
   113  		errorLevel: errorLevel,
   114  	}
   115  }
   116  
   117  type metadataInterceptor struct {
   118  	errorLevel log.Level
   119  }
   120  
   121  func (*metadataInterceptor) WrapStreamingClient(req connect.StreamingClientFunc) connect.StreamingClientFunc {
   122  	return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
   123  		// TODO(aat): I can't figure out how to get the client headers here.
   124  		logger := log.FromContext(ctx)
   125  		logger.Tracef("%s (streaming client)", s.Procedure)
   126  		return req(ctx, s)
   127  	}
   128  }
   129  
   130  func (m *metadataInterceptor) WrapStreamingHandler(req connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
   131  	return func(ctx context.Context, s connect.StreamingHandlerConn) error {
   132  		logger := log.FromContext(ctx)
   133  		logger.Tracef("%s (streaming handler)", s.Spec().Procedure)
   134  		ctx, err := propagateHeaders(ctx, s.Spec().IsClient, s.RequestHeader())
   135  		if err != nil {
   136  			return err
   137  		}
   138  		err = req(ctx, s)
   139  		if err != nil {
   140  			if connect.CodeOf(err) == connect.CodeCanceled {
   141  				return nil
   142  			}
   143  			logger.Logf(m.errorLevel, "Streaming RPC failed: %s: %s", err, s.Spec().Procedure)
   144  			return err
   145  		}
   146  		return nil
   147  	}
   148  }
   149  
   150  func (m *metadataInterceptor) WrapUnary(uf connect.UnaryFunc) connect.UnaryFunc {
   151  	return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
   152  		logger := log.FromContext(ctx)
   153  		logger.Tracef("%s (unary)", req.Spec().Procedure)
   154  		ctx, err := propagateHeaders(ctx, req.Spec().IsClient, req.Header())
   155  		if err != nil {
   156  			return nil, err
   157  		}
   158  		resp, err := uf(ctx, req)
   159  		if err != nil {
   160  			logger.Logf(m.errorLevel, "Unary RPC failed: %s: %s", err, req.Spec().Procedure)
   161  			return nil, err
   162  		}
   163  		return resp, nil
   164  	}
   165  }
   166  
   167  type clientKey[Client Pingable] struct{}
   168  
   169  // ContextWithClient returns a context with an RPC client attached.
   170  func ContextWithClient[Client Pingable](ctx context.Context, client Client) context.Context {
   171  	return context.WithValue(ctx, clientKey[Client]{}, client)
   172  }
   173  
   174  // ClientFromContext returns the given RPC client from the context, or panics.
   175  func ClientFromContext[Client Pingable](ctx context.Context) Client {
   176  	value := ctx.Value(clientKey[Client]{})
   177  	if value == nil {
   178  		panic("no RPC client in context")
   179  	}
   180  	return value.(Client) //nolint:forcetypeassert
   181  }
   182  
   183  func IsClientAvailableInContext[Client Pingable](ctx context.Context) bool {
   184  	return ctx.Value(clientKey[Client]{}) != nil
   185  }
   186  
   187  func propagateHeaders(ctx context.Context, isClient bool, header http.Header) (context.Context, error) {
   188  	if isClient {
   189  		if IsDirectRouted(ctx) {
   190  			headers.SetDirectRouted(header)
   191  		}
   192  		if verbs, ok := VerbsFromContext(ctx); ok {
   193  			headers.SetCallers(header, verbs)
   194  		}
   195  		if key, err := RequestKeyFromContext(ctx); err != nil {
   196  			return nil, err
   197  		} else if key, ok := key.Get(); ok {
   198  			headers.SetRequestKey(header, key)
   199  		}
   200  	} else {
   201  		if headers.IsDirectRouted(header) {
   202  			ctx = WithDirectRouting(ctx)
   203  		}
   204  		if verbs, err := headers.GetCallers(header); err != nil {
   205  			return nil, err
   206  		} else { //nolint:revive
   207  			ctx = WithVerbs(ctx, verbs)
   208  		}
   209  		if key, ok, err := headers.GetRequestKey(header); err != nil {
   210  			return nil, err
   211  		} else if ok {
   212  			ctx = WithRequestName(ctx, key)
   213  		}
   214  	}
   215  	return ctx, nil
   216  }
   217  
   218  // versionInterceptor reports a warning to the client if the client is older than the server.
   219  type versionInterceptor struct{}
   220  
   221  func (v versionInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
   222  	return next
   223  }
   224  
   225  func (v versionInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
   226  	return next
   227  }
   228  
   229  func (v versionInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
   230  	return func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
   231  		resp, err := next(ctx, ar)
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  		if ar.Spec().IsClient {
   236  			if err := v.checkVersion(resp.Header()); err != nil {
   237  				log.FromContext(ctx).Warnf("%s", err)
   238  			}
   239  		} else {
   240  			resp.Header().Set("X-FTL-Version", ftl.Version)
   241  		}
   242  		return resp, nil
   243  	}
   244  }
   245  
   246  func (v versionInterceptor) checkVersion(header http.Header) error {
   247  	version := header.Get("X-FTL-Version")
   248  	if semver.Compare(ftl.Version, version) < 0 {
   249  		return fmt.Errorf("FTL client (%s) is older than server (%s), consider upgrading: https://github.com/TBD54566975/ftl/releases", ftl.Version, version)
   250  	}
   251  	return nil
   252  }