github.com/ngocphuongnb/tetua@v0.0.7-alpha/app/mock/repository/repository.go (about)

     1  package mockrepository
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/ngocphuongnb/tetua/app/entities"
    12  	"github.com/ngocphuongnb/tetua/app/utils"
    13  )
    14  
    15  var FakeRepoErrors = map[string]error{}
    16  
    17  type Repository[E entities.Entity] struct {
    18  	Name     string
    19  	entities []*E
    20  	mu       sync.Mutex
    21  }
    22  
    23  type Filter struct {
    24  	Search          string           `form:"search" json:"search"`
    25  	Page            int              `form:"page" json:"page"`
    26  	Limit           int              `form:"limit" json:"limit"`
    27  	Sorts           []*entities.Sort `form:"orders" json:"orders"`
    28  	IgnoreUrlParams []string         `form:"ignore_url_params" json:"ignore_url_params"`
    29  	ExcludeIDs      []int            `form:"exclude_ids" json:"exclude_ids"`
    30  }
    31  
    32  func setEntityField[E entities.Entity](entity *E, field string, value interface{}) {
    33  	reflect.ValueOf(entity).Elem().FieldByName(field).Set(reflect.ValueOf(value))
    34  }
    35  
    36  func idEQ[E entities.Entity](entity1, entity2 *E) bool {
    37  	return getEntityField(entity1, "ID") == getEntityField(entity2, "ID")
    38  }
    39  
    40  func getEntityField[E entities.Entity](entity *E, field string) interface{} {
    41  	r := reflect.ValueOf(entity)
    42  	f := reflect.Indirect(r).FieldByName(field)
    43  	return f.Interface()
    44  }
    45  
    46  func getEntityByField[E entities.Entity](name string, slice []*E, compareField string, compareValue interface{}) (*E, error) {
    47  	foundEntities := utils.SliceFilter(slice, func(e *E) bool {
    48  		return compareValue == getEntityField(e, compareField)
    49  	})
    50  
    51  	if len(foundEntities) == 0 {
    52  		return nil, &entities.NotFoundError{Message: name + " not found with " + compareField + " = " + fmt.Sprintf("%v", compareValue)}
    53  	}
    54  
    55  	return foundEntities[0], nil
    56  }
    57  
    58  func ByID[E entities.Entity](ctx context.Context, name string, slice []*E, id int) (*E, error) {
    59  	if ctx.Value("query_error") != nil {
    60  		return nil, errors.New("ByID error")
    61  	}
    62  
    63  	return getEntityByField(name, slice, "ID", id)
    64  }
    65  
    66  func (m *Repository[E]) All(ctx context.Context) ([]*E, error) {
    67  	return m.entities, nil
    68  }
    69  
    70  func (m *Repository[E]) ByID(ctx context.Context, id int) (*E, error) {
    71  	return ByID(ctx, m.Name, m.entities, id)
    72  }
    73  
    74  func (m *Repository[E]) Create(ctx context.Context, entity *E) (*E, error) {
    75  	if ctx.Value("create_error") != nil {
    76  		return nil, errors.New("Error create " + m.Name)
    77  	}
    78  
    79  	if err, ok := FakeRepoErrors[m.Name+"_create"]; ok && err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	for _, e := range m.entities {
    84  		if idEQ(e, entity) {
    85  			return nil, errors.New(m.Name + " already exists")
    86  		}
    87  	}
    88  
    89  	now := time.Now()
    90  	m.mu.Lock()
    91  	defer m.mu.Unlock()
    92  	setEntityField(entity, "ID", len(m.entities)+1)
    93  	setEntityField(entity, "CreatedAt", &now)
    94  	setEntityField(entity, "UpdatedAt", &now)
    95  	m.entities = append(m.entities, entity)
    96  
    97  	return entity, nil
    98  }
    99  
   100  func (m *Repository[E]) Update(ctx context.Context, entity *E) (*E, error) {
   101  	if ctx.Value("update_error") != nil {
   102  		return nil, errors.New("Error save " + m.Name)
   103  	}
   104  
   105  	if err, ok := FakeRepoErrors[m.Name+"_update"]; ok && err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	found := false
   110  	m.mu.Lock()
   111  	defer m.mu.Unlock()
   112  	m.entities = utils.SliceMap(m.entities, func(e *E) *E {
   113  		if idEQ(e, entity) {
   114  			found = true
   115  			return entity
   116  		}
   117  		return e
   118  	})
   119  
   120  	if !found {
   121  		return nil, errors.New(m.Name + " not found")
   122  	}
   123  
   124  	return entity, nil
   125  }
   126  
   127  func (m *Repository[E]) DeleteByID(ctx context.Context, id int) error {
   128  	if err, ok := FakeRepoErrors[m.Name+"_deleteByID"]; ok && err != nil {
   129  		return err
   130  	}
   131  
   132  	found := false
   133  	m.mu.Lock()
   134  	defer m.mu.Unlock()
   135  
   136  	m.entities = utils.SliceFilter(m.entities, func(e *E) bool {
   137  		if getEntityField(e, "ID") == id {
   138  			found = true
   139  		}
   140  		return getEntityField(e, "ID") != id
   141  	})
   142  
   143  	if !found {
   144  		return errors.New(m.Name + " not found")
   145  	}
   146  
   147  	return nil
   148  }
   149  
   150  func (m *Repository[E]) Find(ctx context.Context, filters ...*Filter) ([]*E, error) {
   151  	if err, ok := FakeRepoErrors[m.Name+"_find"]; ok && err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	if len(filters) == 0 {
   156  		return m.entities, nil
   157  	}
   158  
   159  	if filters[0].Page < 1 {
   160  		filters[0].Page = 1
   161  	}
   162  
   163  	if filters[0].Limit < 1 {
   164  		filters[0].Limit = 10
   165  	}
   166  
   167  	result := make([]*E, 0)
   168  	filter := *filters[0]
   169  	offset := (filter.Page - 1) * filter.Limit
   170  
   171  	for index, e := range m.entities {
   172  		if index < offset {
   173  			continue
   174  		}
   175  		if index >= offset+filter.Limit {
   176  			break
   177  		}
   178  		result = append(result, e)
   179  	}
   180  
   181  	return result, nil
   182  }
   183  
   184  func (m *Repository[E]) Paginate(ctx context.Context, filters ...*entities.Filter) (*entities.Paginate[E], error) {
   185  	if err, ok := FakeRepoErrors[m.Name+"_paginate"]; ok && err != nil {
   186  		return nil, err
   187  	}
   188  
   189  	result := make([]*E, 0)
   190  	filter := *filters[0]
   191  	offset := (filter.Page - 1) * filter.Limit
   192  
   193  	for index, e := range m.entities {
   194  		if index < offset {
   195  			continue
   196  		}
   197  		if index >= offset+filter.Limit {
   198  			break
   199  		}
   200  		result = append(result, e)
   201  	}
   202  
   203  	return &entities.Paginate[E]{
   204  		PageCurrent: filter.Page,
   205  		PageSize:    filter.Limit,
   206  		Total:       len(m.entities),
   207  		Data:        result,
   208  	}, nil
   209  }