github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/grpc_auth.go (about)

     1  package middleware
     2  
     3  import (
     4  	"golang.org/x/net/context"
     5  	"google.golang.org/grpc"
     6  
     7  	"github.com/weaveworks/common/user"
     8  )
     9  
    10  // ClientUserHeaderInterceptor propagates the user ID from the context to gRPC metadata, which eventually ends up as a HTTP2 header.
    11  func ClientUserHeaderInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    12  	ctx, err := user.InjectIntoGRPCRequest(ctx)
    13  	if err != nil {
    14  		return err
    15  	}
    16  
    17  	return invoker(ctx, method, req, reply, cc, opts...)
    18  }
    19  
    20  // StreamClientUserHeaderInterceptor propagates the user ID from the context to gRPC metadata, which eventually ends up as a HTTP2 header.
    21  // For streaming gRPC requests.
    22  func StreamClientUserHeaderInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    23  	ctx, err := user.InjectIntoGRPCRequest(ctx)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  
    28  	return streamer(ctx, desc, cc, method, opts...)
    29  }
    30  
    31  // ServerUserHeaderInterceptor propagates the user ID from the gRPC metadata back to our context.
    32  func ServerUserHeaderInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    33  	_, ctx, err := user.ExtractFromGRPCRequest(ctx)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	return handler(ctx, req)
    39  }
    40  
    41  // StreamServerUserHeaderInterceptor propagates the user ID from the gRPC metadata back to our context.
    42  func StreamServerUserHeaderInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    43  	_, ctx, err := user.ExtractFromGRPCRequest(ss.Context())
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	return handler(srv, serverStream{
    49  		ctx:          ctx,
    50  		ServerStream: ss,
    51  	})
    52  }
    53  
    54  type serverStream struct {
    55  	ctx context.Context
    56  	grpc.ServerStream
    57  }
    58  
    59  func (ss serverStream) Context() context.Context {
    60  	return ss.ctx
    61  }