github.com/TBD54566975/ftl@v0.219.0/internal/rpc/context.go (about) 1 package rpc 2 3 import ( 4 "context" 5 "fmt" 6 "net/http" 7 8 "connectrpc.com/connect" 9 "connectrpc.com/otelconnect" 10 "github.com/alecthomas/types/optional" 11 "golang.org/x/mod/semver" 12 13 "github.com/TBD54566975/ftl" 14 "github.com/TBD54566975/ftl/backend/schema" 15 "github.com/TBD54566975/ftl/internal/log" 16 "github.com/TBD54566975/ftl/internal/model" 17 "github.com/TBD54566975/ftl/internal/rpc/headers" 18 ) 19 20 type ftlDirectRoutingKey struct{} 21 type ftlVerbKey struct{} 22 type requestIDKey struct{} 23 24 // WithDirectRouting ensures any hops in Verb routing do not redirect. 25 // 26 // This is used so that eg. calls from Drives do not create recursive loops 27 // when calling back to the Agent. 28 func WithDirectRouting(ctx context.Context) context.Context { 29 return context.WithValue(ctx, ftlDirectRoutingKey{}, "1") 30 } 31 32 // WithVerbs adds the module.verb chain from the current request to the context. 33 func WithVerbs(ctx context.Context, verbs []*schema.Ref) context.Context { 34 return context.WithValue(ctx, ftlVerbKey{}, verbs) 35 } 36 37 // VerbFromContext returns the current module.verb of the current request. 38 func VerbFromContext(ctx context.Context) (*schema.Ref, bool) { 39 value := ctx.Value(ftlVerbKey{}) 40 verbs, ok := value.([]*schema.Ref) 41 if len(verbs) == 0 { 42 return nil, false 43 } 44 return verbs[len(verbs)-1], ok 45 } 46 47 // VerbsFromContext returns the module.verb chain of the current request. 48 func VerbsFromContext(ctx context.Context) ([]*schema.Ref, bool) { 49 value := ctx.Value(ftlVerbKey{}) 50 verbs, ok := value.([]*schema.Ref) 51 return verbs, ok 52 } 53 54 // IsDirectRouted returns true if the incoming request should be directly 55 // routed and never redirected. 56 func IsDirectRouted(ctx context.Context) bool { 57 return ctx.Value(ftlDirectRoutingKey{}) != nil 58 } 59 60 // RequestKeyFromContext returns the request key from the context, if any. 61 // 62 // TODO: Return an Option here instead of a bool. 63 func RequestKeyFromContext(ctx context.Context) (optional.Option[model.RequestKey], error) { 64 value := ctx.Value(requestIDKey{}) 65 keyStr, ok := value.(string) 66 if !ok { 67 return optional.None[model.RequestKey](), nil 68 } 69 key, err := model.ParseRequestKey(keyStr) 70 if err != nil { 71 return optional.None[model.RequestKey](), fmt.Errorf("invalid request key: %w", err) 72 } 73 return optional.Some(key), nil 74 } 75 76 // WithRequestName adds the request key to the context. 77 func WithRequestName(ctx context.Context, key model.RequestKey) context.Context { 78 return context.WithValue(ctx, requestIDKey{}, key.String()) 79 } 80 81 func DefaultClientOptions(level log.Level) []connect.ClientOption { 82 interceptors := []connect.Interceptor{MetadataInterceptor(log.Debug), otelInterceptor()} 83 if ftl.Version != "dev" { 84 interceptors = append(interceptors, versionInterceptor{}) 85 } 86 return []connect.ClientOption{ 87 connect.WithGRPC(), // Use gRPC because some servers will not be using Connect. 88 connect.WithInterceptors(interceptors...), 89 } 90 } 91 92 func DefaultHandlerOptions() []connect.HandlerOption { 93 interceptors := []connect.Interceptor{MetadataInterceptor(log.Debug), otelInterceptor()} 94 if ftl.Version != "dev" { 95 interceptors = append(interceptors, versionInterceptor{}) 96 } 97 return []connect.HandlerOption{connect.WithInterceptors(interceptors...)} 98 } 99 100 func otelInterceptor() connect.Interceptor { 101 otel, err := otelconnect.NewInterceptor(otelconnect.WithTrustRemote(), otelconnect.WithoutServerPeerAttributes()) 102 if err != nil { 103 panic(err) 104 } 105 return otel 106 } 107 108 // MetadataInterceptor propagates FTL metadata through servers and clients. 109 // 110 // "errorLevel" is the level at which errors will be logged 111 func MetadataInterceptor(errorLevel log.Level) connect.Interceptor { 112 return &metadataInterceptor{ 113 errorLevel: errorLevel, 114 } 115 } 116 117 type metadataInterceptor struct { 118 errorLevel log.Level 119 } 120 121 func (*metadataInterceptor) WrapStreamingClient(req connect.StreamingClientFunc) connect.StreamingClientFunc { 122 return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { 123 // TODO(aat): I can't figure out how to get the client headers here. 124 logger := log.FromContext(ctx) 125 logger.Tracef("%s (streaming client)", s.Procedure) 126 return req(ctx, s) 127 } 128 } 129 130 func (m *metadataInterceptor) WrapStreamingHandler(req connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 131 return func(ctx context.Context, s connect.StreamingHandlerConn) error { 132 logger := log.FromContext(ctx) 133 logger.Tracef("%s (streaming handler)", s.Spec().Procedure) 134 ctx, err := propagateHeaders(ctx, s.Spec().IsClient, s.RequestHeader()) 135 if err != nil { 136 return err 137 } 138 err = req(ctx, s) 139 if err != nil { 140 if connect.CodeOf(err) == connect.CodeCanceled { 141 return nil 142 } 143 logger.Logf(m.errorLevel, "Streaming RPC failed: %s: %s", err, s.Spec().Procedure) 144 return err 145 } 146 return nil 147 } 148 } 149 150 func (m *metadataInterceptor) WrapUnary(uf connect.UnaryFunc) connect.UnaryFunc { 151 return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { 152 logger := log.FromContext(ctx) 153 logger.Tracef("%s (unary)", req.Spec().Procedure) 154 ctx, err := propagateHeaders(ctx, req.Spec().IsClient, req.Header()) 155 if err != nil { 156 return nil, err 157 } 158 resp, err := uf(ctx, req) 159 if err != nil { 160 logger.Logf(m.errorLevel, "Unary RPC failed: %s: %s", err, req.Spec().Procedure) 161 return nil, err 162 } 163 return resp, nil 164 } 165 } 166 167 type clientKey[Client Pingable] struct{} 168 169 // ContextWithClient returns a context with an RPC client attached. 170 func ContextWithClient[Client Pingable](ctx context.Context, client Client) context.Context { 171 return context.WithValue(ctx, clientKey[Client]{}, client) 172 } 173 174 // ClientFromContext returns the given RPC client from the context, or panics. 175 func ClientFromContext[Client Pingable](ctx context.Context) Client { 176 value := ctx.Value(clientKey[Client]{}) 177 if value == nil { 178 panic("no RPC client in context") 179 } 180 return value.(Client) //nolint:forcetypeassert 181 } 182 183 func IsClientAvailableInContext[Client Pingable](ctx context.Context) bool { 184 return ctx.Value(clientKey[Client]{}) != nil 185 } 186 187 func propagateHeaders(ctx context.Context, isClient bool, header http.Header) (context.Context, error) { 188 if isClient { 189 if IsDirectRouted(ctx) { 190 headers.SetDirectRouted(header) 191 } 192 if verbs, ok := VerbsFromContext(ctx); ok { 193 headers.SetCallers(header, verbs) 194 } 195 if key, err := RequestKeyFromContext(ctx); err != nil { 196 return nil, err 197 } else if key, ok := key.Get(); ok { 198 headers.SetRequestKey(header, key) 199 } 200 } else { 201 if headers.IsDirectRouted(header) { 202 ctx = WithDirectRouting(ctx) 203 } 204 if verbs, err := headers.GetCallers(header); err != nil { 205 return nil, err 206 } else { //nolint:revive 207 ctx = WithVerbs(ctx, verbs) 208 } 209 if key, ok, err := headers.GetRequestKey(header); err != nil { 210 return nil, err 211 } else if ok { 212 ctx = WithRequestName(ctx, key) 213 } 214 } 215 return ctx, nil 216 } 217 218 // versionInterceptor reports a warning to the client if the client is older than the server. 219 type versionInterceptor struct{} 220 221 func (v versionInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { 222 return next 223 } 224 225 func (v versionInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 226 return next 227 } 228 229 func (v versionInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { 230 return func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { 231 resp, err := next(ctx, ar) 232 if err != nil { 233 return nil, err 234 } 235 if ar.Spec().IsClient { 236 if err := v.checkVersion(resp.Header()); err != nil { 237 log.FromContext(ctx).Warnf("%s", err) 238 } 239 } else { 240 resp.Header().Set("X-FTL-Version", ftl.Version) 241 } 242 return resp, nil 243 } 244 } 245 246 func (v versionInterceptor) checkVersion(header http.Header) error { 247 version := header.Get("X-FTL-Version") 248 if semver.Compare(ftl.Version, version) < 0 { 249 return fmt.Errorf("FTL client (%s) is older than server (%s), consider upgrading: https://github.com/TBD54566975/ftl/releases", ftl.Version, version) 250 } 251 return nil 252 }