github.com/hernad/nomad@v1.6.112/nomad/state/state_store_acl_sso_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package state
     5  
     6  import (
     7  	"testing"
     8  
     9  	"github.com/hashicorp/go-memdb"
    10  	"github.com/hernad/nomad/ci"
    11  	"github.com/hernad/nomad/nomad/mock"
    12  	"github.com/hernad/nomad/nomad/structs"
    13  	"github.com/shoenig/test/must"
    14  )
    15  
    16  func TestStateStore_UpsertACLAuthMethods(t *testing.T) {
    17  	ci.Parallel(t)
    18  	testState := testStateStore(t)
    19  
    20  	// Create mock auth methods
    21  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLOIDCAuthMethod(), mock.ACLOIDCAuthMethod()}
    22  
    23  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
    24  
    25  	// Check that the index for the table was modified as expected.
    26  	initialIndex, err := testState.Index(TableACLAuthMethods)
    27  	must.NoError(t, err)
    28  	must.Eq(t, 10, initialIndex)
    29  
    30  	// List all the auth methods in the table, so we can perform a number of
    31  	// tests on the return array.
    32  	ws := memdb.NewWatchSet()
    33  	iter, err := testState.GetACLAuthMethods(ws)
    34  	must.NoError(t, err)
    35  
    36  	// Count how many table entries we have, to ensure it is the expected
    37  	// number.
    38  	var count int
    39  
    40  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
    41  		count++
    42  
    43  		// Ensure the create and modify indexes are populated correctly.
    44  		authMethod := raw.(*structs.ACLAuthMethod)
    45  		must.Eq(t, 10, authMethod.CreateIndex)
    46  		must.Eq(t, 10, authMethod.ModifyIndex)
    47  	}
    48  	must.Eq(t, 2, count)
    49  
    50  	// Try writing the same auth methods to state which should not result in an
    51  	// update to the table index.
    52  	must.NoError(t, testState.UpsertACLAuthMethods(20, mockedACLAuthMethods))
    53  	reInsertActualIndex, err := testState.Index(TableACLAuthMethods)
    54  	must.NoError(t, err)
    55  	must.Eq(t, 10, reInsertActualIndex)
    56  
    57  	// Make a change to the auth methods and ensure this update is accepted and
    58  	// the table index is updated.
    59  	updatedMockedAuthMethod1 := mockedACLAuthMethods[0].Copy()
    60  	updatedMockedAuthMethod1.Type = "new type"
    61  	updatedMockedAuthMethod1.SetHash()
    62  	updatedMockedAuthMethod2 := mockedACLAuthMethods[1].Copy()
    63  	updatedMockedAuthMethod2.Type = "yet another new type"
    64  	updatedMockedAuthMethod2.SetHash()
    65  	must.NoError(t, testState.UpsertACLAuthMethods(20, []*structs.ACLAuthMethod{
    66  		updatedMockedAuthMethod1, updatedMockedAuthMethod2,
    67  	}))
    68  
    69  	// Check that the index for the table was modified as expected.
    70  	updatedIndex, err := testState.Index(TableACLAuthMethods)
    71  	must.NoError(t, err)
    72  	must.Eq(t, 20, updatedIndex)
    73  
    74  	// List the ACL auth methods in state.
    75  	iter, err = testState.GetACLAuthMethods(ws)
    76  	must.NoError(t, err)
    77  
    78  	// Count how many table entries we have, to ensure it is the expected
    79  	// number.
    80  	count = 0
    81  
    82  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
    83  		count++
    84  
    85  		// Ensure the create and modify indexes are populated correctly.
    86  		aclAuthMethod := raw.(*structs.ACLAuthMethod)
    87  		must.Eq(t, 10, aclAuthMethod.CreateIndex)
    88  		must.Eq(t, 20, aclAuthMethod.ModifyIndex)
    89  	}
    90  	must.Eq(t, 2, count, must.Sprintf("incorrect number of ACL auth methods found"))
    91  
    92  	// Try adding a new auth method, which has a name clash with an existing
    93  	// entry.
    94  	dup := mock.ACLOIDCAuthMethod()
    95  	dup.Name = mockedACLAuthMethods[0].Name
    96  	dup.Type = mockedACLAuthMethods[0].Type
    97  
    98  	err = testState.UpsertACLAuthMethods(50, []*structs.ACLAuthMethod{dup})
    99  	must.NoError(t, err)
   100  
   101  	// Get all the ACL auth methods from state.
   102  	iter, err = testState.GetACLAuthMethods(ws)
   103  	must.NoError(t, err)
   104  
   105  	// Count how many table entries we have, to ensure it is the expected
   106  	// number.
   107  	count = 0
   108  
   109  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   110  		count++
   111  	}
   112  	must.Eq(t, 2, count, must.Sprintf("incorrect number of ACL auth methods found"))
   113  }
   114  
   115  func TestStateStore_DeleteACLAuthMethods(t *testing.T) {
   116  	ci.Parallel(t)
   117  	testState := testStateStore(t)
   118  
   119  	// Generate some mocked ACL auth methods for testing and upsert these
   120  	// straight into state.
   121  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLOIDCAuthMethod(), mock.ACLOIDCAuthMethod()}
   122  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   123  
   124  	// Try and delete a method using a name that doesn't exist. This should
   125  	// return an error and not change the index for the table.
   126  	err := testState.DeleteACLAuthMethods(20, []string{"not-a-method"})
   127  	must.EqError(t, err, "ACL auth method not found")
   128  
   129  	tableIndex, err := testState.Index(TableACLAuthMethods)
   130  	must.NoError(t, err)
   131  	must.Eq(t, 10, tableIndex)
   132  
   133  	// Delete one of the previously upserted auth methods. This should succeed
   134  	// and modify the table index.
   135  	err = testState.DeleteACLAuthMethods(20, []string{mockedACLAuthMethods[0].Name})
   136  	must.NoError(t, err)
   137  
   138  	tableIndex, err = testState.Index(TableACLAuthMethods)
   139  	must.NoError(t, err)
   140  	must.Eq(t, 20, tableIndex)
   141  
   142  	// List the ACL auth methods and ensure we now only have one present and
   143  	// that it is the one we expect.
   144  	ws := memdb.NewWatchSet()
   145  	iter, err := testState.GetACLAuthMethods(ws)
   146  	must.NoError(t, err)
   147  
   148  	var aclAuthMethods []*structs.ACLAuthMethod
   149  
   150  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   151  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   152  	}
   153  
   154  	must.Len(t, 1, aclAuthMethods, must.Sprintf("incorrect number of auth methods found"))
   155  	must.True(t, aclAuthMethods[0].Equal(mockedACLAuthMethods[1]))
   156  
   157  	// Delete the final remaining auth method. This should succeed and modify
   158  	// the table index.
   159  	err = testState.DeleteACLAuthMethods(30, []string{mockedACLAuthMethods[1].Name})
   160  	must.NoError(t, err)
   161  
   162  	tableIndex, err = testState.Index(TableACLAuthMethods)
   163  	must.NoError(t, err)
   164  	must.Eq(t, 30, tableIndex)
   165  
   166  	// List the auth methods and ensure we have zero entries.
   167  	iter, err = testState.GetACLAuthMethods(ws)
   168  	must.NoError(t, err)
   169  
   170  	aclAuthMethods = []*structs.ACLAuthMethod{}
   171  
   172  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   173  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   174  	}
   175  	must.Len(t, 0, aclAuthMethods, must.Sprintf("incorrect number of ACL roles found"))
   176  }
   177  
   178  func TestStateStore_GetACLAuthMethods(t *testing.T) {
   179  	ci.Parallel(t)
   180  	testState := testStateStore(t)
   181  
   182  	// Generate a some mocked ACL auth methods for testing and upsert these
   183  	// straight into state.
   184  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLOIDCAuthMethod(), mock.ACLOIDCAuthMethod()}
   185  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   186  
   187  	// List the auth methods and ensure they are exactly as we expect.
   188  	ws := memdb.NewWatchSet()
   189  	iter, err := testState.GetACLAuthMethods(ws)
   190  	must.NoError(t, err)
   191  
   192  	var aclAuthMethods []*structs.ACLAuthMethod
   193  
   194  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   195  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   196  	}
   197  
   198  	expected := mockedACLAuthMethods
   199  	for i := range expected {
   200  		expected[i].CreateIndex = 10
   201  		expected[i].ModifyIndex = 10
   202  	}
   203  
   204  	must.SliceContainsAll(t, aclAuthMethods, expected)
   205  }
   206  
   207  func TestStateStore_GetACLAuthMethodByName(t *testing.T) {
   208  	ci.Parallel(t)
   209  	testState := testStateStore(t)
   210  
   211  	// Generate a some mocked ACL auth methods for testing and upsert these
   212  	// straight into state.
   213  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLOIDCAuthMethod(), mock.ACLOIDCAuthMethod()}
   214  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   215  
   216  	ws := memdb.NewWatchSet()
   217  
   218  	// Try reading an auth method that does not exist.
   219  	authMethod, err := testState.GetACLAuthMethodByName(ws, "not-a-method")
   220  	must.NoError(t, err)
   221  	must.Nil(t, authMethod)
   222  
   223  	// Read the two ACL roles that we should find.
   224  	authMethod, err = testState.GetACLAuthMethodByName(ws, mockedACLAuthMethods[0].Name)
   225  	must.NoError(t, err)
   226  	must.Equal(t, mockedACLAuthMethods[0], authMethod)
   227  
   228  	authMethod, err = testState.GetACLAuthMethodByName(ws, mockedACLAuthMethods[1].Name)
   229  	must.NoError(t, err)
   230  	must.Equal(t, mockedACLAuthMethods[1], authMethod)
   231  }
   232  
   233  func TestStateStore_GetDefaultACLAuthMethod(t *testing.T) {
   234  	ci.Parallel(t)
   235  	testState := testStateStore(t)
   236  
   237  	// Generate 2 auth methods, make one of them default
   238  	am1 := mock.ACLOIDCAuthMethod()
   239  	am1.Default = true
   240  	am2 := mock.ACLOIDCAuthMethod()
   241  
   242  	// upsert
   243  	mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2}
   244  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   245  
   246  	// Get the default method
   247  	ws := memdb.NewWatchSet()
   248  	defaultACLAuthMethod, err := testState.GetDefaultACLAuthMethod(ws)
   249  	must.NoError(t, err)
   250  
   251  	must.True(t, defaultACLAuthMethod.Default)
   252  	must.Eq(t, am1, defaultACLAuthMethod)
   253  
   254  }