github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/nomad/state/state_store_acl_sso_test.go (about)

     1  package state
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/hashicorp/go-memdb"
     7  	"github.com/hashicorp/nomad/ci"
     8  	"github.com/hashicorp/nomad/nomad/mock"
     9  	"github.com/hashicorp/nomad/nomad/structs"
    10  	"github.com/shoenig/test/must"
    11  )
    12  
    13  func TestStateStore_UpsertACLAuthMethods(t *testing.T) {
    14  	ci.Parallel(t)
    15  	testState := testStateStore(t)
    16  
    17  	// Create mock auth methods
    18  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLAuthMethod(), mock.ACLAuthMethod()}
    19  
    20  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
    21  
    22  	// Check that the index for the table was modified as expected.
    23  	initialIndex, err := testState.Index(TableACLAuthMethods)
    24  	must.NoError(t, err)
    25  	must.Eq(t, 10, initialIndex)
    26  
    27  	// List all the auth methods in the table, so we can perform a number of
    28  	// tests on the return array.
    29  	ws := memdb.NewWatchSet()
    30  	iter, err := testState.GetACLAuthMethods(ws)
    31  	must.NoError(t, err)
    32  
    33  	// Count how many table entries we have, to ensure it is the expected
    34  	// number.
    35  	var count int
    36  
    37  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
    38  		count++
    39  
    40  		// Ensure the create and modify indexes are populated correctly.
    41  		authMethod := raw.(*structs.ACLAuthMethod)
    42  		must.Eq(t, 10, authMethod.CreateIndex)
    43  		must.Eq(t, 10, authMethod.ModifyIndex)
    44  	}
    45  	must.Eq(t, 2, count)
    46  
    47  	// Try writing the same auth methods to state which should not result in an
    48  	// update to the table index.
    49  	must.NoError(t, testState.UpsertACLAuthMethods(20, mockedACLAuthMethods))
    50  	reInsertActualIndex, err := testState.Index(TableACLAuthMethods)
    51  	must.NoError(t, err)
    52  	must.Eq(t, 10, reInsertActualIndex)
    53  
    54  	// Make a change to the auth methods and ensure this update is accepted and
    55  	// the table index is updated.
    56  	updatedMockedAuthMethod1 := mockedACLAuthMethods[0].Copy()
    57  	updatedMockedAuthMethod1.Type = "new type"
    58  	updatedMockedAuthMethod1.SetHash()
    59  	updatedMockedAuthMethod2 := mockedACLAuthMethods[1].Copy()
    60  	updatedMockedAuthMethod2.Type = "yet another new type"
    61  	updatedMockedAuthMethod2.SetHash()
    62  	must.NoError(t, testState.UpsertACLAuthMethods(20, []*structs.ACLAuthMethod{
    63  		updatedMockedAuthMethod1, updatedMockedAuthMethod2,
    64  	}))
    65  
    66  	// Check that the index for the table was modified as expected.
    67  	updatedIndex, err := testState.Index(TableACLAuthMethods)
    68  	must.NoError(t, err)
    69  	must.Eq(t, 20, updatedIndex)
    70  
    71  	// List the ACL auth methods in state.
    72  	iter, err = testState.GetACLAuthMethods(ws)
    73  	must.NoError(t, err)
    74  
    75  	// Count how many table entries we have, to ensure it is the expected
    76  	// number.
    77  	count = 0
    78  
    79  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
    80  		count++
    81  
    82  		// Ensure the create and modify indexes are populated correctly.
    83  		aclAuthMethod := raw.(*structs.ACLAuthMethod)
    84  		must.Eq(t, 10, aclAuthMethod.CreateIndex)
    85  		must.Eq(t, 20, aclAuthMethod.ModifyIndex)
    86  	}
    87  	must.Eq(t, 2, count, must.Sprintf("incorrect number of ACL auth methods found"))
    88  
    89  	// Try adding a new auth method, which has a name clash with an existing
    90  	// entry.
    91  	dup := mock.ACLAuthMethod()
    92  	dup.Name = mockedACLAuthMethods[0].Name
    93  	dup.Type = mockedACLAuthMethods[0].Type
    94  
    95  	err = testState.UpsertACLAuthMethods(50, []*structs.ACLAuthMethod{dup})
    96  	must.NoError(t, err)
    97  
    98  	// Get all the ACL auth methods from state.
    99  	iter, err = testState.GetACLAuthMethods(ws)
   100  	must.NoError(t, err)
   101  
   102  	// Count how many table entries we have, to ensure it is the expected
   103  	// number.
   104  	count = 0
   105  
   106  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   107  		count++
   108  	}
   109  	must.Eq(t, 2, count, must.Sprintf("incorrect number of ACL auth methods found"))
   110  }
   111  
   112  func TestStateStore_DeleteACLAuthMethods(t *testing.T) {
   113  	ci.Parallel(t)
   114  	testState := testStateStore(t)
   115  
   116  	// Generate some mocked ACL auth methods for testing and upsert these
   117  	// straight into state.
   118  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLAuthMethod(), mock.ACLAuthMethod()}
   119  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   120  
   121  	// Try and delete a method using a name that doesn't exist. This should
   122  	// return an error and not change the index for the table.
   123  	err := testState.DeleteACLAuthMethods(20, []string{"not-a-method"})
   124  	must.EqError(t, err, "ACL auth method not found")
   125  
   126  	tableIndex, err := testState.Index(TableACLAuthMethods)
   127  	must.NoError(t, err)
   128  	must.Eq(t, 10, tableIndex)
   129  
   130  	// Delete one of the previously upserted auth methods. This should succeed
   131  	// and modify the table index.
   132  	err = testState.DeleteACLAuthMethods(20, []string{mockedACLAuthMethods[0].Name})
   133  	must.NoError(t, err)
   134  
   135  	tableIndex, err = testState.Index(TableACLAuthMethods)
   136  	must.NoError(t, err)
   137  	must.Eq(t, 20, tableIndex)
   138  
   139  	// List the ACL auth methods and ensure we now only have one present and
   140  	// that it is the one we expect.
   141  	ws := memdb.NewWatchSet()
   142  	iter, err := testState.GetACLAuthMethods(ws)
   143  	must.NoError(t, err)
   144  
   145  	var aclAuthMethods []*structs.ACLAuthMethod
   146  
   147  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   148  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   149  	}
   150  
   151  	must.Len(t, 1, aclAuthMethods, must.Sprintf("incorrect number of auth methods found"))
   152  	must.True(t, aclAuthMethods[0].Equal(mockedACLAuthMethods[1]))
   153  
   154  	// Delete the final remaining auth method. This should succeed and modify
   155  	// the table index.
   156  	err = testState.DeleteACLAuthMethods(30, []string{mockedACLAuthMethods[1].Name})
   157  	must.NoError(t, err)
   158  
   159  	tableIndex, err = testState.Index(TableACLAuthMethods)
   160  	must.NoError(t, err)
   161  	must.Eq(t, 30, tableIndex)
   162  
   163  	// List the auth methods and ensure we have zero entries.
   164  	iter, err = testState.GetACLAuthMethods(ws)
   165  	must.NoError(t, err)
   166  
   167  	aclAuthMethods = []*structs.ACLAuthMethod{}
   168  
   169  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   170  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   171  	}
   172  	must.Len(t, 0, aclAuthMethods, must.Sprintf("incorrect number of ACL roles found"))
   173  }
   174  
   175  func TestStateStore_GetACLAuthMethods(t *testing.T) {
   176  	ci.Parallel(t)
   177  	testState := testStateStore(t)
   178  
   179  	// Generate a some mocked ACL auth methods for testing and upsert these
   180  	// straight into state.
   181  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLAuthMethod(), mock.ACLAuthMethod()}
   182  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   183  
   184  	// List the auth methods and ensure they are exactly as we expect.
   185  	ws := memdb.NewWatchSet()
   186  	iter, err := testState.GetACLAuthMethods(ws)
   187  	must.NoError(t, err)
   188  
   189  	var aclAuthMethods []*structs.ACLAuthMethod
   190  
   191  	for raw := iter.Next(); raw != nil; raw = iter.Next() {
   192  		aclAuthMethods = append(aclAuthMethods, raw.(*structs.ACLAuthMethod))
   193  	}
   194  
   195  	expected := mockedACLAuthMethods
   196  	for i := range expected {
   197  		expected[i].CreateIndex = 10
   198  		expected[i].ModifyIndex = 10
   199  	}
   200  
   201  	must.SliceContainsAll(t, aclAuthMethods, expected)
   202  }
   203  
   204  func TestStateStore_GetACLAuthMethodByName(t *testing.T) {
   205  	ci.Parallel(t)
   206  	testState := testStateStore(t)
   207  
   208  	// Generate a some mocked ACL auth methods for testing and upsert these
   209  	// straight into state.
   210  	mockedACLAuthMethods := []*structs.ACLAuthMethod{mock.ACLAuthMethod(), mock.ACLAuthMethod()}
   211  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   212  
   213  	ws := memdb.NewWatchSet()
   214  
   215  	// Try reading an auth method that does not exist.
   216  	authMethod, err := testState.GetACLAuthMethodByName(ws, "not-a-method")
   217  	must.NoError(t, err)
   218  	must.Nil(t, authMethod)
   219  
   220  	// Read the two ACL roles that we should find.
   221  	authMethod, err = testState.GetACLAuthMethodByName(ws, mockedACLAuthMethods[0].Name)
   222  	must.NoError(t, err)
   223  	must.Equal(t, mockedACLAuthMethods[0], authMethod)
   224  
   225  	authMethod, err = testState.GetACLAuthMethodByName(ws, mockedACLAuthMethods[1].Name)
   226  	must.NoError(t, err)
   227  	must.Equal(t, mockedACLAuthMethods[1], authMethod)
   228  }
   229  
   230  func TestStateStore_GetDefaultACLAuthMethodByType(t *testing.T) {
   231  	ci.Parallel(t)
   232  	testState := testStateStore(t)
   233  
   234  	// Generate 2 auth methods, make one of them default
   235  	am1 := mock.ACLAuthMethod()
   236  	am1.Default = true
   237  	am2 := mock.ACLAuthMethod()
   238  
   239  	// upsert
   240  	mockedACLAuthMethods := []*structs.ACLAuthMethod{am1, am2}
   241  	must.NoError(t, testState.UpsertACLAuthMethods(10, mockedACLAuthMethods))
   242  
   243  	// Get the default method for OIDC
   244  	ws := memdb.NewWatchSet()
   245  	defaultOIDCMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "OIDC")
   246  	must.NoError(t, err)
   247  
   248  	must.True(t, defaultOIDCMethod.Default)
   249  	must.Eq(t, am1, defaultOIDCMethod)
   250  
   251  	// Get the default method for jwt (should not return anything)
   252  	defaultJWTMethod, err := testState.GetDefaultACLAuthMethodByType(ws, "JWT")
   253  	must.NoError(t, err)
   254  	must.Nil(t, defaultJWTMethod)
   255  }