github.com/hashicorp/vault/sdk@v0.11.0/helper/pluginutil/multiplexing_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package pluginutil
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"reflect"
    10  	"testing"
    11  
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  )
    15  
    16  func TestMultiplexingSupported(t *testing.T) {
    17  	type args struct {
    18  		ctx  context.Context
    19  		cc   grpc.ClientConnInterface
    20  		name string
    21  	}
    22  
    23  	type testCase struct {
    24  		name    string
    25  		args    args
    26  		env     string
    27  		want    bool
    28  		wantErr bool
    29  	}
    30  
    31  	tests := []testCase{
    32  		{
    33  			name: "multiplexing is supported if plugin is not opted out",
    34  			args: args{
    35  				ctx:  context.Background(),
    36  				cc:   &MockClientConnInterfaceNoop{},
    37  				name: "plugin",
    38  			},
    39  			env:  "",
    40  			want: true,
    41  		},
    42  		{
    43  			name: "multiplexing is not supported if plugin is opted out",
    44  			args: args{
    45  				ctx:  context.Background(),
    46  				cc:   &MockClientConnInterfaceNoop{},
    47  				name: "optedOutPlugin",
    48  			},
    49  			env:  "optedOutPlugin",
    50  			want: false,
    51  		},
    52  		{
    53  			name: "multiplexing is not supported if plugin among one of the opted out",
    54  			args: args{
    55  				ctx:  context.Background(),
    56  				cc:   &MockClientConnInterfaceNoop{},
    57  				name: "optedOutPlugin",
    58  			},
    59  			env:  "firstPlugin,optedOutPlugin,otherPlugin",
    60  			want: false,
    61  		},
    62  		{
    63  			name: "multiplexing is supported if different plugin is opted out",
    64  			args: args{
    65  				ctx:  context.Background(),
    66  				cc:   &MockClientConnInterfaceNoop{},
    67  				name: "plugin",
    68  			},
    69  			env:  "optedOutPlugin",
    70  			want: true,
    71  		},
    72  	}
    73  	for _, tt := range tests {
    74  		t.Run(tt.name, func(t *testing.T) {
    75  			t.Setenv(PluginMultiplexingOptOut, tt.env)
    76  			got, err := MultiplexingSupported(tt.args.ctx, tt.args.cc, tt.args.name)
    77  			if (err != nil) != tt.wantErr {
    78  				t.Errorf("MultiplexingSupported() error = %v, wantErr %v", err, tt.wantErr)
    79  				return
    80  			}
    81  			if got != tt.want {
    82  				t.Errorf("MultiplexingSupported() got = %v, want %v", got, tt.want)
    83  			}
    84  		})
    85  	}
    86  }
    87  
    88  func TestGetMultiplexIDFromContext(t *testing.T) {
    89  	type testCase struct {
    90  		ctx          context.Context
    91  		expectedResp string
    92  		expectedErr  error
    93  	}
    94  
    95  	tests := map[string]testCase{
    96  		"missing plugin multiplexing metadata": {
    97  			ctx:          context.Background(),
    98  			expectedResp: "",
    99  			expectedErr:  fmt.Errorf("missing plugin multiplexing metadata"),
   100  		},
   101  		"unexpected number of IDs in metadata": {
   102  			ctx:          idCtx(t, "12345", "67891"),
   103  			expectedResp: "",
   104  			expectedErr:  fmt.Errorf("unexpected number of IDs in metadata: (2)"),
   105  		},
   106  		"empty multiplex ID in metadata": {
   107  			ctx:          idCtx(t, ""),
   108  			expectedResp: "",
   109  			expectedErr:  fmt.Errorf("empty multiplex ID in metadata"),
   110  		},
   111  		"happy path, id is returned from metadata": {
   112  			ctx:          idCtx(t, "12345"),
   113  			expectedResp: "12345",
   114  			expectedErr:  nil,
   115  		},
   116  	}
   117  
   118  	for name, test := range tests {
   119  		t.Run(name, func(t *testing.T) {
   120  			resp, err := GetMultiplexIDFromContext(test.ctx)
   121  
   122  			if test.expectedErr != nil && test.expectedErr.Error() != "" && err == nil {
   123  				t.Fatalf("err expected, got nil")
   124  			} else if !reflect.DeepEqual(err, test.expectedErr) {
   125  				t.Fatalf("Actual error: %#v\nExpected error: %#v", err, test.expectedErr)
   126  			}
   127  
   128  			if test.expectedErr != nil && test.expectedErr.Error() == "" && err != nil {
   129  				t.Fatalf("no error expected, got: %s", err)
   130  			}
   131  
   132  			if !reflect.DeepEqual(resp, test.expectedResp) {
   133  				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
   134  			}
   135  		})
   136  	}
   137  }
   138  
   139  // idCtx is a test helper that will return a context with the IDs set in its
   140  // metadata
   141  func idCtx(t *testing.T, ids ...string) context.Context {
   142  	// Context doesn't need to timeout since this is just passed through
   143  	ctx := context.Background()
   144  	md := metadata.MD{}
   145  	for _, id := range ids {
   146  		md.Append(MultiplexingCtxKey, id)
   147  	}
   148  	return metadata.NewIncomingContext(ctx, md)
   149  }
   150  
   151  type MockClientConnInterfaceNoop struct{}
   152  
   153  func (m *MockClientConnInterfaceNoop) Invoke(_ context.Context, _ string, _ interface{}, reply interface{}, _ ...grpc.CallOption) error {
   154  	reply.(*MultiplexingSupportResponse).Supported = true
   155  	return nil
   156  }
   157  
   158  func (m *MockClientConnInterfaceNoop) NewStream(_ context.Context, _ *grpc.StreamDesc, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
   159  	return nil, nil
   160  }