github.com/ngocphuongnb/tetua@v0.0.7-alpha/packages/entrepository/user.go (about)

     1  package entrepository
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/ngocphuongnb/tetua/app/entities"
     9  	e "github.com/ngocphuongnb/tetua/app/entities"
    10  	"github.com/ngocphuongnb/tetua/app/utils"
    11  	"github.com/ngocphuongnb/tetua/packages/entrepository/ent"
    12  	"github.com/ngocphuongnb/tetua/packages/entrepository/ent/user"
    13  )
    14  
    15  type UserRepository struct {
    16  	*BaseRepository[e.User, ent.User, *ent.UserQuery, *e.UserFilter]
    17  }
    18  
    19  func userById(ctx context.Context, client *ent.Client, id int) (*ent.User, error) {
    20  	return client.User.Query().
    21  		Where(user.IDEQ(id)).
    22  		WithRoles().
    23  		WithAvatarImage().
    24  		Only(ctx)
    25  }
    26  
    27  func (u *UserRepository) ByUsername(ctx context.Context, username string) (*entities.User, error) {
    28  	user, err := u.Client.User.
    29  		Query().
    30  		Where(user.UsernameEQ(username)).
    31  		WithRoles().
    32  		WithAvatarImage().
    33  		Only(ctx)
    34  	if err != nil {
    35  		return nil, EntError(err, fmt.Sprintf("user not found with username: %s", username))
    36  	}
    37  
    38  	return entUserToUser(user), nil
    39  }
    40  
    41  func (u *UserRepository) ByUsernameOrEmail(ctx context.Context, username, email string) ([]*entities.User, error) {
    42  	user, err := u.Client.User.
    43  		Query().
    44  		Where(
    45  			user.Or(
    46  				user.UsernameEQ(username),
    47  				user.EmailEQ(email),
    48  			),
    49  		).
    50  		WithRoles().
    51  		WithAvatarImage().
    52  		All(ctx)
    53  	if err != nil {
    54  		return nil, EntError(err, fmt.Sprintf("user not found with username or email: %s %s", username, email))
    55  	}
    56  
    57  	return entUsersToUsers(user), nil
    58  }
    59  
    60  func (u *UserRepository) ByProvider(ctx context.Context, providerName, providerId string) (*entities.User, error) {
    61  	user, err := u.Client.User.
    62  		Query().
    63  		Where(
    64  			user.Provider(providerName),
    65  			user.Provider(providerId),
    66  		).
    67  		WithRoles().
    68  		WithAvatarImage().
    69  		Only(ctx)
    70  
    71  	if err != nil {
    72  		return nil, EntError(err, fmt.Sprintf("user not found with provider: %s %s", providerName, providerId))
    73  	}
    74  
    75  	return entUserToUser(user), nil
    76  }
    77  
    78  func (ur *UserRepository) CreateIfNotExistsByProvider(ctx context.Context, userData *entities.User) (*entities.User, error) {
    79  	u, err := ur.Client.User.
    80  		Query().
    81  		Where(
    82  			user.Provider(userData.Provider),
    83  			user.ProviderID(userData.ProviderID),
    84  		).
    85  		WithRoles().
    86  		WithAvatarImage().
    87  		Only(ctx)
    88  
    89  	if err != nil {
    90  		if ent.IsNotFound(err) {
    91  			return ur.Create(ctx, userData)
    92  		}
    93  
    94  		return nil, err
    95  	}
    96  
    97  	return entUserToUser(u), nil
    98  }
    99  
   100  func (ur *UserRepository) Setting(ctx context.Context, id int, userData *entities.SettingMutation) (*entities.User, error) {
   101  	uu := ur.Client.User.UpdateOneID(id).
   102  		SetUsername(userData.Username).
   103  		SetDisplayName(userData.DisplayName).
   104  		SetURL(userData.URL).
   105  		SetBio(userData.Bio).
   106  		SetBioHTML(userData.BioHTML).
   107  		SetEmail(userData.Email)
   108  
   109  	if userData.AvatarImageID > 0 {
   110  		uu.SetAvatarImageID(userData.AvatarImageID)
   111  	}
   112  
   113  	if userData.Password != "" {
   114  		uu.SetPassword(userData.Password)
   115  	}
   116  
   117  	user, err := uu.Save(ctx)
   118  
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	return entUserToUser(user), nil
   124  }
   125  
   126  func CreateUserRepository(client *ent.Client) *UserRepository {
   127  	return &UserRepository{
   128  		BaseRepository: &BaseRepository[e.User, ent.User, *ent.UserQuery, *e.UserFilter]{
   129  			Name:      "user",
   130  			Client:    client,
   131  			ConvertFn: entUserToUser,
   132  			ByIDFn:    userById,
   133  			DeleteByIDFn: func(ctx context.Context, client *ent.Client, id int) error {
   134  				return client.User.DeleteOneID(id).Exec(ctx)
   135  			},
   136  			CreateFn: func(ctx context.Context, client *ent.Client, data *e.User) (*ent.User, error) {
   137  				uc := client.User.Create().
   138  					SetUsername(data.Username).
   139  					SetDisplayName(data.DisplayName).
   140  					SetURL(data.URL).
   141  					SetBio(data.Bio).
   142  					SetBioHTML(data.BioHTML).
   143  					SetEmail(data.Email).
   144  					SetProvider(data.Provider).
   145  					SetProviderID(data.ProviderID).
   146  					SetProviderUsername(data.ProviderUsername).
   147  					SetProviderAvatar(data.ProviderAvatar).
   148  					SetActive(data.Active)
   149  
   150  				if data.AvatarImageID > 0 {
   151  					uc.SetAvatarImageID(data.AvatarImageID)
   152  				}
   153  
   154  				if data.Provider == "local" {
   155  					uc.SetProviderID(fmt.Sprintf("%d", time.Now().UnixMicro()))
   156  				}
   157  
   158  				if len(data.RoleIDs) > 0 {
   159  					uc.AddRoleIDs(data.RoleIDs...)
   160  				}
   161  
   162  				if data.Password != "" {
   163  					uc.SetPassword(data.Password)
   164  				}
   165  
   166  				user, err := uc.Save(ctx)
   167  
   168  				if err != nil {
   169  					return nil, err
   170  				}
   171  
   172  				if data.Provider == "local" {
   173  					user, err = client.User.
   174  						UpdateOneID(user.ID).
   175  						SetProviderID(fmt.Sprintf("%d", user.ID)).Save(ctx)
   176  					if err != nil {
   177  						return nil, err
   178  					}
   179  				}
   180  
   181  				return userById(ctx, client, user.ID)
   182  			},
   183  			UpdateFn: func(ctx context.Context, client *ent.Client, data *e.User) (*ent.User, error) {
   184  				if data.ID == 0 {
   185  					return nil, fmt.Errorf("user id is required")
   186  				}
   187  				uu := client.User.UpdateOneID(data.ID).
   188  					SetUsername(data.Username).
   189  					SetDisplayName(data.DisplayName).
   190  					SetURL(data.URL).
   191  					SetBio(data.Bio).
   192  					SetBioHTML(data.BioHTML).
   193  					SetEmail(data.Email).
   194  					SetProvider(data.Provider).
   195  					SetProviderID(data.ProviderID).
   196  					SetProviderUsername(data.ProviderUsername).
   197  					SetProviderAvatar(data.ProviderAvatar).
   198  					SetActive(data.Active)
   199  
   200  				if data.AvatarImageID > 0 {
   201  					uu.SetAvatarImageID(data.AvatarImageID)
   202  				}
   203  
   204  				if len(data.RoleIDs) > 0 {
   205  					oldUserEnt, err := userById(ctx, client, data.ID)
   206  					if err != nil {
   207  						return nil, err
   208  					}
   209  					oldUser := entUserToUser(oldUserEnt)
   210  					oldRoleIDs := utils.SliceMap(oldUser.Roles, func(r *entities.Role) int {
   211  						return r.ID
   212  					})
   213  					uu.RemoveRoleIDs(oldRoleIDs...)
   214  					uu.AddRoleIDs(data.RoleIDs...)
   215  				}
   216  
   217  				if data.Password != "" {
   218  					uu.SetPassword(data.Password)
   219  				}
   220  
   221  				user, err := uu.Save(ctx)
   222  
   223  				if err != nil {
   224  					return nil, err
   225  				}
   226  
   227  				return user, nil
   228  			},
   229  			QueryFilterFn: func(client *ent.Client, filters ...*e.UserFilter) *ent.UserQuery {
   230  				query := client.User.Query().Where(user.DeletedAtIsNil())
   231  
   232  				if len(filters) > 0 && filters[0].Search != "" {
   233  					query = query.Where(user.UsernameContainsFold(filters[0].Search))
   234  				}
   235  				return query
   236  			},
   237  			FindFn: func(ctx context.Context, query *ent.UserQuery, filters ...*e.UserFilter) ([]*ent.User, error) {
   238  				page, limit, sorts := getPaginateParams(filters...)
   239  				return query.
   240  					Limit(limit).
   241  					Offset((page - 1) * limit).
   242  					Order(sorts...).All(ctx)
   243  			},
   244  		},
   245  	}
   246  }