github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/metadata/metadata.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package metadata
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"google.golang.org/grpc"
    25  	"google.golang.org/grpc/metadata"
    26  
    27  	"github.com/gravitational/teleport/api"
    28  )
    29  
    30  const (
    31  	VersionKey = "version"
    32  )
    33  
    34  // defaultMetadata returns the default metadata which will be added to all outgoing calls.
    35  func defaultMetadata() map[string]string {
    36  	return map[string]string{
    37  		VersionKey: api.Version,
    38  	}
    39  }
    40  
    41  // AddMetadataToContext returns a new context copied from ctx with the given
    42  // raw metadata added. Metadata already set on the given context for any key
    43  // will not be overridden, but new key/value pairs will always be added.
    44  func AddMetadataToContext(ctx context.Context, raw map[string]string) context.Context {
    45  	md := metadata.New(raw)
    46  	if existingMd, ok := metadata.FromOutgoingContext(ctx); ok {
    47  		for key, vals := range existingMd {
    48  			md.Set(key, vals...)
    49  		}
    50  	}
    51  	return metadata.NewOutgoingContext(ctx, md)
    52  }
    53  
    54  // DisableInterceptors can be set on the client context with context.WithValue(ctx, DisableInterceptors{}, struct{}{})
    55  // to stop the client interceptors from adding any metadata to the context (useful for testing).
    56  type DisableInterceptors struct{}
    57  
    58  // StreamServerInterceptor intercepts a gRPC client stream call and adds
    59  // default metadata to the context.
    60  func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    61  	if disable := stream.Context().Value(DisableInterceptors{}); disable == nil {
    62  		header := metadata.New(defaultMetadata())
    63  		grpc.SetHeader(stream.Context(), header)
    64  	}
    65  	return handler(srv, stream)
    66  }
    67  
    68  // UnaryServerInterceptor intercepts a gRPC server unary call and adds default
    69  // metadata to the context.
    70  func UnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    71  	if disable := ctx.Value(DisableInterceptors{}); disable == nil {
    72  		header := metadata.New(defaultMetadata())
    73  		grpc.SetHeader(ctx, header)
    74  	}
    75  	return handler(ctx, req)
    76  }
    77  
    78  // StreamClientInterceptor intercepts a gRPC client stream call and adds
    79  // default metadata to the context.
    80  func StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    81  	if disable := ctx.Value(DisableInterceptors{}); disable == nil {
    82  		ctx = AddMetadataToContext(ctx, defaultMetadata())
    83  	}
    84  	return streamer(ctx, desc, cc, method, opts...)
    85  }
    86  
    87  // UnaryClientInterceptor intercepts a gRPC client unary call and adds default
    88  // metadata to the context.
    89  func UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    90  	if disable := ctx.Value(DisableInterceptors{}); disable == nil {
    91  		ctx = AddMetadataToContext(ctx, defaultMetadata())
    92  	}
    93  	return invoker(ctx, method, req, reply, cc, opts...)
    94  }
    95  
    96  // ClientVersionFromContext can be called from a gRPC server method to return
    97  // the client version that was added to the gRPC metadata by
    98  // StreamClientInterceptor or UnaryClientInterceptor on the client.
    99  func ClientVersionFromContext(ctx context.Context) (string, bool) {
   100  	md, ok := metadata.FromIncomingContext(ctx)
   101  	if !ok {
   102  		return "", false
   103  	}
   104  
   105  	return VersionFromMetadata(md)
   106  }
   107  
   108  // VersionFromMetadata attempts to extract the standard version metadata value that is
   109  // added to client and server headers by the interceptors in this package.
   110  func VersionFromMetadata(md metadata.MD) (string, bool) {
   111  	versionList := md.Get(VersionKey)
   112  	if len(versionList) != 1 {
   113  		return "", false
   114  	}
   115  	return versionList[0], true
   116  }
   117  
   118  // WithUserAgentFromTeleportComponent returns a grpc.DialOption that reports
   119  // the Teleport component and the API version for user agent.
   120  func WithUserAgentFromTeleportComponent(component string) grpc.DialOption {
   121  	return grpc.WithUserAgent(fmt.Sprintf("%s/%s", component, api.Version))
   122  }
   123  
   124  // UserAgentFromContext returns the user agent from GRPC client metadata.
   125  func UserAgentFromContext(ctx context.Context) string {
   126  	md, ok := metadata.FromIncomingContext(ctx)
   127  	if !ok {
   128  		return ""
   129  	}
   130  	values := md.Get("user-agent")
   131  	if len(values) == 0 {
   132  		return ""
   133  	}
   134  	return strings.Join(values, " ")
   135  }