github.com/ngocphuongnb/tetua@v0.0.7-alpha/app/mock/repository/repository.go (about) 1 package mockrepository 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "sync" 9 "time" 10 11 "github.com/ngocphuongnb/tetua/app/entities" 12 "github.com/ngocphuongnb/tetua/app/utils" 13 ) 14 15 var FakeRepoErrors = map[string]error{} 16 17 type Repository[E entities.Entity] struct { 18 Name string 19 entities []*E 20 mu sync.Mutex 21 } 22 23 type Filter struct { 24 Search string `form:"search" json:"search"` 25 Page int `form:"page" json:"page"` 26 Limit int `form:"limit" json:"limit"` 27 Sorts []*entities.Sort `form:"orders" json:"orders"` 28 IgnoreUrlParams []string `form:"ignore_url_params" json:"ignore_url_params"` 29 ExcludeIDs []int `form:"exclude_ids" json:"exclude_ids"` 30 } 31 32 func setEntityField[E entities.Entity](entity *E, field string, value interface{}) { 33 reflect.ValueOf(entity).Elem().FieldByName(field).Set(reflect.ValueOf(value)) 34 } 35 36 func idEQ[E entities.Entity](entity1, entity2 *E) bool { 37 return getEntityField(entity1, "ID") == getEntityField(entity2, "ID") 38 } 39 40 func getEntityField[E entities.Entity](entity *E, field string) interface{} { 41 r := reflect.ValueOf(entity) 42 f := reflect.Indirect(r).FieldByName(field) 43 return f.Interface() 44 } 45 46 func getEntityByField[E entities.Entity](name string, slice []*E, compareField string, compareValue interface{}) (*E, error) { 47 foundEntities := utils.SliceFilter(slice, func(e *E) bool { 48 return compareValue == getEntityField(e, compareField) 49 }) 50 51 if len(foundEntities) == 0 { 52 return nil, &entities.NotFoundError{Message: name + " not found with " + compareField + " = " + fmt.Sprintf("%v", compareValue)} 53 } 54 55 return foundEntities[0], nil 56 } 57 58 func ByID[E entities.Entity](ctx context.Context, name string, slice []*E, id int) (*E, error) { 59 if ctx.Value("query_error") != nil { 60 return nil, errors.New("ByID error") 61 } 62 63 return getEntityByField(name, slice, "ID", id) 64 } 65 66 func (m *Repository[E]) All(ctx context.Context) ([]*E, error) { 67 return m.entities, nil 68 } 69 70 func (m *Repository[E]) ByID(ctx context.Context, id int) (*E, error) { 71 return ByID(ctx, m.Name, m.entities, id) 72 } 73 74 func (m *Repository[E]) Create(ctx context.Context, entity *E) (*E, error) { 75 if ctx.Value("create_error") != nil { 76 return nil, errors.New("Error create " + m.Name) 77 } 78 79 if err, ok := FakeRepoErrors[m.Name+"_create"]; ok && err != nil { 80 return nil, err 81 } 82 83 for _, e := range m.entities { 84 if idEQ(e, entity) { 85 return nil, errors.New(m.Name + " already exists") 86 } 87 } 88 89 now := time.Now() 90 m.mu.Lock() 91 defer m.mu.Unlock() 92 setEntityField(entity, "ID", len(m.entities)+1) 93 setEntityField(entity, "CreatedAt", &now) 94 setEntityField(entity, "UpdatedAt", &now) 95 m.entities = append(m.entities, entity) 96 97 return entity, nil 98 } 99 100 func (m *Repository[E]) Update(ctx context.Context, entity *E) (*E, error) { 101 if ctx.Value("update_error") != nil { 102 return nil, errors.New("Error save " + m.Name) 103 } 104 105 if err, ok := FakeRepoErrors[m.Name+"_update"]; ok && err != nil { 106 return nil, err 107 } 108 109 found := false 110 m.mu.Lock() 111 defer m.mu.Unlock() 112 m.entities = utils.SliceMap(m.entities, func(e *E) *E { 113 if idEQ(e, entity) { 114 found = true 115 return entity 116 } 117 return e 118 }) 119 120 if !found { 121 return nil, errors.New(m.Name + " not found") 122 } 123 124 return entity, nil 125 } 126 127 func (m *Repository[E]) DeleteByID(ctx context.Context, id int) error { 128 if err, ok := FakeRepoErrors[m.Name+"_deleteByID"]; ok && err != nil { 129 return err 130 } 131 132 found := false 133 m.mu.Lock() 134 defer m.mu.Unlock() 135 136 m.entities = utils.SliceFilter(m.entities, func(e *E) bool { 137 if getEntityField(e, "ID") == id { 138 found = true 139 } 140 return getEntityField(e, "ID") != id 141 }) 142 143 if !found { 144 return errors.New(m.Name + " not found") 145 } 146 147 return nil 148 } 149 150 func (m *Repository[E]) Find(ctx context.Context, filters ...*Filter) ([]*E, error) { 151 if err, ok := FakeRepoErrors[m.Name+"_find"]; ok && err != nil { 152 return nil, err 153 } 154 155 if len(filters) == 0 { 156 return m.entities, nil 157 } 158 159 if filters[0].Page < 1 { 160 filters[0].Page = 1 161 } 162 163 if filters[0].Limit < 1 { 164 filters[0].Limit = 10 165 } 166 167 result := make([]*E, 0) 168 filter := *filters[0] 169 offset := (filter.Page - 1) * filter.Limit 170 171 for index, e := range m.entities { 172 if index < offset { 173 continue 174 } 175 if index >= offset+filter.Limit { 176 break 177 } 178 result = append(result, e) 179 } 180 181 return result, nil 182 } 183 184 func (m *Repository[E]) Paginate(ctx context.Context, filters ...*entities.Filter) (*entities.Paginate[E], error) { 185 if err, ok := FakeRepoErrors[m.Name+"_paginate"]; ok && err != nil { 186 return nil, err 187 } 188 189 result := make([]*E, 0) 190 filter := *filters[0] 191 offset := (filter.Page - 1) * filter.Limit 192 193 for index, e := range m.entities { 194 if index < offset { 195 continue 196 } 197 if index >= offset+filter.Limit { 198 break 199 } 200 result = append(result, e) 201 } 202 203 return &entities.Paginate[E]{ 204 PageCurrent: filter.Page, 205 PageSize: filter.Limit, 206 Total: len(m.entities), 207 Data: result, 208 }, nil 209 }