github.com/polarismesh/polaris@v1.17.8/store/mysql/group.go (about) 1 /** 2 * Tencent is pleased to support the open source community by making Polaris available. 3 * 4 * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. 5 * 6 * Licensed under the BSD 3-Clause License (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * https://opensource.org/licenses/BSD-3-Clause 11 * 12 * Unless required by applicable law or agreed to in writing, software distributed 13 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 14 * CONDITIONS OF ANY KIND, either express or implied. See the License for the 15 * specific language governing permissions and limitations under the License. 16 */ 17 18 package sqldb 19 20 import ( 21 "database/sql" 22 "fmt" 23 "time" 24 25 "go.uber.org/zap" 26 27 "github.com/polarismesh/polaris/common/model" 28 "github.com/polarismesh/polaris/common/utils" 29 "github.com/polarismesh/polaris/store" 30 ) 31 32 const ( 33 // IDAttribute is the name of the attribute that stores the ID of the object. 34 IDAttribute string = "id" 35 36 // NameAttribute will be used as the name of the attribute that stores the name of the object. 37 NameAttribute string = "name" 38 39 // FlagAttribute will be used as the name of the attribute that stores the flag of the object. 40 FlagAttribute string = "flag" 41 42 // GroupIDAttribute will be used as the name of the attribute that stores the group ID of the object. 43 GroupIDAttribute string = "group_id" 44 ) 45 46 var ( 47 groupAttribute map[string]string = map[string]string{ 48 "name": "ug.name", 49 "id": "ug.id", 50 "owner": "ug.owner", 51 } 52 ) 53 54 type groupStore struct { 55 master *BaseDB 56 slave *BaseDB 57 } 58 59 // AddGroup 创建一个用户组 60 func (u *groupStore) AddGroup(group *model.UserGroupDetail) error { 61 if group.ID == "" || group.Name == "" || group.Token == "" { 62 return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 63 "add usergroup missing some params, groupId is %s, name is %s", group.ID, group.Name)) 64 } 65 66 err := RetryTransaction("addGroup", func() error { 67 return u.addGroup(group) 68 }) 69 70 return store.Error(err) 71 } 72 73 func (u *groupStore) addGroup(group *model.UserGroupDetail) error { 74 tx, err := u.master.Begin() 75 if err != nil { 76 return err 77 } 78 79 defer func() { _ = tx.Rollback() }() 80 81 // 先清理无效数据 82 if err := cleanInValidGroup(tx, group.Name, group.Owner); err != nil { 83 return store.Error(err) 84 } 85 86 addSql := ` 87 INSERT INTO user_group (id, name, owner, token, token_enable, comment, flag, ctime, mtime) 88 VALUES (?, ?, ?, ?, ?, ?, ?, sysdate(), sysdate()) 89 ` 90 91 if _, err = tx.Exec(addSql, []interface{}{ 92 group.ID, 93 group.Name, 94 group.Owner, 95 group.Token, 96 1, 97 group.Comment, 98 0, 99 }...); err != nil { 100 log.Errorf("[Store][Group] add usergroup err: %s", err.Error()) 101 return err 102 } 103 104 if err := u.addGroupRelation(tx, group.ID, group.ToUserIdSlice()); err != nil { 105 log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error()) 106 return err 107 } 108 109 if err := createDefaultStrategy(tx, model.PrincipalGroup, group.ID, group.Name, group.Owner); err != nil { 110 log.Errorf("[Store][Group] add usergroup default strategy err: %s", err.Error()) 111 return err 112 } 113 114 if err := tx.Commit(); err != nil { 115 log.Errorf("[Store][Group] add usergroup tx commit err: %s", err.Error()) 116 return err 117 } 118 return nil 119 } 120 121 // UpdateGroup 更新用户组 122 func (u *groupStore) UpdateGroup(group *model.ModifyUserGroup) error { 123 if group.ID == "" { 124 return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 125 "update usergroup missing some params, groupId is %s", group.ID)) 126 } 127 128 err := RetryTransaction("updateGroup", func() error { 129 return u.updateGroup(group) 130 }) 131 132 return store.Error(err) 133 } 134 135 func (u *groupStore) updateGroup(group *model.ModifyUserGroup) error { 136 tx, err := u.master.Begin() 137 if err != nil { 138 return err 139 } 140 141 defer func() { _ = tx.Rollback() }() 142 143 tokenEnable := 1 144 if !group.TokenEnable { 145 tokenEnable = 0 146 } 147 148 // 更新用户-用户组关联数据 149 if len(group.AddUserIds) != 0 { 150 if err := u.addGroupRelation(tx, group.ID, group.AddUserIds); err != nil { 151 log.Errorf("[Store][Group] add usergroup relation err: %s", err.Error()) 152 return err 153 } 154 } 155 156 if len(group.RemoveUserIds) != 0 { 157 if err := u.removeGroupRelation(tx, group.ID, group.RemoveUserIds); err != nil { 158 log.Errorf("[Store][Group] remove usergroup relation err: %s", err.Error()) 159 return err 160 } 161 } 162 163 modifySql := "UPDATE user_group SET token = ?, comment = ?, token_enable = ?, mtime = sysdate() " + 164 " WHERE id = ? AND flag = 0" 165 if _, err = tx.Exec(modifySql, []interface{}{ 166 group.Token, 167 group.Comment, 168 tokenEnable, 169 group.ID, 170 }...); err != nil { 171 log.Errorf("[Store][Group] update usergroup main err: %s", err.Error()) 172 return err 173 } 174 175 if err := tx.Commit(); err != nil { 176 log.Errorf("[Store][Group] update usergroup tx commit err: %s", err.Error()) 177 return err 178 } 179 180 return nil 181 } 182 183 // DeleteGroup 删除用户组 184 func (u *groupStore) DeleteGroup(group *model.UserGroupDetail) error { 185 if group.ID == "" || group.Name == "" { 186 return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 187 "delete usergroup missing some params, groupId is %s", group.ID)) 188 } 189 190 err := RetryTransaction("deleteUserGroup", func() error { 191 return u.deleteUserGroup(group) 192 }) 193 194 return store.Error(err) 195 } 196 197 func (u *groupStore) deleteUserGroup(group *model.UserGroupDetail) error { 198 tx, err := u.master.Begin() 199 if err != nil { 200 return err 201 } 202 203 defer func() { _ = tx.Rollback() }() 204 205 if _, err = tx.Exec("DELETE FROM user_group_relation WHERE group_id = ?", []interface{}{ 206 group.ID, 207 }...); err != nil { 208 log.Errorf("[Store][Group] clean usergroup relation err: %s", err.Error()) 209 return err 210 } 211 212 if _, err = tx.Exec("UPDATE user_group SET flag = 1, mtime = sysdate() WHERE id = ?", []interface{}{ 213 group.ID, 214 }...); err != nil { 215 log.Errorf("[Store][Group] remove usergroup err: %s", err.Error()) 216 return err 217 } 218 219 if err := cleanLinkStrategy(tx, model.PrincipalGroup, group.ID, group.Owner); err != nil { 220 log.Errorf("[Store][Group] clean usergroup default strategy err: %s", err.Error()) 221 return err 222 } 223 224 if err := tx.Commit(); err != nil { 225 log.Errorf("[Store][Group] delete usergroupr tx commit err: %s", err.Error()) 226 return err 227 } 228 return nil 229 } 230 231 // GetGroup 根据用户组ID获取用户组 232 func (u *groupStore) GetGroup(groupId string) (*model.UserGroupDetail, error) { 233 if groupId == "" { 234 return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 235 "get usergroup missing some params, groupId is %s", groupId)) 236 } 237 238 getSql := ` 239 SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable 240 , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime) 241 FROM user_group ug 242 WHERE ug.flag = 0 243 AND ug.id = ? 244 ` 245 row := u.master.QueryRow(getSql, groupId) 246 247 group := &model.UserGroupDetail{ 248 UserGroup: &model.UserGroup{}, 249 } 250 var ( 251 ctime, mtime int64 252 tokenEnable int 253 ) 254 255 if err := row.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &tokenEnable, 256 &ctime, &mtime); err != nil { 257 switch err { 258 case sql.ErrNoRows: 259 return nil, nil 260 default: 261 return nil, store.Error(err) 262 } 263 } 264 uids, err := u.getGroupLinkUserIds(group.ID) 265 if err != nil { 266 return nil, store.Error(err) 267 } 268 269 group.UserIds = uids 270 group.TokenEnable = tokenEnable == 1 271 group.CreateTime = time.Unix(ctime, 0) 272 group.ModifyTime = time.Unix(mtime, 0) 273 274 return group, nil 275 } 276 277 // GetGroupByName 根据 owner、name 获取用户组 278 func (u *groupStore) GetGroupByName(name, owner string) (*model.UserGroup, error) { 279 if name == "" || owner == "" { 280 return nil, store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 281 "get usergroup missing some params, name=%s, owner=%s", name, owner)) 282 } 283 284 var ctime, mtime int64 285 286 getSql := ` 287 SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token 288 , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime) 289 FROM user_group ug 290 WHERE ug.flag = 0 291 AND ug.name = ? 292 AND ug.owner = ? 293 ` 294 row := u.master.QueryRow(getSql, name, owner) 295 296 group := new(model.UserGroup) 297 298 if err := row.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &ctime, &mtime); err != nil { 299 switch err { 300 case sql.ErrNoRows: 301 return nil, nil 302 default: 303 return nil, store.Error(err) 304 } 305 } 306 307 group.CreateTime = time.Unix(ctime, 0) 308 group.ModifyTime = time.Unix(mtime, 0) 309 310 return group, nil 311 } 312 313 // GetGroups 根据不同的请求情况进行不同的用户组列表查询 314 func (u *groupStore) GetGroups(filters map[string]string, offset uint32, limit uint32) (uint32, 315 []*model.UserGroup, error) { 316 317 // 如果本次请求参数携带了 user_id,那么就是查询这个用户所关联的所有用户组 318 if _, ok := filters["user_id"]; ok { 319 return u.listGroupByUser(filters, offset, limit) 320 } 321 // 正常查询用户组信息 322 return u.listSimpleGroups(filters, offset, limit) 323 } 324 325 // listSimpleGroups 正常的用户组查询 326 func (u *groupStore) listSimpleGroups(filters map[string]string, offset uint32, limit uint32) (uint32, 327 []*model.UserGroup, error) { 328 329 query := make(map[string]string) 330 if _, ok := filters["id"]; ok { 331 query["id"] = filters["id"] 332 } 333 if _, ok := filters["name"]; ok { 334 query["name"] = filters["name"] 335 } 336 filters = query 337 338 countSql := "SELECT COUNT(*) FROM user_group ug WHERE ug.flag = 0 " 339 getSql := ` 340 SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable 341 , UNIX_TIMESTAMP(ug.ctime), UNIX_TIMESTAMP(ug.mtime) 342 , ug.flag 343 FROM user_group ug 344 WHERE ug.flag = 0 345 ` 346 347 args := make([]interface{}, 0) 348 349 if len(filters) != 0 { 350 for k, v := range filters { 351 getSql += " AND " 352 countSql += " AND " 353 if newK, ok := groupAttribute[k]; ok { 354 k = newK 355 } 356 if utils.IsPrefixWildName(v) { 357 getSql += (" " + k + " like ? ") 358 countSql += (" " + k + " like ? ") 359 args = append(args, "%"+v[:len(v)-1]+"%") 360 } else { 361 getSql += (" " + k + " = ? ") 362 countSql += (" " + k + " = ? ") 363 args = append(args, v) 364 } 365 } 366 } 367 368 count, err := queryEntryCount(u.master, countSql, args) 369 if err != nil { 370 return 0, nil, err 371 } 372 373 getSql += " ORDER BY ug.mtime LIMIT ? , ?" 374 args = append(args, offset, limit) 375 376 groups, err := u.collectGroupsFromRows(u.master.Query, getSql, args) 377 if err != nil { 378 return 0, nil, err 379 } 380 381 return count, groups, nil 382 } 383 384 // listGroupByUser 查询某个用户下所关联的用户组信息 385 func (u *groupStore) listGroupByUser(filters map[string]string, offset uint32, limit uint32) (uint32, 386 []*model.UserGroup, error) { 387 countSql := "SELECT COUNT(*) FROM user_group_relation ul LEFT JOIN user_group ug ON " + 388 " ul.group_id = ug.id WHERE ug.flag = 0 " 389 getSql := "SELECT ug.id, ug.name, ug.owner, ug.comment, ug.token, ug.token_enable, UNIX_TIMESTAMP(ug.ctime), " + 390 " UNIX_TIMESTAMP(ug.mtime), ug.flag " + 391 " FROM user_group_relation ul LEFT JOIN user_group ug ON ul.group_id = ug.id WHERE ug.flag = 0 " 392 393 args := make([]interface{}, 0) 394 395 if len(filters) != 0 { 396 for k, v := range filters { 397 getSql += " AND " 398 countSql += " AND " 399 if newK, ok := userLinkGroupAttributeMapping[k]; ok { 400 k = newK 401 } 402 if utils.IsPrefixWildName(v) { 403 getSql += (" " + k + " like ? ") 404 countSql += (" " + k + " like ? ") 405 args = append(args, "%"+v[:len(v)-1]+"%") 406 } else if k == "ug.owner" { 407 getSql += " (ug.owner = ?) " 408 countSql += " (ug.owner = ?) " 409 args = append(args, v) 410 } else { 411 getSql += (" " + k + " = ? ") 412 countSql += (" " + k + " = ? ") 413 args = append(args, v) 414 } 415 } 416 } 417 418 count, err := queryEntryCount(u.master, countSql, args) 419 if err != nil { 420 return 0, nil, err 421 } 422 423 getSql += " GROUP BY ug.id ORDER BY ug.mtime LIMIT ? , ?" 424 args = append(args, offset, limit) 425 426 groups, err := u.collectGroupsFromRows(u.master.Query, getSql, args) 427 if err != nil { 428 return 0, nil, err 429 } 430 431 return count, groups, nil 432 } 433 434 // collectGroupsFromRows 查询用户组列表 435 func (u *groupStore) collectGroupsFromRows(handler QueryHandler, querySql string, 436 args []interface{}) ([]*model.UserGroup, error) { 437 rows, err := u.master.Query(querySql, args...) 438 if err != nil { 439 log.Error("[Store][Group] list group", zap.String("query sql", querySql), zap.Any("args", args)) 440 return nil, err 441 } 442 defer rows.Close() 443 444 groups := make([]*model.UserGroup, 0) 445 for rows.Next() { 446 group, err := fetchRown2UserGroup(rows) 447 if err != nil { 448 log.Errorf("[Store][Group] list group by user fetch rows scan err: %s", err.Error()) 449 return nil, err 450 } 451 groups = append(groups, group) 452 } 453 454 return groups, nil 455 } 456 457 // GetGroupsForCache . 458 func (u *groupStore) GetGroupsForCache(mtime time.Time, firstUpdate bool) ([]*model.UserGroupDetail, error) { 459 tx, err := u.slave.Begin() 460 if err != nil { 461 return nil, store.Error(err) 462 } 463 464 defer func() { _ = tx.Commit() }() 465 466 args := make([]interface{}, 0) 467 querySql := "SELECT id, name, owner, comment, token, token_enable, UNIX_TIMESTAMP(ctime), UNIX_TIMESTAMP(mtime), " + 468 " flag FROM user_group " 469 if !firstUpdate { 470 querySql += " WHERE mtime >= FROM_UNIXTIME(?)" 471 args = append(args, timeToTimestamp(mtime)) 472 } 473 474 rows, err := tx.Query(querySql, args...) 475 if err != nil { 476 return nil, store.Error(err) 477 } 478 defer rows.Close() 479 480 ret := make([]*model.UserGroupDetail, 0) 481 for rows.Next() { 482 detail := &model.UserGroupDetail{ 483 UserIds: make(map[string]struct{}, 0), 484 } 485 group, err := fetchRown2UserGroup(rows) 486 if err != nil { 487 return nil, store.Error(err) 488 } 489 uids, err := u.getGroupLinkUserIds(group.ID) 490 if err != nil { 491 return nil, store.Error(err) 492 } 493 494 detail.UserIds = uids 495 detail.UserGroup = group 496 497 ret = append(ret, detail) 498 } 499 500 return ret, nil 501 } 502 503 func (u *groupStore) addGroupRelation(tx *BaseTx, groupId string, userIds []string) error { 504 if groupId == "" { 505 return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 506 "add user relation missing some params, groupid is %s", groupId)) 507 } 508 if len(userIds) > utils.MaxBatchSize { 509 return store.NewStatusError(store.InvalidUserIDSlice, fmt.Sprintf( 510 "user id slice is invalid, len=%d", len(userIds))) 511 } 512 513 for i := range userIds { 514 uid := userIds[i] 515 addSql := "INSERT INTO user_group_relation (group_id, user_id) VALUE (?,?)" 516 args := []interface{}{groupId, uid} 517 _, err := tx.Exec(addSql, args...) 518 if err != nil { 519 err = store.Error(err) 520 // 之前的用户已经存在,直接忽略 521 if store.Code(err) == store.DuplicateEntryErr { 522 continue 523 } 524 return err 525 } 526 } 527 return nil 528 } 529 530 func (u *groupStore) removeGroupRelation(tx *BaseTx, groupId string, userIds []string) error { 531 if groupId == "" { 532 return store.NewStatusError(store.EmptyParamsErr, fmt.Sprintf( 533 "delete user relation missing some params, groupid is %s", groupId)) 534 } 535 if len(userIds) > utils.MaxBatchSize { 536 return store.NewStatusError(store.InvalidUserIDSlice, fmt.Sprintf( 537 "user id slice is invalid, len=%d", len(userIds))) 538 } 539 540 for i := range userIds { 541 uid := userIds[i] 542 addSql := "DELETE FROM user_group_relation WHERE group_id = ? AND user_id = ?" 543 args := []interface{}{groupId, uid} 544 if _, err := tx.Exec(addSql, args...); err != nil { 545 return err 546 } 547 } 548 549 return nil 550 } 551 552 func (u *groupStore) getGroupLinkUserIds(groupId string) (map[string]struct{}, error) { 553 554 ids := make(map[string]struct{}) 555 556 // 拉取该分组下的所有 user 557 idRows, err := u.slave.Query("SELECT user_id FROM user u JOIN user_group_relation ug ON "+ 558 " u.id = ug.user_id WHERE ug.group_id = ?", groupId) 559 if err != nil { 560 return nil, err 561 } 562 defer idRows.Close() 563 for idRows.Next() { 564 var uid string 565 if err := idRows.Scan(&uid); err != nil { 566 return nil, err 567 } 568 ids[uid] = struct{}{} 569 } 570 571 return ids, nil 572 } 573 574 func fetchRown2UserGroup(rows *sql.Rows) (*model.UserGroup, error) { 575 var ctime, mtime int64 576 var flag, tokenEnable int 577 group := new(model.UserGroup) 578 if err := rows.Scan(&group.ID, &group.Name, &group.Owner, &group.Comment, &group.Token, &tokenEnable, 579 &ctime, &mtime, &flag); err != nil { 580 return nil, err 581 } 582 583 group.Valid = flag == 0 584 group.TokenEnable = tokenEnable == 1 585 group.CreateTime = time.Unix(ctime, 0) 586 group.ModifyTime = time.Unix(mtime, 0) 587 588 return group, nil 589 } 590 591 // cleanInValidUserGroup 清理无效的用户组数据 592 func cleanInValidGroup(tx *BaseTx, name, owner string) error { 593 log.Infof("[Store][User] clean usergroup(%s)", name) 594 595 str := "delete from user_group where name = ? and flag = 1" 596 if _, err := tx.Exec(str, name); err != nil { 597 log.Errorf("[Store][User] clean usergroup(%s) err: %s", name, err.Error()) 598 return err 599 } 600 601 return nil 602 }