github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/lorry/engines/postgres/user.go (about)

     1  /*
     2  Copyright (C) 2022-2023 ApeCloud Co., Ltd
     3  
     4  This file is part of KubeBlocks project
     5  
     6  This program is free software: you can redistribute it and/or modify
     7  it under the terms of the GNU Affero General Public License as published by
     8  the Free Software Foundation, either version 3 of the License, or
     9  (at your option) any later version.
    10  
    11  This program is distributed in the hope that it will be useful
    12  but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  GNU Affero General Public License for more details.
    15  
    16  You should have received a copy of the GNU Affero General Public License
    17  along with this program.  If not, see <http://www.gnu.org/licenses/>.
    18  */
    19  
    20  package postgres
    21  
    22  import (
    23  	"context"
    24  	"encoding/json"
    25  	"fmt"
    26  
    27  	"golang.org/x/exp/slices"
    28  
    29  	"github.com/1aal/kubeblocks/pkg/lorry/engines/models"
    30  )
    31  
    32  const (
    33  	listUserTpl = `
    34  	SELECT usename AS userName, valuntil <now() AS expired,  usesuper,
    35  	ARRAY(SELECT
    36  		case
    37  			when b.rolname = 'pg_read_all_data' THEN 'readonly'
    38  			when b.rolname = 'pg_write_all_data' THEN 'readwrite'
    39  		else b.rolname
    40  		end
    41  	FROM pg_catalog.pg_auth_members m
    42  	JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid)
    43  	WHERE m.member = usesysid ) as roles
    44  	FROM pg_catalog.pg_user
    45  	WHERE usename <> 'postgres' and usename  not like 'kb%'
    46  	ORDER BY usename;
    47  	`
    48  	descUserTpl = `
    49  	SELECT usename AS userName,  valuntil <now() AS expired, usesuper,
    50  	ARRAY(SELECT
    51  	 case
    52  		 when b.rolname = 'pg_read_all_data' THEN 'readonly'
    53  		 when b.rolname = 'pg_write_all_data' THEN 'readwrite'
    54  	 else b.rolname
    55  	 end
    56  	FROM pg_catalog.pg_auth_members m
    57  	JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid)
    58  	WHERE m.member = usesysid ) as roles
    59  	FROM pg_user
    60  	WHERE usename = '%s';
    61  	`
    62  	createUserTpl         = "CREATE USER %s WITH PASSWORD '%s';"
    63  	dropUserTpl           = "DROP USER IF EXISTS %s;"
    64  	grantTpl              = "GRANT %s TO %s;"
    65  	revokeTpl             = "REVOKE %s FROM %s;"
    66  	listSystemAccountsTpl = "SELECT rolname FROM pg_catalog.pg_roles WHERE pg_roles.rolname LIKE 'kb%'"
    67  )
    68  
    69  func (mgr *Manager) ListUsers(ctx context.Context) ([]models.UserInfo, error) {
    70  	data, err := mgr.Query(ctx, listUserTpl)
    71  	if err != nil {
    72  		mgr.Logger.Error(err, "error executing %s")
    73  		return nil, err
    74  	}
    75  
    76  	return pgUserRolesProcessor(data)
    77  }
    78  
    79  func (mgr *Manager) ListSystemAccounts(ctx context.Context) ([]models.UserInfo, error) {
    80  	data, err := mgr.Query(ctx, listSystemAccountsTpl)
    81  	if err != nil {
    82  		mgr.Logger.Error(err, "error executing %s")
    83  		return nil, err
    84  	}
    85  	type roleInfo struct {
    86  		Rolname string `json:"rolname"`
    87  	}
    88  	var roles []roleInfo
    89  	if err := json.Unmarshal(data, &roles); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	users := []models.UserInfo{}
    94  	for _, role := range roles {
    95  		user := models.UserInfo{
    96  			RoleName: role.Rolname,
    97  		}
    98  		users = append(users, user)
    99  	}
   100  
   101  	return users, nil
   102  }
   103  
   104  func (mgr *Manager) DescribeUser(ctx context.Context, userName string) (*models.UserInfo, error) {
   105  	sql := fmt.Sprintf(descUserTpl, userName)
   106  
   107  	data, err := mgr.Query(ctx, sql)
   108  	if err != nil {
   109  		mgr.Logger.Error(err, "execute sql failed", "sql", sql)
   110  		return nil, err
   111  	}
   112  
   113  	users, err := pgUserRolesProcessor(data)
   114  	if err != nil {
   115  		mgr.Logger.Error(err, "parse data failed", "data", string(data))
   116  		return nil, err
   117  	}
   118  
   119  	if len(users) > 0 {
   120  		return &users[0], nil
   121  	}
   122  	return nil, nil
   123  }
   124  
   125  func (mgr *Manager) CreateUser(ctx context.Context, userName, password string) error {
   126  	sql := fmt.Sprintf(createUserTpl, userName, password)
   127  
   128  	_, err := mgr.Exec(ctx, sql)
   129  	if err != nil {
   130  		mgr.Logger.Error(err, "execute sql failed", "sql", sql)
   131  		return err
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (mgr *Manager) DeleteUser(ctx context.Context, userName string) error {
   138  	sql := fmt.Sprintf(dropUserTpl, userName)
   139  
   140  	_, err := mgr.Exec(ctx, sql)
   141  	if err != nil {
   142  		mgr.Logger.Error(err, "execute sql failed", "sql", sql)
   143  		return err
   144  	}
   145  
   146  	return nil
   147  }
   148  
   149  func (mgr *Manager) GrantUserRole(ctx context.Context, userName, roleName string) error {
   150  	var sql string
   151  	if models.SuperUserRole.EqualTo(roleName) {
   152  		sql = "ALTER USER " + userName + " WITH SUPERUSER;"
   153  	} else {
   154  		roleDesc, _ := role2PGRole(roleName)
   155  		sql = fmt.Sprintf(grantTpl, roleDesc, userName)
   156  	}
   157  	_, err := mgr.Exec(ctx, sql)
   158  	if err != nil {
   159  		mgr.Logger.Error(err, "execute sql failed", "sql", sql)
   160  		return err
   161  	}
   162  
   163  	return nil
   164  }
   165  
   166  func (mgr *Manager) RevokeUserRole(ctx context.Context, userName, roleName string) error {
   167  	var sql string
   168  	if models.SuperUserRole.EqualTo(roleName) {
   169  		sql = "ALTER USER " + userName + " WITH NOSUPERUSER;"
   170  	} else {
   171  		roleDesc, _ := role2PGRole(roleName)
   172  		sql = fmt.Sprintf(revokeTpl, roleDesc, userName)
   173  	}
   174  
   175  	_, err := mgr.Exec(ctx, sql)
   176  	if err != nil {
   177  		mgr.Logger.Error(err, "execute sql failed", "sql", sql)
   178  		return err
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  // post-processing
   185  func pgUserRolesProcessor(data interface{}) ([]models.UserInfo, error) {
   186  	type pgUserInfo struct {
   187  		UserName string   `json:"username"`
   188  		Expired  bool     `json:"expired"`
   189  		Super    bool     `json:"usesuper"`
   190  		Roles    []string `json:"roles"`
   191  	}
   192  	// parse data to struct
   193  	var pgUsers []pgUserInfo
   194  	err := json.Unmarshal(data.([]byte), &pgUsers)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	// parse roles
   199  	users := make([]models.UserInfo, len(pgUsers))
   200  	for i := range pgUsers {
   201  		users[i] = models.UserInfo{
   202  			UserName: pgUsers[i].UserName,
   203  		}
   204  
   205  		if pgUsers[i].Expired {
   206  			users[i].Expired = "T"
   207  		} else {
   208  			users[i].Expired = "F"
   209  		}
   210  
   211  		// parse Super attribute
   212  		if pgUsers[i].Super {
   213  			pgUsers[i].Roles = append(pgUsers[i].Roles, string(models.SuperUserRole))
   214  		}
   215  
   216  		// convert to RoleType and sort by weight
   217  		roleTypes := make([]models.RoleType, 0)
   218  		for _, role := range pgUsers[i].Roles {
   219  			roleTypes = append(roleTypes, models.String2RoleType(role))
   220  		}
   221  		slices.SortFunc(roleTypes, models.SortRoleByWeight)
   222  		if len(roleTypes) > 0 {
   223  			users[i].RoleName = string(roleTypes[0])
   224  		}
   225  	}
   226  	return users, nil
   227  }
   228  
   229  func role2PGRole(roleName string) (string, error) {
   230  	roleType := models.String2RoleType(roleName)
   231  	switch roleType {
   232  	case models.ReadWriteRole:
   233  		return "pg_write_all_data", nil
   234  	case models.ReadOnlyRole:
   235  		return "pg_read_all_data", nil
   236  	}
   237  	return "", fmt.Errorf("role name: %s is not supported", roleName)
   238  }