github.com/polarismesh/polaris@v1.17.8/cache/auth/user_test.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package auth
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/golang/mock/gomock"
    28  	"github.com/stretchr/testify/assert"
    29  
    30  	types "github.com/polarismesh/polaris/cache/api"
    31  	"github.com/polarismesh/polaris/common/model"
    32  	"github.com/polarismesh/polaris/common/utils"
    33  	"github.com/polarismesh/polaris/store/mock"
    34  )
    35  
    36  // 创建一个测试mock userCache
    37  func newTestUserCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *userCache) {
    38  	ctl := gomock.NewController(t)
    39  
    40  	var cacheMgr types.CacheManager
    41  	mockStore := mock.NewMockStore(ctl)
    42  
    43  	uc := NewUserCache(mockStore, cacheMgr)
    44  	opt := map[string]interface{}{}
    45  	_ = uc.Initialize(opt)
    46  	mockStore.EXPECT().GetUnixSecond(gomock.Any()).Return(time.Now().Unix(), nil).AnyTimes()
    47  
    48  	return ctl, mockStore, uc.(*userCache)
    49  }
    50  
    51  // 生成测试数据
    52  func genModelUsers(total int) []*model.User {
    53  	if total%10 != 0 {
    54  		panic(errors.New("total must like 10, 20, 30, 40, ..."))
    55  	}
    56  
    57  	out := make([]*model.User, 0, total)
    58  
    59  	var owner *model.User
    60  
    61  	for i := 0; i < total; i++ {
    62  		if i%10 == 0 {
    63  			owner = &model.User{
    64  				ID:       fmt.Sprintf("owner-user-%d", i),
    65  				Name:     fmt.Sprintf("owner-user-%d", i),
    66  				Password: fmt.Sprintf("owner-user-%d", i),
    67  				Owner:    "",
    68  				Source:   "Polaris",
    69  				Type:     model.OwnerUserRole,
    70  				Token:    fmt.Sprintf("owner-user-%d", i),
    71  				Valid:    true,
    72  			}
    73  			out = append(out, owner)
    74  			continue
    75  		}
    76  
    77  		entry := &model.User{
    78  			ID:       fmt.Sprintf("sub-user-%d", i),
    79  			Name:     fmt.Sprintf("sub-user-%d", i),
    80  			Password: fmt.Sprintf("sub-user-%d", i),
    81  			Owner:    owner.ID,
    82  			Source:   "Polaris",
    83  			Type:     model.SubAccountUserRole,
    84  			Token:    fmt.Sprintf("sub-user-%d", i),
    85  			Valid:    true,
    86  		}
    87  
    88  		out = append(out, entry)
    89  	}
    90  	return out
    91  }
    92  
    93  func genModelUserGroups(users []*model.User) []*model.UserGroupDetail {
    94  
    95  	out := make([]*model.UserGroupDetail, 0, len(users))
    96  
    97  	for i := 0; i < len(users); i++ {
    98  		entry := &model.UserGroupDetail{
    99  			UserGroup: &model.UserGroup{
   100  				ID:          utils.NewUUID(),
   101  				Name:        fmt.Sprintf("group-%d", i),
   102  				Owner:       users[0].ID,
   103  				Token:       users[i].Token,
   104  				TokenEnable: true,
   105  				Valid:       true,
   106  				Comment:     "",
   107  				CreateTime:  time.Time{},
   108  				ModifyTime:  time.Time{},
   109  			},
   110  			UserIds: map[string]struct{}{
   111  				users[i].ID: {},
   112  			},
   113  		}
   114  
   115  		out = append(out, entry)
   116  	}
   117  	return out
   118  }
   119  
   120  func TestUserCache_UpdateNormal(t *testing.T) {
   121  	ctrl, store, uc := newTestUserCache(t)
   122  
   123  	defer ctrl.Finish()
   124  
   125  	users := genModelUsers(10)
   126  	groups := genModelUserGroups(users)
   127  
   128  	t.Run("首次更新用户", func(t *testing.T) {
   129  		copyUsers := make([]*model.User, 0, len(users))
   130  		copyGroups := make([]*model.UserGroupDetail, 0, len(groups))
   131  
   132  		for i := range users {
   133  			copyUser := *users[i]
   134  			copyUsers = append(copyUsers, &copyUser)
   135  		}
   136  
   137  		for i := range groups {
   138  			copyGroup := *groups[i]
   139  			newUserIds := make(map[string]struct{}, len(copyGroup.UserIds))
   140  			for k, v := range groups[i].UserIds {
   141  				newUserIds[k] = v
   142  			}
   143  			copyGroup.UserIds = newUserIds
   144  			copyGroups = append(copyGroups, &copyGroup)
   145  		}
   146  		store.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).Return(copyUsers, nil).Times(1)
   147  		store.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).Return(copyGroups, nil).Times(1)
   148  
   149  		assert.NoError(t, uc.Update())
   150  
   151  		u := uc.GetUserByID(users[1].ID)
   152  		assert.NotNil(t, u)
   153  		assert.Equal(t, u, users[1])
   154  
   155  		u = uc.GetUserByName(users[1].Name, users[0].Name)
   156  		assert.NotNil(t, u)
   157  		assert.Equal(t, u, users[1])
   158  
   159  		g := uc.GetGroup(groups[1].ID)
   160  		assert.NotNil(t, g)
   161  		assert.Equal(t, g, groups[1])
   162  
   163  		gid := uc.GetUserLinkGroupIds(users[1].ID)
   164  		assert.Equal(t, 1, len(gid))
   165  		assert.Equal(t, groups[1].ID, gid[0])
   166  	})
   167  
   168  	t.Run("部分用户删除", func(t *testing.T) {
   169  
   170  		deleteCnt := 0
   171  		for i := range users {
   172  			// 主账户/管理账户 不能删除,因此这里对于第一个用户需要跳过
   173  			if users[i].Type != model.SubAccountUserRole {
   174  				continue
   175  			}
   176  			if rand.Int31n(3) < 1 {
   177  				users[i].Valid = false
   178  				delete(groups[i].UserIds, users[i].ID)
   179  				deleteCnt++
   180  			}
   181  
   182  			users[i].Comment = fmt.Sprintf("Update user %d", i)
   183  		}
   184  
   185  		copyUsers := make([]*model.User, 0, len(users))
   186  		copyGroups := make([]*model.UserGroupDetail, 0, len(groups))
   187  
   188  		for i := range users {
   189  			copyUser := *users[i]
   190  			copyUsers = append(copyUsers, &copyUser)
   191  		}
   192  
   193  		for i := range groups {
   194  			copyGroup := *groups[i]
   195  			newUserIds := make(map[string]struct{}, len(copyGroup.UserIds))
   196  			for k, v := range groups[i].UserIds {
   197  				newUserIds[k] = v
   198  			}
   199  			copyGroup.UserIds = newUserIds
   200  			copyGroups = append(copyGroups, &copyGroup)
   201  		}
   202  
   203  		store.EXPECT().GetUsersForCache(gomock.Any(), gomock.Any()).Return(copyUsers, nil).Times(1)
   204  		store.EXPECT().GetGroupsForCache(gomock.Any(), gomock.Any()).Return(copyGroups, nil).Times(1)
   205  
   206  		assert.NoError(t, uc.Update())
   207  
   208  		mockTn := time.Now()
   209  		for i := range users {
   210  			u := uc.GetUserByID(users[i].ID)
   211  
   212  			users[i].CreateTime = mockTn
   213  			users[i].ModifyTime = mockTn
   214  
   215  			if users[i].Valid {
   216  				u.CreateTime = mockTn
   217  				u.ModifyTime = mockTn
   218  				assert.NotNil(t, u)
   219  				assert.Equal(t, u, users[i])
   220  
   221  				u = uc.GetUserByName(users[i].Name, users[0].Name)
   222  				assert.NotNil(t, u)
   223  				assert.Equal(t, u, users[i])
   224  
   225  				g := uc.GetGroup(groups[i].ID)
   226  				assert.NotNil(t, g)
   227  				assert.Equal(t, g, groups[i])
   228  
   229  				gid := uc.GetUserLinkGroupIds(users[i].ID)
   230  				assert.Equal(t, 1, len(gid))
   231  				assert.Equal(t, groups[i].ID, gid[0])
   232  			} else {
   233  				assert.Nil(t, u)
   234  
   235  				u = uc.GetUserByName(users[i].Name, users[0].Name)
   236  				assert.Nil(t, u)
   237  
   238  				g := uc.GetGroup(groups[i].ID)
   239  				assert.NotNil(t, g)
   240  				assert.Equal(t, g, groups[i])
   241  				assert.Equal(t, 0, len(groups[i].UserIds))
   242  
   243  				gid := uc.GetUserLinkGroupIds(users[i].ID)
   244  				assert.Equal(t, 0, len(gid))
   245  			}
   246  		}
   247  
   248  	})
   249  
   250  }