github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/userloginstate/userloginstate_test.go (about) 1 // Copyright 2023 Gravitational, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package userloginstate 16 17 import ( 18 "context" 19 "testing" 20 21 "github.com/google/go-cmp/cmp" 22 "github.com/stretchr/testify/require" 23 "google.golang.org/grpc" 24 "google.golang.org/protobuf/types/known/emptypb" 25 26 userloginstatev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/userloginstate/v1" 27 "github.com/gravitational/teleport/api/types" 28 "github.com/gravitational/teleport/api/types/header" 29 "github.com/gravitational/teleport/api/types/trait" 30 "github.com/gravitational/teleport/api/types/userloginstate" 31 conv "github.com/gravitational/teleport/api/types/userloginstate/convert/v1" 32 ) 33 34 type mockClient struct { 35 userloginstatev1.UserLoginStateServiceClient 36 37 t *testing.T 38 39 getUserLoginStatesRequest *userloginstatev1.GetUserLoginStatesRequest 40 getUserLoginStateRequest *userloginstatev1.GetUserLoginStateRequest 41 upsertUserLoginStateRequest *userloginstatev1.UpsertUserLoginStateRequest 42 deleteUserLoginStateRequest *userloginstatev1.DeleteUserLoginStateRequest 43 deleteAllUserLoginStatesRequest *userloginstatev1.DeleteAllUserLoginStatesRequest 44 } 45 46 func (m *mockClient) GetUserLoginStates(_ context.Context, in *userloginstatev1.GetUserLoginStatesRequest, _ ...grpc.CallOption) (*userloginstatev1.GetUserLoginStatesResponse, error) { 47 m.getUserLoginStatesRequest = in 48 return &userloginstatev1.GetUserLoginStatesResponse{ 49 UserLoginStates: []*userloginstatev1.UserLoginState{ 50 newUserLoginStateProto(m.t, "uls1"), 51 newUserLoginStateProto(m.t, "uls2"), 52 newUserLoginStateProto(m.t, "uls3"), 53 }, 54 }, nil 55 } 56 57 func (m *mockClient) GetUserLoginState(_ context.Context, in *userloginstatev1.GetUserLoginStateRequest, _ ...grpc.CallOption) (*userloginstatev1.UserLoginState, error) { 58 m.getUserLoginStateRequest = in 59 return newUserLoginStateProto(m.t, in.Name), nil 60 } 61 62 func (m *mockClient) UpsertUserLoginState(_ context.Context, in *userloginstatev1.UpsertUserLoginStateRequest, _ ...grpc.CallOption) (*userloginstatev1.UserLoginState, error) { 63 m.upsertUserLoginStateRequest = in 64 return in.UserLoginState, nil 65 } 66 67 func (m *mockClient) DeleteUserLoginState(_ context.Context, in *userloginstatev1.DeleteUserLoginStateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) { 68 m.deleteUserLoginStateRequest = in 69 return nil, nil 70 } 71 72 func (m *mockClient) DeleteAllUserLoginStates(_ context.Context, in *userloginstatev1.DeleteAllUserLoginStatesRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) { 73 m.deleteAllUserLoginStatesRequest = in 74 return nil, nil 75 } 76 77 func TestGetUserLoginStates(t *testing.T) { 78 t.Parallel() 79 mockClient := &mockClient{t: t} 80 client := NewClient(mockClient) 81 82 states, err := client.GetUserLoginStates(context.Background()) 83 require.NoError(t, err) 84 85 require.NotNil(t, mockClient.getUserLoginStatesRequest) 86 87 require.Empty(t, cmp.Diff([]*userloginstate.UserLoginState{ 88 newUserLoginState(t, "uls1"), 89 newUserLoginState(t, "uls2"), 90 newUserLoginState(t, "uls3"), 91 }, states)) 92 } 93 94 func TestGetUserLoginState(t *testing.T) { 95 t.Parallel() 96 97 mockClient := &mockClient{t: t} 98 client := NewClient(mockClient) 99 100 uls, err := client.GetUserLoginState(context.Background(), "uls1") 101 require.NoError(t, err) 102 103 require.Equal(t, "uls1", mockClient.getUserLoginStateRequest.Name) 104 105 require.Empty(t, cmp.Diff(newUserLoginState(t, "uls1"), uls)) 106 } 107 108 func TestUpsertUserLoginState(t *testing.T) { 109 t.Parallel() 110 111 mockClient := &mockClient{t: t} 112 client := NewClient(mockClient) 113 114 uls := newUserLoginState(t, "uls1") 115 116 resp, err := client.UpsertUserLoginState(context.Background(), uls) 117 require.NoError(t, err) 118 119 require.Empty(t, cmp.Diff(uls, mustFromProto(t, mockClient.upsertUserLoginStateRequest.UserLoginState))) 120 require.Empty(t, cmp.Diff(resp, newUserLoginState(t, "uls1"))) 121 } 122 123 func TestDeleteUserLoginState(t *testing.T) { 124 t.Parallel() 125 126 mockClient := &mockClient{t: t} 127 client := NewClient(mockClient) 128 129 require.NoError(t, client.DeleteUserLoginState(context.Background(), "uls1")) 130 131 require.Equal(t, "uls1", mockClient.deleteUserLoginStateRequest.Name) 132 } 133 134 func TestDeleteAllUserLoginStates(t *testing.T) { 135 t.Parallel() 136 137 mockClient := &mockClient{t: t} 138 client := NewClient(mockClient) 139 140 require.NoError(t, client.DeleteAllUserLoginStates(context.Background())) 141 142 require.NotNil(t, mockClient.deleteAllUserLoginStatesRequest) 143 } 144 145 func newUserLoginStateProto(t *testing.T, name string) *userloginstatev1.UserLoginState { 146 t.Helper() 147 148 return conv.ToProto(newUserLoginState(t, name)) 149 } 150 151 func newUserLoginState(t *testing.T, name string) *userloginstate.UserLoginState { 152 t.Helper() 153 154 uls, err := userloginstate.New(header.Metadata{ 155 Name: name, 156 }, userloginstate.Spec{ 157 Roles: []string{"role1", "role2"}, 158 OriginalTraits: trait.Traits{}, 159 Traits: trait.Traits{ 160 "trait1": []string{"value1", "value2"}, 161 "trait2": []string{"value1", "value2"}, 162 }, 163 UserType: types.UserTypeLocal, 164 }) 165 require.NoError(t, err) 166 167 return uls 168 } 169 170 func mustFromProto(t *testing.T, msg *userloginstatev1.UserLoginState) *userloginstate.UserLoginState { 171 t.Helper() 172 173 uls, err := conv.FromProto(msg) 174 require.NoError(t, err) 175 176 return uls 177 }