github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/userloginstate/userloginstate_test.go (about)

     1  // Copyright 2023 Gravitational, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package userloginstate
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/stretchr/testify/require"
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/protobuf/types/known/emptypb"
    25  
    26  	userloginstatev1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/userloginstate/v1"
    27  	"github.com/gravitational/teleport/api/types"
    28  	"github.com/gravitational/teleport/api/types/header"
    29  	"github.com/gravitational/teleport/api/types/trait"
    30  	"github.com/gravitational/teleport/api/types/userloginstate"
    31  	conv "github.com/gravitational/teleport/api/types/userloginstate/convert/v1"
    32  )
    33  
    34  type mockClient struct {
    35  	userloginstatev1.UserLoginStateServiceClient
    36  
    37  	t *testing.T
    38  
    39  	getUserLoginStatesRequest       *userloginstatev1.GetUserLoginStatesRequest
    40  	getUserLoginStateRequest        *userloginstatev1.GetUserLoginStateRequest
    41  	upsertUserLoginStateRequest     *userloginstatev1.UpsertUserLoginStateRequest
    42  	deleteUserLoginStateRequest     *userloginstatev1.DeleteUserLoginStateRequest
    43  	deleteAllUserLoginStatesRequest *userloginstatev1.DeleteAllUserLoginStatesRequest
    44  }
    45  
    46  func (m *mockClient) GetUserLoginStates(_ context.Context, in *userloginstatev1.GetUserLoginStatesRequest, _ ...grpc.CallOption) (*userloginstatev1.GetUserLoginStatesResponse, error) {
    47  	m.getUserLoginStatesRequest = in
    48  	return &userloginstatev1.GetUserLoginStatesResponse{
    49  		UserLoginStates: []*userloginstatev1.UserLoginState{
    50  			newUserLoginStateProto(m.t, "uls1"),
    51  			newUserLoginStateProto(m.t, "uls2"),
    52  			newUserLoginStateProto(m.t, "uls3"),
    53  		},
    54  	}, nil
    55  }
    56  
    57  func (m *mockClient) GetUserLoginState(_ context.Context, in *userloginstatev1.GetUserLoginStateRequest, _ ...grpc.CallOption) (*userloginstatev1.UserLoginState, error) {
    58  	m.getUserLoginStateRequest = in
    59  	return newUserLoginStateProto(m.t, in.Name), nil
    60  }
    61  
    62  func (m *mockClient) UpsertUserLoginState(_ context.Context, in *userloginstatev1.UpsertUserLoginStateRequest, _ ...grpc.CallOption) (*userloginstatev1.UserLoginState, error) {
    63  	m.upsertUserLoginStateRequest = in
    64  	return in.UserLoginState, nil
    65  }
    66  
    67  func (m *mockClient) DeleteUserLoginState(_ context.Context, in *userloginstatev1.DeleteUserLoginStateRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
    68  	m.deleteUserLoginStateRequest = in
    69  	return nil, nil
    70  }
    71  
    72  func (m *mockClient) DeleteAllUserLoginStates(_ context.Context, in *userloginstatev1.DeleteAllUserLoginStatesRequest, _ ...grpc.CallOption) (*emptypb.Empty, error) {
    73  	m.deleteAllUserLoginStatesRequest = in
    74  	return nil, nil
    75  }
    76  
    77  func TestGetUserLoginStates(t *testing.T) {
    78  	t.Parallel()
    79  	mockClient := &mockClient{t: t}
    80  	client := NewClient(mockClient)
    81  
    82  	states, err := client.GetUserLoginStates(context.Background())
    83  	require.NoError(t, err)
    84  
    85  	require.NotNil(t, mockClient.getUserLoginStatesRequest)
    86  
    87  	require.Empty(t, cmp.Diff([]*userloginstate.UserLoginState{
    88  		newUserLoginState(t, "uls1"),
    89  		newUserLoginState(t, "uls2"),
    90  		newUserLoginState(t, "uls3"),
    91  	}, states))
    92  }
    93  
    94  func TestGetUserLoginState(t *testing.T) {
    95  	t.Parallel()
    96  
    97  	mockClient := &mockClient{t: t}
    98  	client := NewClient(mockClient)
    99  
   100  	uls, err := client.GetUserLoginState(context.Background(), "uls1")
   101  	require.NoError(t, err)
   102  
   103  	require.Equal(t, "uls1", mockClient.getUserLoginStateRequest.Name)
   104  
   105  	require.Empty(t, cmp.Diff(newUserLoginState(t, "uls1"), uls))
   106  }
   107  
   108  func TestUpsertUserLoginState(t *testing.T) {
   109  	t.Parallel()
   110  
   111  	mockClient := &mockClient{t: t}
   112  	client := NewClient(mockClient)
   113  
   114  	uls := newUserLoginState(t, "uls1")
   115  
   116  	resp, err := client.UpsertUserLoginState(context.Background(), uls)
   117  	require.NoError(t, err)
   118  
   119  	require.Empty(t, cmp.Diff(uls, mustFromProto(t, mockClient.upsertUserLoginStateRequest.UserLoginState)))
   120  	require.Empty(t, cmp.Diff(resp, newUserLoginState(t, "uls1")))
   121  }
   122  
   123  func TestDeleteUserLoginState(t *testing.T) {
   124  	t.Parallel()
   125  
   126  	mockClient := &mockClient{t: t}
   127  	client := NewClient(mockClient)
   128  
   129  	require.NoError(t, client.DeleteUserLoginState(context.Background(), "uls1"))
   130  
   131  	require.Equal(t, "uls1", mockClient.deleteUserLoginStateRequest.Name)
   132  }
   133  
   134  func TestDeleteAllUserLoginStates(t *testing.T) {
   135  	t.Parallel()
   136  
   137  	mockClient := &mockClient{t: t}
   138  	client := NewClient(mockClient)
   139  
   140  	require.NoError(t, client.DeleteAllUserLoginStates(context.Background()))
   141  
   142  	require.NotNil(t, mockClient.deleteAllUserLoginStatesRequest)
   143  }
   144  
   145  func newUserLoginStateProto(t *testing.T, name string) *userloginstatev1.UserLoginState {
   146  	t.Helper()
   147  
   148  	return conv.ToProto(newUserLoginState(t, name))
   149  }
   150  
   151  func newUserLoginState(t *testing.T, name string) *userloginstate.UserLoginState {
   152  	t.Helper()
   153  
   154  	uls, err := userloginstate.New(header.Metadata{
   155  		Name: name,
   156  	}, userloginstate.Spec{
   157  		Roles:          []string{"role1", "role2"},
   158  		OriginalTraits: trait.Traits{},
   159  		Traits: trait.Traits{
   160  			"trait1": []string{"value1", "value2"},
   161  			"trait2": []string{"value1", "value2"},
   162  		},
   163  		UserType: types.UserTypeLocal,
   164  	})
   165  	require.NoError(t, err)
   166  
   167  	return uls
   168  }
   169  
   170  func mustFromProto(t *testing.T, msg *userloginstatev1.UserLoginState) *userloginstate.UserLoginState {
   171  	t.Helper()
   172  
   173  	uls, err := conv.FromProto(msg)
   174  	require.NoError(t, err)
   175  
   176  	return uls
   177  }