github.com/polarismesh/polaris@v1.17.8/store/mysql/group.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package sqldb
    19  
    20  import (
    21  	"database/sql"
    22  	"fmt"
    23  	"time"
    24  
    25  	"go.uber.org/zap"
    26  
    27  	"github.com/polarismesh/polaris/common/model"
    28  	"github.com/polarismesh/polaris/common/utils"
    29  	"github.com/polarismesh/polaris/store"
    30  )
    31  
    32  const (
    33  	// IDAttribute is the name of the attribute that stores the ID of the object.
    34  	IDAttribute string = "id"
    35  
    36  	// NameAttribute will be used as the name of the attribute that stores the name of the object.
    37  	NameAttribute string = "name"
    38  
    39  	// FlagAttribute will be used as the name of the attribute that stores the flag of the object.
    40  	FlagAttribute string = "flag"
    41  
    42  	// GroupIDAttribute will be used as the name of the attribute that stores the group ID of the object.
    43  	GroupIDAttribute string = "group_id"
    44  )
    45  
    46  var (
    47  	groupAttribute map[string]string = map[string]string{
    48  		"name":  "ug.name",
    49  		"id":    "ug.id",
    50  		"owner": "ug.owner",
    51  	}
    52  )
    53  
    54  type groupStore struct {
    55  	master *BaseDB
    56  	slave  *BaseDB
    57  }
    58  
    59  // AddGroup 创建一个用户组
    60  func (u *groupStore) AddGroup(group *model.UserGroupDetail) error {
    61  	if group.ID == "" || group.Name == "" || group.Token == "" {
    62  		return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
    63  			"add usergroup missing some params, groupId is %s, name is %s", group.ID, group.Name))
    64  	}
    65  
    66  	err := RetryTransaction("addGroup", func() error {
    67  		return u.addGroup(group)
    68  	})
    69  
    70  	return store.Error(err)
    71  }
    72  
    73  func (u *groupStore) addGroup(group *model.UserGroupDetail) error {
    74  	tx, err := u.master.Begin()
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	defer func() { _ = tx.Rollback() }()
    80  
    81  	// 先清理无效数据
    82  	if err := cleanInValidGroup(tx, group.Name, group.Owner); err != nil {
    83  		return store.Error(err)
    84  	}
    85  
    86  	addSql := `
    87  	  INSERT INTO user_group (id, name, owner, token, token_enable, comment, flag, ctime, mtime)
    88  	  VALUES (?, ?, ?, ?, ?, ?, ?, sysdate(), sysdate())
    89  	  `
    90  
    91  	if _, err = tx.Exec(addSql, []interface{}{
    92  		group.ID,
    93  		group.Name,
    94  		group.Owner,
    95  		group.Token,
    96  		1,
    97  		group.Comment,
    98  		0,
    99  	}...); err != nil {
   100  		log.Errorf("[Store][Group] add usergroup err: %s", err.Error())
   101  		return err
   102  	}
   103  
   104  	if err := u.addGroupRelation(tx, group.ID, group.ToUserIdSlice()); err != nil {
   105  		log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error())
   106  		return err
   107  	}
   108  
   109  	if err := createDefaultStrategy(tx, model.PrincipalGroup, group.ID, group.Name, group.Owner); err != nil {
   110  		log.Errorf("[Store][Group] add usergroup default strategy err: %s", err.Error())
   111  		return err
   112  	}
   113  
   114  	if err := tx.Commit(); err != nil {
   115  		log.Errorf("[Store][Group] add usergroup tx commit err: %s", err.Error())
   116  		return err
   117  	}
   118  	return nil
   119  }
   120  
   121  // UpdateGroup 更新用户组
   122  func (u *groupStore) UpdateGroup(group *model.ModifyUserGroup) error {
   123  	if group.ID == "" {
   124  		return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   125  			"update usergroup missing some params, groupId is %s", group.ID))
   126  	}
   127  
   128  	err := RetryTransaction("updateGroup", func() error {
   129  		return u.updateGroup(group)
   130  	})
   131  
   132  	return store.Error(err)
   133  }
   134  
   135  func (u *groupStore) updateGroup(group *model.ModifyUserGroup) error {
   136  	tx, err := u.master.Begin()
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	defer func() { _ = tx.Rollback() }()
   142  
   143  	tokenEnable := 1
   144  	if !group.TokenEnable {
   145  		tokenEnable = 0
   146  	}
   147  
   148  	// 更新用户-用户组关联数据
   149  	if len(group.AddUserIds) != 0 {
   150  		if err := u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil {
   151  			log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error())
   152  			return err
   153  		}
   154  	}
   155  
   156  	if len(group.RemoveUserIds) != 0 {
   157  		if err := u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil {
   158  			log.Errorf("[Store][Group] remove usergroup relation err: %s", err.Error())
   159  			return err
   160  		}
   161  	}
   162  
   163  	modifySql := "UPDATE user_group SET token = ?, comment = ?, token_enable = ?, mtime = sysdate() " +
   164  		" WHERE id = ? AND flag = 0"
   165  	if _, err = tx.Exec(modifySql, []interface{}{
   166  		group.Token,
   167  		group.Comment,
   168  		tokenEnable,
   169  		group.ID,
   170  	}...); err != nil {
   171  		log.Errorf("[Store][Group] update usergroup main err: %s", err.Error())
   172  		return err
   173  	}
   174  
   175  	if err := tx.Commit(); err != nil {
   176  		log.Errorf("[Store][Group] update usergroup tx commit err: %s", err.Error())
   177  		return err
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  // DeleteGroup 删除用户组
   184  func (u *groupStore) DeleteGroup(group *model.UserGroupDetail) error {
   185  	if group.ID == "" || group.Name == "" {
   186  		return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   187  			"delete usergroup missing some params, groupId is %s", group.ID))
   188  	}
   189  
   190  	err := RetryTransaction("deleteUserGroup", func() error {
   191  		return u.deleteUserGroup(group)
   192  	})
   193  
   194  	return store.Error(err)
   195  }
   196  
   197  func (u *groupStore) deleteUserGroup(group *model.UserGroupDetail) error {
   198  	tx, err := u.master.Begin()
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	defer func() { _ = tx.Rollback() }()
   204  
   205  	if _, err = tx.Exec("DELETE FROM user_group_relation WHERE group_id = ?", []interface{}{
   206  		group.ID,
   207  	}...); err != nil {
   208  		log.Errorf("[Store][Group] clean usergroup relation err: %s", err.Error())
   209  		return err
   210  	}
   211  
   212  	if _, err = tx.Exec("UPDATE user_group SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{
   213  		group.ID,
   214  	}...); err != nil {
   215  		log.Errorf("[Store][Group] remove usergroup err: %s", err.Error())
   216  		return err
   217  	}
   218  
   219  	if err := cleanLinkStrategy(tx, model.PrincipalGroup, group.ID, group.Owner); err != nil {
   220  		log.Errorf("[Store][Group] clean usergroup default strategy err: %s", err.Error())
   221  		return err
   222  	}
   223  
   224  	if err := tx.Commit(); err != nil {
   225  		log.Errorf("[Store][Group] delete usergroupr tx commit err: %s", err.Error())
   226  		return err
   227  	}
   228  	return nil
   229  }
   230  
   231  // GetGroup 根据用户组ID获取用户组
   232  func (u *groupStore) GetGroup(groupId string) (*model.UserGroupDetail, error) {
   233  	if groupId == "" {
   234  		return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   235  			"get usergroup missing some params, groupId is %s", groupId))
   236  	}
   237  
   238  	getSql := `
   239  	  SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable
   240  		  , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime)
   241  	  FROM user_group ug
   242  	  WHERE ug.flag = 0
   243  		  AND ug.id = ? 
   244  	  `
   245  	row := u.master.QueryRow(getSql, groupId)
   246  
   247  	group := &model.UserGroupDetail{
   248  		UserGroup: &model.UserGroup{},
   249  	}
   250  	var (
   251  		ctime, mtime int64
   252  		tokenEnable  int
   253  	)
   254  
   255  	if err := row.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &tokenEnable,
   256  		&ctime, &mtime); err != nil {
   257  		switch err {
   258  		case sql.ErrNoRows:
   259  			return nil, nil
   260  		default:
   261  			return nil, store.Error(err)
   262  		}
   263  	}
   264  	uids, err := u.getGroupLinkUserIds(group.ID)
   265  	if err != nil {
   266  		return nil, store.Error(err)
   267  	}
   268  
   269  	group.UserIds = uids
   270  	group.TokenEnable = tokenEnable == 1
   271  	group.CreateTime = time.Unix(ctime, 0)
   272  	group.ModifyTime = time.Unix(mtime, 0)
   273  
   274  	return group, nil
   275  }
   276  
   277  // GetGroupByName 根据 owner、name 获取用户组
   278  func (u *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error) {
   279  	if name == "" || owner == "" {
   280  		return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   281  			"get usergroup missing some params, name=%s, owner=%s", name, owner))
   282  	}
   283  
   284  	var ctime, mtime int64
   285  
   286  	getSql := `
   287  	  SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token
   288  		  , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime)
   289  	  FROM user_group ug
   290  	  WHERE ug.flag = 0
   291  		  AND ug.name = ?
   292  		  AND ug.owner = ? 
   293  	  `
   294  	row := u.master.QueryRow(getSql, name, owner)
   295  
   296  	group := new(model.UserGroup)
   297  
   298  	if err := row.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &ctime, &mtime); err != nil {
   299  		switch err {
   300  		case sql.ErrNoRows:
   301  			return nil, nil
   302  		default:
   303  			return nil, store.Error(err)
   304  		}
   305  	}
   306  
   307  	group.CreateTime = time.Unix(ctime, 0)
   308  	group.ModifyTime = time.Unix(mtime, 0)
   309  
   310  	return group, nil
   311  }
   312  
   313  // GetGroups 根据不同的请求情况进行不同的用户组列表查询
   314  func (u *groupStore) GetGroups(filters map[string]string, offset uint32, limit uint32) (uint32,
   315  	[]*model.UserGroup, error) {
   316  
   317  	// 如果本次请求参数携带了 user_id,那么就是查询这个用户所关联的所有用户组
   318  	if _, ok := filters["user_id"]; ok {
   319  		return u.listGroupByUser(filters, offset, limit)
   320  	}
   321  	// 正常查询用户组信息
   322  	return u.listSimpleGroups(filters, offset, limit)
   323  }
   324  
   325  // listSimpleGroups 正常的用户组查询
   326  func (u *groupStore) listSimpleGroups(filters map[string]string, offset uint32, limit uint32) (uint32,
   327  	[]*model.UserGroup, error) {
   328  
   329  	query := make(map[string]string)
   330  	if _, ok := filters["id"]; ok {
   331  		query["id"] = filters["id"]
   332  	}
   333  	if _, ok := filters["name"]; ok {
   334  		query["name"] = filters["name"]
   335  	}
   336  	filters = query
   337  
   338  	countSql := "SELECT COUNT(*) FROM user_group ug WHERE ug.flag = 0 "
   339  	getSql := `
   340  	  SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable
   341  		  , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime)
   342  		  , ug.flag
   343  	  FROM user_group ug
   344  	  WHERE ug.flag = 0 
   345  	  `
   346  
   347  	args := make([]interface{}, 0)
   348  
   349  	if len(filters) != 0 {
   350  		for k, v := range filters {
   351  			getSql += " AND "
   352  			countSql += " AND "
   353  			if newK, ok := groupAttribute[k]; ok {
   354  				k = newK
   355  			}
   356  			if utils.IsPrefixWildName(v) {
   357  				getSql += (" " + k + " like ? ")
   358  				countSql += (" " + k + " like ? ")
   359  				args = append(args, "%"+v[:len(v)-1]+"%")
   360  			} else {
   361  				getSql += (" " + k + " = ? ")
   362  				countSql += (" " + k + " = ? ")
   363  				args = append(args, v)
   364  			}
   365  		}
   366  	}
   367  
   368  	count, err := queryEntryCount(u.master, countSql, args)
   369  	if err != nil {
   370  		return 0, nil, err
   371  	}
   372  
   373  	getSql += " ORDER BY ug.mtime LIMIT ? , ?"
   374  	args = append(args, offset, limit)
   375  
   376  	groups, err := u.collectGroupsFromRows(u.master.Query, getSql, args)
   377  	if err != nil {
   378  		return 0, nil, err
   379  	}
   380  
   381  	return count, groups, nil
   382  }
   383  
   384  // listGroupByUser 查询某个用户下所关联的用户组信息
   385  func (u *groupStore) listGroupByUser(filters map[string]string, offset uint32, limit uint32) (uint32,
   386  	[]*model.UserGroup, error) {
   387  	countSql := "SELECT COUNT(*) FROM user_group_relation ul LEFT JOIN user_group ug ON " +
   388  		" ul.group_id = ug.id WHERE ug.flag = 0 "
   389  	getSql := "SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable, UNIX_TIMESTAMP(ug.ctime), " +
   390  		" UNIX_TIMESTAMP(ug.mtime), ug.flag " +
   391  		" FROM user_group_relation ul LEFT JOIN user_group ug ON ul.group_id = ug.id WHERE ug.flag = 0 "
   392  
   393  	args := make([]interface{}, 0)
   394  
   395  	if len(filters) != 0 {
   396  		for k, v := range filters {
   397  			getSql += " AND "
   398  			countSql += " AND "
   399  			if newK, ok := userLinkGroupAttributeMapping[k]; ok {
   400  				k = newK
   401  			}
   402  			if utils.IsPrefixWildName(v) {
   403  				getSql += (" " + k + " like ? ")
   404  				countSql += (" " + k + " like ? ")
   405  				args = append(args, "%"+v[:len(v)-1]+"%")
   406  			} else if k == "ug.owner" {
   407  				getSql += " (ug.owner = ?) "
   408  				countSql += " (ug.owner = ?) "
   409  				args = append(args, v)
   410  			} else {
   411  				getSql += (" " + k + " = ? ")
   412  				countSql += (" " + k + " = ? ")
   413  				args = append(args, v)
   414  			}
   415  		}
   416  	}
   417  
   418  	count, err := queryEntryCount(u.master, countSql, args)
   419  	if err != nil {
   420  		return 0, nil, err
   421  	}
   422  
   423  	getSql += " GROUP BY ug.id ORDER BY ug.mtime LIMIT ? , ?"
   424  	args = append(args, offset, limit)
   425  
   426  	groups, err := u.collectGroupsFromRows(u.master.Query, getSql, args)
   427  	if err != nil {
   428  		return 0, nil, err
   429  	}
   430  
   431  	return count, groups, nil
   432  }
   433  
   434  // collectGroupsFromRows 查询用户组列表
   435  func (u *groupStore) collectGroupsFromRows(handler QueryHandler, querySql string,
   436  	args []interface{}) ([]*model.UserGroup, error) {
   437  	rows, err := u.master.Query(querySql, args...)
   438  	if err != nil {
   439  		log.Error("[Store][Group] list group", zap.String("query sql", querySql), zap.Any("args", args))
   440  		return nil, err
   441  	}
   442  	defer rows.Close()
   443  
   444  	groups := make([]*model.UserGroup, 0)
   445  	for rows.Next() {
   446  		group, err := fetchRown2UserGroup(rows)
   447  		if err != nil {
   448  			log.Errorf("[Store][Group] list group by user fetch rows scan err: %s", err.Error())
   449  			return nil, err
   450  		}
   451  		groups = append(groups, group)
   452  	}
   453  
   454  	return groups, nil
   455  }
   456  
   457  // GetGroupsForCache .
   458  func (u *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) {
   459  	tx, err := u.slave.Begin()
   460  	if err != nil {
   461  		return nil, store.Error(err)
   462  	}
   463  
   464  	defer func() { _ = tx.Commit() }()
   465  
   466  	args := make([]interface{}, 0)
   467  	querySql := "SELECT id, name, owner, comment, token, token_enable, UNIX_TIMESTAMP(ctime), UNIX_TIMESTAMP(mtime), " +
   468  		" flag FROM user_group "
   469  	if !firstUpdate {
   470  		querySql += " WHERE mtime >= FROM_UNIXTIME(?)"
   471  		args = append(args, timeToTimestamp(mtime))
   472  	}
   473  
   474  	rows, err := tx.Query(querySql, args...)
   475  	if err != nil {
   476  		return nil, store.Error(err)
   477  	}
   478  	defer rows.Close()
   479  
   480  	ret := make([]*model.UserGroupDetail, 0)
   481  	for rows.Next() {
   482  		detail := &model.UserGroupDetail{
   483  			UserIds: make(map[string]struct{}, 0),
   484  		}
   485  		group, err := fetchRown2UserGroup(rows)
   486  		if err != nil {
   487  			return nil, store.Error(err)
   488  		}
   489  		uids, err := u.getGroupLinkUserIds(group.ID)
   490  		if err != nil {
   491  			return nil, store.Error(err)
   492  		}
   493  
   494  		detail.UserIds = uids
   495  		detail.UserGroup = group
   496  
   497  		ret = append(ret, detail)
   498  	}
   499  
   500  	return ret, nil
   501  }
   502  
   503  func (u *groupStore) addGroupRelation(tx *BaseTx, groupId string, userIds []string) error {
   504  	if groupId == "" {
   505  		return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   506  			"add user relation missing some params, groupid is %s", groupId))
   507  	}
   508  	if len(userIds) > utils.MaxBatchSize {
   509  		return store.NewStatusError(store.InvalidUserIDSlice, fmt.Sprintf(
   510  			"user id slice is invalid, len=%d", len(userIds)))
   511  	}
   512  
   513  	for i := range userIds {
   514  		uid := userIds[i]
   515  		addSql := "INSERT INTO user_group_relation (group_id, user_id) VALUE (?,?)"
   516  		args := []interface{}{groupId, uid}
   517  		_, err := tx.Exec(addSql, args...)
   518  		if err != nil {
   519  			err = store.Error(err)
   520  			// 之前的用户已经存在,直接忽略
   521  			if store.Code(err) == store.DuplicateEntryErr {
   522  				continue
   523  			}
   524  			return err
   525  		}
   526  	}
   527  	return nil
   528  }
   529  
   530  func (u *groupStore) removeGroupRelation(tx *BaseTx, groupId string, userIds []string) error {
   531  	if groupId == "" {
   532  		return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf(
   533  			"delete user relation missing some params, groupid is %s", groupId))
   534  	}
   535  	if len(userIds) > utils.MaxBatchSize {
   536  		return store.NewStatusError(store.InvalidUserIDSlice, fmt.Sprintf(
   537  			"user id slice is invalid, len=%d", len(userIds)))
   538  	}
   539  
   540  	for i := range userIds {
   541  		uid := userIds[i]
   542  		addSql := "DELETE FROM user_group_relation WHERE group_id = ? AND user_id = ?"
   543  		args := []interface{}{groupId, uid}
   544  		if _, err := tx.Exec(addSql, args...); err != nil {
   545  			return err
   546  		}
   547  	}
   548  
   549  	return nil
   550  }
   551  
   552  func (u *groupStore) getGroupLinkUserIds(groupId string) (map[string]struct{}, error) {
   553  
   554  	ids := make(map[string]struct{})
   555  
   556  	// 拉取该分组下的所有 user
   557  	idRows, err := u.slave.Query("SELECT user_id FROM user u JOIN user_group_relation ug ON "+
   558  		" u.id = ug.user_id WHERE ug.group_id = ?", groupId)
   559  	if err != nil {
   560  		return nil, err
   561  	}
   562  	defer idRows.Close()
   563  	for idRows.Next() {
   564  		var uid string
   565  		if err := idRows.Scan(&uid); err != nil {
   566  			return nil, err
   567  		}
   568  		ids[uid] = struct{}{}
   569  	}
   570  
   571  	return ids, nil
   572  }
   573  
   574  func fetchRown2UserGroup(rows *sql.Rows) (*model.UserGroup, error) {
   575  	var ctime, mtime int64
   576  	var flag, tokenEnable int
   577  	group := new(model.UserGroup)
   578  	if err := rows.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &tokenEnable,
   579  		&ctime, &mtime, &flag); err != nil {
   580  		return nil, err
   581  	}
   582  
   583  	group.Valid = flag == 0
   584  	group.TokenEnable = tokenEnable == 1
   585  	group.CreateTime = time.Unix(ctime, 0)
   586  	group.ModifyTime = time.Unix(mtime, 0)
   587  
   588  	return group, nil
   589  }
   590  
   591  // cleanInValidUserGroup 清理无效的用户组数据
   592  func cleanInValidGroup(tx *BaseTx, name, owner string) error {
   593  	log.Infof("[Store][User] clean usergroup(%s)", name)
   594  
   595  	str := "delete from user_group where name = ? and flag = 1"
   596  	if _, err := tx.Exec(str, name); err != nil {
   597  		log.Errorf("[Store][User] clean usergroup(%s) err: %s", name, err.Error())
   598  		return err
   599  	}
   600  
   601  	return nil
   602  }