github.com/ngocphuongnb/tetua@v0.0.7-alpha/packages/entrepository/base.go (about)

     1  package entrepository
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/ngocphuongnb/tetua/app/entities"
     9  	"github.com/ngocphuongnb/tetua/app/utils"
    10  	"github.com/ngocphuongnb/tetua/packages/entrepository/ent"
    11  )
    12  
    13  type EntityType interface {
    14  	ent.Comment | ent.File | ent.Permission | ent.Post | ent.Page | ent.Role | ent.Setting | ent.Topic | ent.User
    15  }
    16  
    17  type QueryFilter interface {
    18  	GetSearch() string
    19  	GetPage() int
    20  	GetLimit() int
    21  	GetSorts() []*entities.Sort
    22  	GetIgnoreUrlParams() []string
    23  	GetExcludeIDs() []int
    24  	Base() string
    25  }
    26  
    27  type EntityQuery[EE EntityType] interface {
    28  	*ent.CommentQuery | *ent.FileQuery | *ent.PermissionQuery | *ent.PostQuery | *ent.PageQuery | *ent.RoleQuery | *ent.SettingQuery | *ent.TopicQuery | *ent.UserQuery
    29  	Count(context.Context) (int, error)
    30  	All(context.Context) ([]*EE, error)
    31  }
    32  
    33  type BaseRepository[E entities.Entity, EE EntityType, EQ EntityQuery[EE], QF QueryFilter] struct {
    34  	Name          string
    35  	Client        *ent.Client
    36  	ConvertFn     func(entEntity *EE) *E
    37  	ByIDFn        func(ctx context.Context, client *ent.Client, id int) (*EE, error)
    38  	DeleteByIDFn  func(ctx context.Context, client *ent.Client, id int) error
    39  	CreateFn      func(ctx context.Context, client *ent.Client, data *E) (*EE, error)
    40  	UpdateFn      func(ctx context.Context, client *ent.Client, data *E) (*EE, error)
    41  	FindFn        func(ctx context.Context, query EQ, filters ...QF) ([]*EE, error)
    42  	QueryFilterFn func(client *ent.Client, filters ...QF) EQ
    43  }
    44  
    45  func (b *BaseRepository[E, EE, EQ, QF]) ByID(ctx context.Context, id int) (*E, error) {
    46  	entity, err := b.ByIDFn(ctx, b.Client, id)
    47  
    48  	if err != nil {
    49  		return nil, EntError(err, fmt.Sprintf("%s not found with id: %d", b.Name, id))
    50  	}
    51  
    52  	return b.ConvertFn(entity), nil
    53  }
    54  
    55  func (b *BaseRepository[E, EE, EQ, QF]) DeleteByID(ctx context.Context, id int) error {
    56  	return b.DeleteByIDFn(ctx, b.Client, id)
    57  }
    58  
    59  func (b *BaseRepository[E, EE, EQ, QF]) Create(ctx context.Context, data *E) (*E, error) {
    60  	entity, err := b.CreateFn(ctx, b.Client, data)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	return b.ConvertFn(entity), nil
    66  }
    67  
    68  func (b *BaseRepository[E, EE, EQ, QF]) Update(ctx context.Context, data *E) (*E, error) {
    69  	entity, err := b.UpdateFn(ctx, b.Client, data)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return b.ConvertFn(entity), nil
    75  }
    76  
    77  func (b *BaseRepository[E, EE, EQ, QF]) Count(ctx context.Context, filters ...QF) (int, error) {
    78  	return b.QueryFilterFn(b.Client, filters...).Count(ctx)
    79  }
    80  
    81  func getPaginateParams[F QueryFilter](filters ...F) (int, int, []ent.OrderFunc) {
    82  	page := 1
    83  	limit := 10
    84  	sorts := []ent.OrderFunc{ent.Desc("id")}
    85  
    86  	if len(filters) > 0 {
    87  		if filters[0].GetLimit() > 0 {
    88  			limit = filters[0].GetLimit()
    89  		}
    90  
    91  		if filters[0].GetPage() > 0 {
    92  			page = filters[0].GetPage()
    93  		}
    94  
    95  		if len(filters[0].GetSorts()) > 0 {
    96  			sorts = getSortFNs(filters[0].GetSorts())
    97  		}
    98  	}
    99  
   100  	return page, limit, sorts
   101  }
   102  
   103  func (b *BaseRepository[E, EE, EQ, QF]) All(ctx context.Context) ([]*E, error) {
   104  	query := b.QueryFilterFn(b.Client)
   105  	if items, err := query.All(ctx); err != nil {
   106  		return nil, err
   107  	} else {
   108  		return utils.SliceMap(items, b.ConvertFn), nil
   109  	}
   110  }
   111  
   112  func (b *BaseRepository[E, EE, EQ, QF]) Find(ctx context.Context, filters ...QF) ([]*E, error) {
   113  	query := b.QueryFilterFn(b.Client, filters...)
   114  	if items, err := b.FindFn(ctx, query, filters...); err != nil {
   115  		return nil, err
   116  	} else {
   117  		return utils.SliceMap(items, b.ConvertFn), nil
   118  	}
   119  }
   120  
   121  func (b *BaseRepository[E, EE, EQ, QF]) Paginate(ctx context.Context, filters ...QF) (*entities.Paginate[E], error) {
   122  	var err1 error
   123  	var err2 error
   124  	var wg sync.WaitGroup
   125  	total := 0
   126  	base := ""
   127  	items := make([]*EE, 0)
   128  	page, limit, _ := getPaginateParams(filters[0])
   129  
   130  	wg.Add(2)
   131  	go func(wg *sync.WaitGroup) {
   132  		defer wg.Done()
   133  		total, err1 = b.QueryFilterFn(b.Client, filters...).Count(ctx)
   134  	}(&wg)
   135  	go func(wg *sync.WaitGroup) {
   136  		defer wg.Done()
   137  		items, err2 = b.FindFn(ctx, b.QueryFilterFn(b.Client, filters...), filters...)
   138  	}(&wg)
   139  	wg.Wait()
   140  
   141  	if err := utils.FirstError(err1, err2); err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	if len(filters) > 0 {
   146  		base = filters[0].Base()
   147  	}
   148  
   149  	return &entities.Paginate[E]{
   150  		Data:        utils.SliceMap(items, b.ConvertFn),
   151  		BaseUrl:     base,
   152  		Total:       total,
   153  		PageSize:    limit,
   154  		PageCurrent: page,
   155  	}, nil
   156  }
   157  
   158  func EntError(err error, msg string) error {
   159  	if ent.IsNotFound(err) {
   160  		return &entities.NotFoundError{Message: msg}
   161  	}
   162  
   163  	return err
   164  }