github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/user_store.go (about) 1 package common 2 3 import ( 4 "database/sql" 5 "errors" 6 "strconv" 7 8 qgen "github.com/Azareal/Gosora/query_gen" 9 "golang.org/x/crypto/bcrypt" 10 ) 11 12 // TODO: Add the watchdog goroutine 13 // TODO: Add some sort of update method 14 var Users UserStore 15 var ErrAccountExists = errors.New("this username is already in use") 16 var ErrLongUsername = errors.New("this username is too long") 17 var ErrSomeUsersNotFound = errors.New("Unable to find some users") 18 19 type UserStore interface { 20 DirtyGet(id int) *User 21 Get(id int) (*User, error) 22 Getn(id int) *User 23 GetByName(name string) (*User, error) 24 BulkGetByName(names []string) (list []*User, err error) 25 RawBulkGetByNameForConvo(f func(int, string, int, bool, int, int) error, names []string) error 26 Exists(id int) bool 27 SearchOffset(name, email string, gid, offset, perPage int) (users []*User, err error) 28 GetOffset(offset, perPage int) ([]*User, error) 29 Each(f func(*User) error) error 30 //BulkGet(ids []int) ([]*User, error) 31 BulkGetMap(ids []int) (map[int]*User, error) 32 BypassGet(id int) (*User, error) 33 ClearLastIPs() error 34 Create(name, password, email string, group int, active bool) (int, error) 35 Reload(id int) error 36 Count() int 37 CountSearch(name, email string, gid int) int 38 39 SetCache(cache UserCache) 40 GetCache() UserCache 41 } 42 43 type DefaultUserStore struct { 44 cache UserCache 45 46 get *sql.Stmt 47 getByName *sql.Stmt 48 searchOffset *sql.Stmt 49 getOffset *sql.Stmt 50 getAll *sql.Stmt 51 exists *sql.Stmt 52 register *sql.Stmt 53 nameExists *sql.Stmt 54 55 count *sql.Stmt 56 countSearch *sql.Stmt 57 58 clearIPs *sql.Stmt 59 } 60 61 // NewDefaultUserStore gives you a new instance of DefaultUserStore 62 func NewDefaultUserStore(cache UserCache) (*DefaultUserStore, error) { 63 acc := qgen.NewAcc() 64 if cache == nil { 65 cache = NewNullUserCache() 66 } 67 u := "users" 68 allCols := "uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo" 69 // TODO: Add an admin version of registerStmt with more flexibility? 70 return &DefaultUserStore{ 71 cache: cache, 72 73 get: acc.Select(u).Columns("name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("uid=?").Prepare(), 74 getByName: acc.Select(u).Columns(allCols).Where("name=?").Prepare(), 75 searchOffset: acc.Select(u).Columns(allCols).Where("(name=? OR ?='') AND (email=? OR ?='') AND (group=? OR ?=0)").Orderby("uid ASC").Limit("?,?").Prepare(), 76 getOffset: acc.Select(u).Columns(allCols).Orderby("uid ASC").Limit("?,?").Prepare(), 77 getAll: acc.Select(u).Columns(allCols).Prepare(), 78 79 exists: acc.Exists(u, "uid").Prepare(), 80 register: acc.Insert(u).Columns("name,email,password,salt,group,is_super_admin,session,active,message,createdAt,lastActiveAt,lastLiked,oldestItemLikedCreatedAt").Fields("?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP(),UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), // TODO: Implement user_count on users_groups here 81 nameExists: acc.Exists(u, "name").Prepare(), 82 83 count: acc.Count(u).Prepare(), 84 countSearch: acc.Count(u).Where("(name=? OR ?='') AND (email=? OR ?='') AND (group=? OR ?=0)").Prepare(), 85 86 clearIPs: acc.Update(u).Set("last_ip=''").Where("last_ip!=''").Prepare(), 87 }, acc.FirstError() 88 } 89 90 func (s *DefaultUserStore) DirtyGet(id int) *User { 91 user, err := s.Get(id) 92 if err == nil { 93 return user 94 } 95 /*if s.OutOfBounds(id) { 96 return BlankUser() 97 }*/ 98 return BlankUser() 99 } 100 101 func (s *DefaultUserStore) scanUser(r *sql.Row, u *User) (embeds int, err error) { 102 e := r.Scan(&u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 103 return embeds, e 104 } 105 106 // TODO: Log weird cache errors? Not just here but in every *Cache? 107 func (s *DefaultUserStore) Get(id int) (*User, error) { 108 u, err := s.cache.Get(id) 109 if err == nil { 110 //log.Print("cached user") 111 //log.Print(string(debug.Stack())) 112 //log.Println("") 113 return u, nil 114 } 115 //log.Print("uncached user") 116 117 u = &User{ID: id, Loggedin: true} 118 embeds, err := s.scanUser(s.get.QueryRow(id), u) 119 if err == nil { 120 if embeds != -1 { 121 u.ParseSettings = DefaultParseSettings.CopyPtr() 122 u.ParseSettings.NoEmbed = embeds == 0 123 } 124 u.Init() 125 s.cache.Set(u) 126 } 127 return u, err 128 } 129 130 func (s *DefaultUserStore) Getn(id int) *User { 131 u := s.cache.Getn(id) 132 if u != nil { 133 return u 134 } 135 136 u = &User{ID: id, Loggedin: true} 137 embeds, err := s.scanUser(s.get.QueryRow(id), u) 138 if err != nil { 139 return nil 140 } 141 if embeds != -1 { 142 u.ParseSettings = DefaultParseSettings.CopyPtr() 143 u.ParseSettings.NoEmbed = embeds == 0 144 } 145 u.Init() 146 s.cache.Set(u) 147 return u 148 } 149 150 // TODO: Log weird cache errors? Not just here but in every *Cache? 151 // ! This bypasses the cache, use frugally 152 func (s *DefaultUserStore) GetByName(name string) (*User, error) { 153 u := &User{Loggedin: true} 154 var embeds int 155 err := s.getByName.QueryRow(name).Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 156 if err != nil { 157 return nil, err 158 } 159 if embeds != -1 { 160 u.ParseSettings = DefaultParseSettings.CopyPtr() 161 u.ParseSettings.NoEmbed = embeds == 0 162 } 163 u.Init() 164 s.cache.Set(u) 165 return u, nil 166 } 167 168 // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts? 169 // ! This bypasses the cache, use frugally 170 func (s *DefaultUserStore) BulkGetByName(names []string) (list []*User, err error) { 171 if len(names) == 0 { 172 return list, nil 173 } else if len(names) == 1 { 174 user, err := s.GetByName(names[0]) 175 if err != nil { 176 return list, err 177 } 178 return []*User{user}, nil 179 } 180 181 idList, q := inqbuildstr(names) 182 rows, err := qgen.NewAcc().Select("users").Columns("uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("name IN(" + q + ")").Query(idList...) 183 if err != nil { 184 return list, err 185 } 186 defer rows.Close() 187 188 var embeds int 189 for rows.Next() { 190 u := &User{Loggedin: true} 191 err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 192 if err != nil { 193 return list, err 194 } 195 if embeds != -1 { 196 u.ParseSettings = DefaultParseSettings.CopyPtr() 197 u.ParseSettings.NoEmbed = embeds == 0 198 } 199 u.Init() 200 s.cache.Set(u) 201 list = append(list, u) 202 } 203 if err = rows.Err(); err != nil { 204 return list, err 205 } 206 207 // Did we miss any users? 208 if len(names) > len(list) { 209 return list, ErrSomeUsersNotFound 210 } 211 return list, err 212 } 213 214 // Special case function for efficiency 215 func (s *DefaultUserStore) RawBulkGetByNameForConvo(f func(int, string, int, bool, int, int) error, names []string) error { 216 idList, q := inqbuildstr(names) 217 rows, e := qgen.NewAcc().Select("users").Columns("uid,name,group,is_super_admin,temp_group,who_can_convo").Where("name IN(" + q + ")").Query(idList...) 218 if e != nil { 219 return e 220 } 221 defer rows.Close() 222 for rows.Next() { 223 var name string 224 var id, group, temp_group, who_can_convo int 225 var super_admin bool 226 if e = rows.Scan(&id, &name, &group, &super_admin, &temp_group, &who_can_convo); e != nil { 227 return e 228 } 229 if e = f(id, name, group, super_admin, temp_group, who_can_convo); e != nil { 230 return e 231 } 232 } 233 return rows.Err() 234 } 235 236 // TODO: Optimise this, so we don't wind up hitting the database every-time for small gaps 237 // TODO: Make this a little more consistent with DefaultGroupStore's GetRange method 238 func (s *DefaultUserStore) GetOffset(offset, perPage int) (users []*User, err error) { 239 rows, err := s.getOffset.Query(offset, perPage) 240 if err != nil { 241 return users, err 242 } 243 defer rows.Close() 244 245 var embeds int 246 for rows.Next() { 247 u := &User{Loggedin: true} 248 err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 249 if err != nil { 250 return nil, err 251 } 252 if embeds != -1 { 253 u.ParseSettings = DefaultParseSettings.CopyPtr() 254 u.ParseSettings.NoEmbed = embeds == 0 255 } 256 u.Init() 257 s.cache.Set(u) 258 users = append(users, u) 259 } 260 return users, rows.Err() 261 } 262 func (s *DefaultUserStore) SearchOffset(name, email string, gid, offset, perPage int) (users []*User, err error) { 263 rows, err := s.searchOffset.Query(name, name, email, email, gid, gid, offset, perPage) 264 if err != nil { 265 return users, err 266 } 267 defer rows.Close() 268 269 var embeds int 270 for rows.Next() { 271 u := &User{Loggedin: true} 272 err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 273 if err != nil { 274 return nil, err 275 } 276 if embeds != -1 { 277 u.ParseSettings = DefaultParseSettings.CopyPtr() 278 u.ParseSettings.NoEmbed = embeds == 0 279 } 280 u.Init() 281 s.cache.Set(u) 282 users = append(users, u) 283 } 284 return users, rows.Err() 285 } 286 func (s *DefaultUserStore) Each(f func(*User) error) error { 287 rows, e := s.getAll.Query() 288 if e != nil { 289 return e 290 } 291 defer rows.Close() 292 var embeds int 293 for rows.Next() { 294 u := new(User) 295 if e := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage); e != nil { 296 return e 297 } 298 if embeds != -1 { 299 u.ParseSettings = DefaultParseSettings.CopyPtr() 300 u.ParseSettings.NoEmbed = embeds == 0 301 } 302 u.Init() 303 if e := f(u); e != nil { 304 return e 305 } 306 } 307 return rows.Err() 308 } 309 310 // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts? 311 // TODO: ID of 0 should always error? 312 func (s *DefaultUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) { 313 idCount := len(ids) 314 list = make(map[int]*User) 315 if idCount == 0 { 316 return list, nil 317 } 318 319 var stillHere []int 320 sliceList := s.cache.BulkGet(ids) 321 if len(sliceList) > 0 { 322 for i, sliceItem := range sliceList { 323 if sliceItem != nil { 324 list[sliceItem.ID] = sliceItem 325 } else { 326 stillHere = append(stillHere, ids[i]) 327 } 328 } 329 ids = stillHere 330 } 331 332 // If every user is in the cache, then return immediately 333 if len(ids) == 0 { 334 return list, nil 335 } else if len(ids) == 1 { 336 user, err := s.Get(ids[0]) 337 if err != nil { 338 return list, err 339 } 340 list[user.ID] = user 341 return list, nil 342 } 343 344 idList, q := inqbuild(ids) 345 rows, err := qgen.NewAcc().Select("users").Columns("uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("uid IN(" + q + ")").Query(idList...) 346 if err != nil { 347 return list, err 348 } 349 defer rows.Close() 350 351 var embeds int 352 for rows.Next() { 353 u := &User{Loggedin: true} 354 err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage) 355 if err != nil { 356 return list, err 357 } 358 if embeds != -1 { 359 u.ParseSettings = DefaultParseSettings.CopyPtr() 360 u.ParseSettings.NoEmbed = embeds == 0 361 } 362 u.Init() 363 s.cache.Set(u) 364 list[u.ID] = u 365 } 366 if err = rows.Err(); err != nil { 367 return list, err 368 } 369 370 // Did we miss any users? 371 if idCount > len(list) { 372 var sidList string 373 for _, id := range ids { 374 _, ok := list[id] 375 if !ok { 376 sidList += strconv.Itoa(id) + "," 377 } 378 } 379 if sidList != "" { 380 sidList = sidList[0 : len(sidList)-1] 381 err = errors.New("Unable to find users with the following IDs: " + sidList) 382 } 383 } 384 385 return list, err 386 } 387 388 func (s *DefaultUserStore) BypassGet(id int) (*User, error) { 389 u := &User{ID: id, Loggedin: true} 390 embeds, err := s.scanUser(s.get.QueryRow(id), u) 391 if err == nil { 392 if embeds != -1 { 393 u.ParseSettings = DefaultParseSettings.CopyPtr() 394 u.ParseSettings.NoEmbed = embeds == 0 395 } 396 u.Init() 397 } 398 return u, err 399 } 400 401 func (s *DefaultUserStore) Reload(id int) error { 402 u, err := s.BypassGet(id) 403 if err != nil { 404 s.cache.Remove(id) 405 return err 406 } 407 _ = s.cache.Set(u) 408 TopicListThaw.Thaw() 409 return nil 410 } 411 412 func (s *DefaultUserStore) Exists(id int) bool { 413 err := s.exists.QueryRow(id).Scan(&id) 414 if err != nil && err != ErrNoRows { 415 LogError(err) 416 } 417 return err != ErrNoRows 418 } 419 420 func (s *DefaultUserStore) ClearLastIPs() error { 421 _, e := s.clearIPs.Exec() 422 return e 423 } 424 425 // TODO: Change active to a bool? 426 // TODO: Use unique keys for the usernames 427 func (s *DefaultUserStore) Create(name, password, email string, group int, active bool) (int, error) { 428 // TODO: Strip spaces? 429 430 // ? This number might be a little screwy with Unicode, but it's the only consistent thing we have, as Unicode characters can be any number of bytes in theory? 431 if len(name) > Config.MaxUsernameLength { 432 return 0, ErrLongUsername 433 } 434 435 // Is this name already taken..? 436 err := s.nameExists.QueryRow(name).Scan(&name) 437 if err != ErrNoRows { 438 return 0, ErrAccountExists 439 } 440 salt, err := GenerateSafeString(SaltLength) 441 if err != nil { 442 return 0, err 443 } 444 hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost) 445 if err != nil { 446 return 0, err 447 } 448 449 res, err := s.register.Exec(name, email, string(hashedPassword), salt, group, active) 450 if err != nil { 451 return 0, err 452 } 453 lastID, err := res.LastInsertId() 454 return int(lastID), err 455 } 456 457 // Count returns the total number of users registered on the forums 458 func (s *DefaultUserStore) Count() (count int) { 459 return Countf(s.count) 460 } 461 462 func (s *DefaultUserStore) CountSearch(name, email string, gid int) (count int) { 463 return Countf(s.countSearch, name, name, email, email, gid, gid) 464 } 465 466 func (s *DefaultUserStore) SetCache(cache UserCache) { 467 s.cache = cache 468 } 469 470 // TODO: We're temporarily doing this so that you can do ucache != nil in getTopicUser. Refactor it. 471 func (s *DefaultUserStore) GetCache() UserCache { 472 _, ok := s.cache.(*NullUserCache) 473 if ok { 474 return nil 475 } 476 return s.cache 477 }