github.com/hashicorp/vault/sdk@v0.13.0/logical/request_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package logical
     5  
     6  import (
     7  	"context"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func TestContextDisableReplicationStatusEndpointsValue(t *testing.T) {
    14  	testcases := []struct {
    15  		name          string
    16  		ctx           context.Context
    17  		expectedValue bool
    18  		expectedOk    bool
    19  	}{
    20  		{
    21  			name:          "without-value",
    22  			ctx:           context.Background(),
    23  			expectedValue: false,
    24  			expectedOk:    false,
    25  		},
    26  		{
    27  			name:          "with-nil",
    28  			ctx:           context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, nil),
    29  			expectedValue: false,
    30  			expectedOk:    false,
    31  		},
    32  		{
    33  			name:          "with-incompatible-value",
    34  			ctx:           context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, "true"),
    35  			expectedValue: false,
    36  			expectedOk:    false,
    37  		},
    38  		{
    39  			name:          "with-bool-true",
    40  			ctx:           context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, true),
    41  			expectedValue: true,
    42  			expectedOk:    true,
    43  		},
    44  		{
    45  			name:          "with-bool-false",
    46  			ctx:           context.WithValue(context.Background(), ctxKeyDisableReplicationStatusEndpoints{}, false),
    47  			expectedValue: false,
    48  			expectedOk:    true,
    49  		},
    50  	}
    51  
    52  	for _, testcase := range testcases {
    53  		value, ok := ContextDisableReplicationStatusEndpointsValue(testcase.ctx)
    54  		assert.Equal(t, testcase.expectedValue, value, testcase.name)
    55  		assert.Equal(t, testcase.expectedOk, ok, testcase.name)
    56  	}
    57  }
    58  
    59  func TestCreateContextDisableReplicationStatusEndpoints(t *testing.T) {
    60  	ctx := CreateContextDisableReplicationStatusEndpoints(context.Background(), true)
    61  
    62  	value := ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
    63  
    64  	assert.NotNil(t, ctx)
    65  	assert.NotNil(t, value)
    66  	assert.IsType(t, bool(false), value)
    67  	assert.Equal(t, true, value.(bool))
    68  
    69  	ctx = CreateContextDisableReplicationStatusEndpoints(context.Background(), false)
    70  
    71  	value = ctx.Value(ctxKeyDisableReplicationStatusEndpoints{})
    72  
    73  	assert.NotNil(t, ctx)
    74  	assert.NotNil(t, value)
    75  	assert.IsType(t, bool(false), value)
    76  	assert.Equal(t, false, value.(bool))
    77  }
    78  
    79  func TestContextOriginalRequestPathValue(t *testing.T) {
    80  	testcases := []struct {
    81  		name          string
    82  		ctx           context.Context
    83  		expectedValue string
    84  		expectedOk    bool
    85  	}{
    86  		{
    87  			name:          "without-value",
    88  			ctx:           context.Background(),
    89  			expectedValue: "",
    90  			expectedOk:    false,
    91  		},
    92  		{
    93  			name:          "with-nil",
    94  			ctx:           context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, nil),
    95  			expectedValue: "",
    96  			expectedOk:    false,
    97  		},
    98  		{
    99  			name:          "with-incompatible-value",
   100  			ctx:           context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, 6666),
   101  			expectedValue: "",
   102  			expectedOk:    false,
   103  		},
   104  		{
   105  			name:          "with-string-value",
   106  			ctx:           context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, "test"),
   107  			expectedValue: "test",
   108  			expectedOk:    true,
   109  		},
   110  		{
   111  			name:          "with-empty-string",
   112  			ctx:           context.WithValue(context.Background(), ctxKeyOriginalRequestPath{}, ""),
   113  			expectedValue: "",
   114  			expectedOk:    true,
   115  		},
   116  	}
   117  
   118  	for _, testcase := range testcases {
   119  		value, ok := ContextOriginalRequestPathValue(testcase.ctx)
   120  		assert.Equal(t, testcase.expectedValue, value, testcase.name)
   121  		assert.Equal(t, testcase.expectedOk, ok, testcase.name)
   122  	}
   123  }
   124  
   125  func TestCreateContextOriginalRequestPath(t *testing.T) {
   126  	ctx := CreateContextOriginalRequestPath(context.Background(), "test")
   127  
   128  	value := ctx.Value(ctxKeyOriginalRequestPath{})
   129  
   130  	assert.NotNil(t, ctx)
   131  	assert.NotNil(t, value)
   132  	assert.IsType(t, string(""), value)
   133  	assert.Equal(t, "test", value.(string))
   134  
   135  	ctx = CreateContextOriginalRequestPath(context.Background(), "")
   136  
   137  	value = ctx.Value(ctxKeyOriginalRequestPath{})
   138  
   139  	assert.NotNil(t, ctx)
   140  	assert.NotNil(t, value)
   141  	assert.IsType(t, string(""), value)
   142  	assert.Equal(t, "", value.(string))
   143  }