github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/service/user.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"time"
     7  
     8  	"gorm.io/gorm"
     9  
    10  	"github.com/pyroscope-io/pyroscope/pkg/model"
    11  )
    12  
    13  type UserService struct{ db *gorm.DB }
    14  
    15  func NewUserService(db *gorm.DB) UserService { return UserService{db} }
    16  
    17  func (svc UserService) CreateUser(ctx context.Context, params model.CreateUserParams) (model.User, error) {
    18  	if err := params.Validate(); err != nil {
    19  		return model.User{}, err
    20  	}
    21  	user := model.User{
    22  		Name:              params.Name,
    23  		Email:             params.Email,
    24  		Role:              params.Role,
    25  		IsExternal:        &params.IsExternal,
    26  		PasswordHash:      model.MustPasswordHash(params.Password),
    27  		PasswordChangedAt: time.Now(),
    28  	}
    29  	if params.FullName != nil {
    30  		user.FullName = params.FullName
    31  	}
    32  	return user, svc.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
    33  		// Two separate queries only to simplify error handling (separate for
    34  		// name and email). Feel free to replace it if you deem it necessary.
    35  		if params.Email != nil {
    36  			_, err := findUserByEmail(tx, params.Email)
    37  			switch {
    38  			case errors.Is(err, model.ErrUserNotFound):
    39  			case err == nil:
    40  				return model.ErrUserEmailExists
    41  			default:
    42  				return err
    43  			}
    44  		}
    45  		_, err := findUserByName(tx, params.Name)
    46  		switch {
    47  		case errors.Is(err, model.ErrUserNotFound):
    48  		case err == nil:
    49  			return model.ErrUserNameExists
    50  		default:
    51  			return err
    52  		}
    53  		return tx.Create(&user).Error
    54  	})
    55  }
    56  
    57  func (svc UserService) FindUserByName(ctx context.Context, name string) (model.User, error) {
    58  	if err := model.ValidateUserName(name); err != nil {
    59  		return model.User{}, err
    60  	}
    61  	return findUserByName(svc.db.WithContext(ctx), name)
    62  }
    63  
    64  func (svc UserService) FindUserByEmail(ctx context.Context, email string) (model.User, error) {
    65  	if err := model.ValidateEmail(email); err != nil {
    66  		return model.User{}, err
    67  	}
    68  	return findUserByEmail(svc.db.WithContext(ctx), &email)
    69  }
    70  
    71  func (svc UserService) FindUserByID(ctx context.Context, id uint) (model.User, error) {
    72  	return findUserByID(svc.db.WithContext(ctx), id)
    73  }
    74  
    75  func findUserByName(tx *gorm.DB, name string) (model.User, error) {
    76  	return findUser(tx, model.User{Name: name})
    77  }
    78  
    79  func findUserByEmail(tx *gorm.DB, email *string) (model.User, error) {
    80  	return findUser(tx, model.User{Email: email})
    81  }
    82  
    83  func findUserByID(tx *gorm.DB, id uint) (model.User, error) {
    84  	return findUser(tx, model.User{ID: id})
    85  }
    86  
    87  func findUser(tx *gorm.DB, user model.User) (model.User, error) {
    88  	var u model.User
    89  	r := tx.Where(user).First(&u)
    90  	switch {
    91  	case r.Error == nil:
    92  		return u, nil
    93  	case errors.Is(r.Error, gorm.ErrRecordNotFound):
    94  		return model.User{}, model.ErrUserNotFound
    95  	default:
    96  		return model.User{}, r.Error
    97  	}
    98  }
    99  
   100  func (svc UserService) GetAllUsers(ctx context.Context) ([]model.User, error) {
   101  	var users []model.User
   102  	return users, svc.db.WithContext(ctx).Find(&users).Error
   103  }
   104  
   105  func (svc UserService) UpdateUserByID(ctx context.Context, id uint, params model.UpdateUserParams) (model.User, error) {
   106  	if err := params.Validate(); err != nil {
   107  		return model.User{}, err
   108  	}
   109  	return updateUserByID(svc.db.WithContext(ctx), id, params)
   110  }
   111  
   112  func (svc UserService) UpdateUserByName(ctx context.Context, name string, params model.UpdateUserParams) (model.User, error) {
   113  	if err := model.ValidateUserName(name); err != nil {
   114  		return model.User{}, err
   115  	}
   116  	if err := params.Validate(); err != nil {
   117  		return model.User{}, err
   118  	}
   119  	tx := svc.db.WithContext(ctx)
   120  	user, err := findUserByName(tx, name)
   121  	if err != nil {
   122  		return model.User{}, err
   123  	}
   124  	return updateUserByID(tx, user.ID, params)
   125  }
   126  
   127  func updateUserByID(tx *gorm.DB, id uint, params model.UpdateUserParams) (model.User, error) {
   128  	var updated model.User
   129  	return updated, tx.Transaction(func(tx *gorm.DB) error {
   130  		user, err := findUserByID(tx, id)
   131  		if err != nil {
   132  			return err
   133  		}
   134  		// We only skip update if params are not specified.
   135  		// Otherwise, even if the values match the current ones,
   136  		// the user is to be updated.
   137  		if (model.UpdateUserParams{}) == params {
   138  			updated = user
   139  			return nil
   140  		}
   141  		var columns model.User
   142  		// If the new email matches the current one, ignore.
   143  		if params.Email != nil && user.Email != nil && *user.Email != *params.Email {
   144  			// Make sure it is not in use.
   145  			// Note that we can't rely on the constraint violation error
   146  			// that should occur: underlying database driver errors are
   147  			// not standardized, but service consumers expect friendly
   148  			// typed errors.
   149  			switch _, err = findUserByEmail(tx.Unscoped(), params.Email); {
   150  			case errors.Is(err, model.ErrUserNotFound):
   151  				columns.Email = params.Email
   152  			case err == nil:
   153  				return model.ErrUserEmailExists
   154  			default:
   155  				return err
   156  			}
   157  		}
   158  		// Same for user name.
   159  		if params.Name != nil && user.Name != *params.Name {
   160  			if model.IsUserExternal(user) {
   161  				return model.ErrUserExternalChange
   162  			}
   163  			switch _, err = findUserByName(tx.Unscoped(), *params.Name); {
   164  			case errors.Is(err, model.ErrUserNotFound):
   165  				columns.Name = *params.Name
   166  			case err == nil:
   167  				return model.ErrUserNameExists
   168  			default:
   169  				return err
   170  			}
   171  		}
   172  		columns.FullName = params.FullName
   173  		columns.IsDisabled = params.IsDisabled
   174  		if params.Role != nil {
   175  			columns.Role = *params.Role
   176  		}
   177  		if params.Password != nil {
   178  			if model.IsUserExternal(user) {
   179  				return model.ErrUserExternalChange
   180  			}
   181  			columns.PasswordHash = model.MustPasswordHash(*params.Password)
   182  			columns.PasswordChangedAt = time.Now()
   183  		}
   184  		return tx.Model(user).Updates(columns).Error
   185  	})
   186  }
   187  
   188  func (svc UserService) UpdateUserPasswordByID(ctx context.Context, id uint, params model.UpdateUserPasswordParams) error {
   189  	if err := params.Validate(); err != nil {
   190  		return err
   191  	}
   192  	return svc.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
   193  		user, err := findUserByID(tx, id)
   194  		if err != nil {
   195  			return err
   196  		}
   197  		if err = model.VerifyPassword(user.PasswordHash, params.OldPassword); err != nil {
   198  			return model.ErrUserPasswordInvalid
   199  		}
   200  		columns := model.User{
   201  			ID:                id,
   202  			PasswordHash:      model.MustPasswordHash(params.NewPassword),
   203  			PasswordChangedAt: time.Now(),
   204  		}
   205  		return tx.Model(user).Updates(&columns).Error
   206  	})
   207  }
   208  
   209  func (svc UserService) DeleteUserByID(ctx context.Context, id uint) error {
   210  	return svc.db.WithContext(ctx).Delete(&model.User{}, id).Error
   211  }