github.com/hashicorp/vault/sdk@v0.11.0/helper/pluginutil/multiplexing.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package pluginutil
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"strings"
    12  
    13  	"github.com/hashicorp/go-secure-stdlib/strutil"
    14  	"google.golang.org/grpc"
    15  	"google.golang.org/grpc/codes"
    16  	"google.golang.org/grpc/metadata"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  var ErrNoMultiplexingIDFound = errors.New("no multiplexing ID found")
    21  
    22  type PluginMultiplexingServerImpl struct {
    23  	UnimplementedPluginMultiplexingServer
    24  
    25  	Supported bool
    26  }
    27  
    28  func (pm PluginMultiplexingServerImpl) MultiplexingSupport(_ context.Context, _ *MultiplexingSupportRequest) (*MultiplexingSupportResponse, error) {
    29  	return &MultiplexingSupportResponse{
    30  		Supported: pm.Supported,
    31  	}, nil
    32  }
    33  
    34  func MultiplexingSupported(ctx context.Context, cc grpc.ClientConnInterface, name string) (bool, error) {
    35  	if cc == nil {
    36  		return false, fmt.Errorf("client connection is nil")
    37  	}
    38  
    39  	out := strings.Split(os.Getenv(PluginMultiplexingOptOut), ",")
    40  	if strutil.StrListContains(out, name) {
    41  		return false, nil
    42  	}
    43  
    44  	req := new(MultiplexingSupportRequest)
    45  	resp, err := NewPluginMultiplexingClient(cc).MultiplexingSupport(ctx, req)
    46  	if err != nil {
    47  
    48  		// If the server does not implement the multiplexing server then we can
    49  		// assume it is not multiplexed
    50  		if status.Code(err) == codes.Unimplemented {
    51  			return false, nil
    52  		}
    53  
    54  		return false, err
    55  	}
    56  	if resp == nil {
    57  		// Somehow got a nil response, assume not multiplexed
    58  		return false, nil
    59  	}
    60  
    61  	return resp.Supported, nil
    62  }
    63  
    64  func GetMultiplexIDFromContext(ctx context.Context) (string, error) {
    65  	md, ok := metadata.FromIncomingContext(ctx)
    66  	if !ok {
    67  		return "", fmt.Errorf("missing plugin multiplexing metadata")
    68  	}
    69  
    70  	multiplexIDs := md[MultiplexingCtxKey]
    71  	if len(multiplexIDs) == 0 {
    72  		return "", ErrNoMultiplexingIDFound
    73  	} else if len(multiplexIDs) != 1 {
    74  		return "", fmt.Errorf("unexpected number of IDs in metadata: (%d)", len(multiplexIDs))
    75  	}
    76  
    77  	multiplexID := multiplexIDs[0]
    78  	if multiplexID == "" {
    79  		return "", fmt.Errorf("empty multiplex ID in metadata")
    80  	}
    81  
    82  	return multiplexID, nil
    83  }