github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/user_store.go (about)

     1  package common
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"strconv"
     7  
     8  	qgen "github.com/Azareal/Gosora/query_gen"
     9  	"golang.org/x/crypto/bcrypt"
    10  )
    11  
    12  // TODO: Add the watchdog goroutine
    13  // TODO: Add some sort of update method
    14  var Users UserStore
    15  var ErrAccountExists = errors.New("this username is already in use")
    16  var ErrLongUsername = errors.New("this username is too long")
    17  var ErrSomeUsersNotFound = errors.New("Unable to find some users")
    18  
    19  type UserStore interface {
    20  	DirtyGet(id int) *User
    21  	Get(id int) (*User, error)
    22  	Getn(id int) *User
    23  	GetByName(name string) (*User, error)
    24  	BulkGetByName(names []string) (list []*User, err error)
    25  	RawBulkGetByNameForConvo(f func(int, string, int, bool, int, int) error, names []string) error
    26  	Exists(id int) bool
    27  	SearchOffset(name, email string, gid, offset, perPage int) (users []*User, err error)
    28  	GetOffset(offset, perPage int) ([]*User, error)
    29  	Each(f func(*User) error) error
    30  	//BulkGet(ids []int) ([]*User, error)
    31  	BulkGetMap(ids []int) (map[int]*User, error)
    32  	BypassGet(id int) (*User, error)
    33  	ClearLastIPs() error
    34  	Create(name, password, email string, group int, active bool) (int, error)
    35  	Reload(id int) error
    36  	Count() int
    37  	CountSearch(name, email string, gid int) int
    38  
    39  	SetCache(cache UserCache)
    40  	GetCache() UserCache
    41  }
    42  
    43  type DefaultUserStore struct {
    44  	cache UserCache
    45  
    46  	get          *sql.Stmt
    47  	getByName    *sql.Stmt
    48  	searchOffset *sql.Stmt
    49  	getOffset    *sql.Stmt
    50  	getAll       *sql.Stmt
    51  	exists       *sql.Stmt
    52  	register     *sql.Stmt
    53  	nameExists   *sql.Stmt
    54  
    55  	count       *sql.Stmt
    56  	countSearch *sql.Stmt
    57  
    58  	clearIPs *sql.Stmt
    59  }
    60  
    61  // NewDefaultUserStore gives you a new instance of DefaultUserStore
    62  func NewDefaultUserStore(cache UserCache) (*DefaultUserStore, error) {
    63  	acc := qgen.NewAcc()
    64  	if cache == nil {
    65  		cache = NewNullUserCache()
    66  	}
    67  	u := "users"
    68  	allCols := "uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo"
    69  	// TODO: Add an admin version of registerStmt with more flexibility?
    70  	return &DefaultUserStore{
    71  		cache: cache,
    72  
    73  		get:          acc.Select(u).Columns("name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("uid=?").Prepare(),
    74  		getByName:    acc.Select(u).Columns(allCols).Where("name=?").Prepare(),
    75  		searchOffset: acc.Select(u).Columns(allCols).Where("(name=? OR ?='') AND (email=? OR ?='') AND (group=? OR ?=0)").Orderby("uid ASC").Limit("?,?").Prepare(),
    76  		getOffset:    acc.Select(u).Columns(allCols).Orderby("uid ASC").Limit("?,?").Prepare(),
    77  		getAll:       acc.Select(u).Columns(allCols).Prepare(),
    78  
    79  		exists:     acc.Exists(u, "uid").Prepare(),
    80  		register:   acc.Insert(u).Columns("name,email,password,salt,group,is_super_admin,session,active,message,createdAt,lastActiveAt,lastLiked,oldestItemLikedCreatedAt").Fields("?,?,?,?,?,0,'',?,'',UTC_TIMESTAMP(),UTC_TIMESTAMP(),UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), // TODO: Implement user_count on users_groups here
    81  		nameExists: acc.Exists(u, "name").Prepare(),
    82  
    83  		count:       acc.Count(u).Prepare(),
    84  		countSearch: acc.Count(u).Where("(name=? OR ?='') AND (email=? OR ?='') AND (group=? OR ?=0)").Prepare(),
    85  
    86  		clearIPs: acc.Update(u).Set("last_ip=''").Where("last_ip!=''").Prepare(),
    87  	}, acc.FirstError()
    88  }
    89  
    90  func (s *DefaultUserStore) DirtyGet(id int) *User {
    91  	user, err := s.Get(id)
    92  	if err == nil {
    93  		return user
    94  	}
    95  	/*if s.OutOfBounds(id) {
    96  		return BlankUser()
    97  	}*/
    98  	return BlankUser()
    99  }
   100  
   101  func (s *DefaultUserStore) scanUser(r *sql.Row, u *User) (embeds int, err error) {
   102  	e := r.Scan(&u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   103  	return embeds, e
   104  }
   105  
   106  // TODO: Log weird cache errors? Not just here but in every *Cache?
   107  func (s *DefaultUserStore) Get(id int) (*User, error) {
   108  	u, err := s.cache.Get(id)
   109  	if err == nil {
   110  		//log.Print("cached user")
   111  		//log.Print(string(debug.Stack()))
   112  		//log.Println("")
   113  		return u, nil
   114  	}
   115  	//log.Print("uncached user")
   116  
   117  	u = &User{ID: id, Loggedin: true}
   118  	embeds, err := s.scanUser(s.get.QueryRow(id), u)
   119  	if err == nil {
   120  		if embeds != -1 {
   121  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   122  			u.ParseSettings.NoEmbed = embeds == 0
   123  		}
   124  		u.Init()
   125  		s.cache.Set(u)
   126  	}
   127  	return u, err
   128  }
   129  
   130  func (s *DefaultUserStore) Getn(id int) *User {
   131  	u := s.cache.Getn(id)
   132  	if u != nil {
   133  		return u
   134  	}
   135  
   136  	u = &User{ID: id, Loggedin: true}
   137  	embeds, err := s.scanUser(s.get.QueryRow(id), u)
   138  	if err != nil {
   139  		return nil
   140  	}
   141  	if embeds != -1 {
   142  		u.ParseSettings = DefaultParseSettings.CopyPtr()
   143  		u.ParseSettings.NoEmbed = embeds == 0
   144  	}
   145  	u.Init()
   146  	s.cache.Set(u)
   147  	return u
   148  }
   149  
   150  // TODO: Log weird cache errors? Not just here but in every *Cache?
   151  // ! This bypasses the cache, use frugally
   152  func (s *DefaultUserStore) GetByName(name string) (*User, error) {
   153  	u := &User{Loggedin: true}
   154  	var embeds int
   155  	err := s.getByName.QueryRow(name).Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	if embeds != -1 {
   160  		u.ParseSettings = DefaultParseSettings.CopyPtr()
   161  		u.ParseSettings.NoEmbed = embeds == 0
   162  	}
   163  	u.Init()
   164  	s.cache.Set(u)
   165  	return u, nil
   166  }
   167  
   168  // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
   169  // ! This bypasses the cache, use frugally
   170  func (s *DefaultUserStore) BulkGetByName(names []string) (list []*User, err error) {
   171  	if len(names) == 0 {
   172  		return list, nil
   173  	} else if len(names) == 1 {
   174  		user, err := s.GetByName(names[0])
   175  		if err != nil {
   176  			return list, err
   177  		}
   178  		return []*User{user}, nil
   179  	}
   180  
   181  	idList, q := inqbuildstr(names)
   182  	rows, err := qgen.NewAcc().Select("users").Columns("uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("name IN(" + q + ")").Query(idList...)
   183  	if err != nil {
   184  		return list, err
   185  	}
   186  	defer rows.Close()
   187  
   188  	var embeds int
   189  	for rows.Next() {
   190  		u := &User{Loggedin: true}
   191  		err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   192  		if err != nil {
   193  			return list, err
   194  		}
   195  		if embeds != -1 {
   196  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   197  			u.ParseSettings.NoEmbed = embeds == 0
   198  		}
   199  		u.Init()
   200  		s.cache.Set(u)
   201  		list = append(list, u)
   202  	}
   203  	if err = rows.Err(); err != nil {
   204  		return list, err
   205  	}
   206  
   207  	// Did we miss any users?
   208  	if len(names) > len(list) {
   209  		return list, ErrSomeUsersNotFound
   210  	}
   211  	return list, err
   212  }
   213  
   214  // Special case function for efficiency
   215  func (s *DefaultUserStore) RawBulkGetByNameForConvo(f func(int, string, int, bool, int, int) error, names []string) error {
   216  	idList, q := inqbuildstr(names)
   217  	rows, e := qgen.NewAcc().Select("users").Columns("uid,name,group,is_super_admin,temp_group,who_can_convo").Where("name IN(" + q + ")").Query(idList...)
   218  	if e != nil {
   219  		return e
   220  	}
   221  	defer rows.Close()
   222  	for rows.Next() {
   223  		var name string
   224  		var id, group, temp_group, who_can_convo int
   225  		var super_admin bool
   226  		if e = rows.Scan(&id, &name, &group, &super_admin, &temp_group, &who_can_convo); e != nil {
   227  			return e
   228  		}
   229  		if e = f(id, name, group, super_admin, temp_group, who_can_convo); e != nil {
   230  			return e
   231  		}
   232  	}
   233  	return rows.Err()
   234  }
   235  
   236  // TODO: Optimise this, so we don't wind up hitting the database every-time for small gaps
   237  // TODO: Make this a little more consistent with DefaultGroupStore's GetRange method
   238  func (s *DefaultUserStore) GetOffset(offset, perPage int) (users []*User, err error) {
   239  	rows, err := s.getOffset.Query(offset, perPage)
   240  	if err != nil {
   241  		return users, err
   242  	}
   243  	defer rows.Close()
   244  
   245  	var embeds int
   246  	for rows.Next() {
   247  		u := &User{Loggedin: true}
   248  		err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   249  		if err != nil {
   250  			return nil, err
   251  		}
   252  		if embeds != -1 {
   253  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   254  			u.ParseSettings.NoEmbed = embeds == 0
   255  		}
   256  		u.Init()
   257  		s.cache.Set(u)
   258  		users = append(users, u)
   259  	}
   260  	return users, rows.Err()
   261  }
   262  func (s *DefaultUserStore) SearchOffset(name, email string, gid, offset, perPage int) (users []*User, err error) {
   263  	rows, err := s.searchOffset.Query(name, name, email, email, gid, gid, offset, perPage)
   264  	if err != nil {
   265  		return users, err
   266  	}
   267  	defer rows.Close()
   268  
   269  	var embeds int
   270  	for rows.Next() {
   271  		u := &User{Loggedin: true}
   272  		err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   273  		if err != nil {
   274  			return nil, err
   275  		}
   276  		if embeds != -1 {
   277  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   278  			u.ParseSettings.NoEmbed = embeds == 0
   279  		}
   280  		u.Init()
   281  		s.cache.Set(u)
   282  		users = append(users, u)
   283  	}
   284  	return users, rows.Err()
   285  }
   286  func (s *DefaultUserStore) Each(f func(*User) error) error {
   287  	rows, e := s.getAll.Query()
   288  	if e != nil {
   289  		return e
   290  	}
   291  	defer rows.Close()
   292  	var embeds int
   293  	for rows.Next() {
   294  		u := new(User)
   295  		if e := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage); e != nil {
   296  			return e
   297  		}
   298  		if embeds != -1 {
   299  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   300  			u.ParseSettings.NoEmbed = embeds == 0
   301  		}
   302  		u.Init()
   303  		if e := f(u); e != nil {
   304  			return e
   305  		}
   306  	}
   307  	return rows.Err()
   308  }
   309  
   310  // TODO: Optimise the query to avoid preparing it on the spot? Maybe, use knowledge of the most common IN() parameter counts?
   311  // TODO: ID of 0 should always error?
   312  func (s *DefaultUserStore) BulkGetMap(ids []int) (list map[int]*User, err error) {
   313  	idCount := len(ids)
   314  	list = make(map[int]*User)
   315  	if idCount == 0 {
   316  		return list, nil
   317  	}
   318  
   319  	var stillHere []int
   320  	sliceList := s.cache.BulkGet(ids)
   321  	if len(sliceList) > 0 {
   322  		for i, sliceItem := range sliceList {
   323  			if sliceItem != nil {
   324  				list[sliceItem.ID] = sliceItem
   325  			} else {
   326  				stillHere = append(stillHere, ids[i])
   327  			}
   328  		}
   329  		ids = stillHere
   330  	}
   331  
   332  	// If every user is in the cache, then return immediately
   333  	if len(ids) == 0 {
   334  		return list, nil
   335  	} else if len(ids) == 1 {
   336  		user, err := s.Get(ids[0])
   337  		if err != nil {
   338  			return list, err
   339  		}
   340  		list[user.ID] = user
   341  		return list, nil
   342  	}
   343  
   344  	idList, q := inqbuild(ids)
   345  	rows, err := qgen.NewAcc().Select("users").Columns("uid,name,group,active,is_super_admin,session,email,avatar,message,level,score,posts,liked,last_ip,temp_group,createdAt,enable_embeds,profile_comments,who_can_convo").Where("uid IN(" + q + ")").Query(idList...)
   346  	if err != nil {
   347  		return list, err
   348  	}
   349  	defer rows.Close()
   350  
   351  	var embeds int
   352  	for rows.Next() {
   353  		u := &User{Loggedin: true}
   354  		err := rows.Scan(&u.ID, &u.Name, &u.Group, &u.Active, &u.IsSuperAdmin, &u.Session, &u.Email, &u.RawAvatar, &u.Message, &u.Level, &u.Score, &u.Posts, &u.Liked, &u.LastIP, &u.TempGroup, &u.CreatedAt, &embeds, &u.Privacy.ShowComments, &u.Privacy.AllowMessage)
   355  		if err != nil {
   356  			return list, err
   357  		}
   358  		if embeds != -1 {
   359  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   360  			u.ParseSettings.NoEmbed = embeds == 0
   361  		}
   362  		u.Init()
   363  		s.cache.Set(u)
   364  		list[u.ID] = u
   365  	}
   366  	if err = rows.Err(); err != nil {
   367  		return list, err
   368  	}
   369  
   370  	// Did we miss any users?
   371  	if idCount > len(list) {
   372  		var sidList string
   373  		for _, id := range ids {
   374  			_, ok := list[id]
   375  			if !ok {
   376  				sidList += strconv.Itoa(id) + ","
   377  			}
   378  		}
   379  		if sidList != "" {
   380  			sidList = sidList[0 : len(sidList)-1]
   381  			err = errors.New("Unable to find users with the following IDs: " + sidList)
   382  		}
   383  	}
   384  
   385  	return list, err
   386  }
   387  
   388  func (s *DefaultUserStore) BypassGet(id int) (*User, error) {
   389  	u := &User{ID: id, Loggedin: true}
   390  	embeds, err := s.scanUser(s.get.QueryRow(id), u)
   391  	if err == nil {
   392  		if embeds != -1 {
   393  			u.ParseSettings = DefaultParseSettings.CopyPtr()
   394  			u.ParseSettings.NoEmbed = embeds == 0
   395  		}
   396  		u.Init()
   397  	}
   398  	return u, err
   399  }
   400  
   401  func (s *DefaultUserStore) Reload(id int) error {
   402  	u, err := s.BypassGet(id)
   403  	if err != nil {
   404  		s.cache.Remove(id)
   405  		return err
   406  	}
   407  	_ = s.cache.Set(u)
   408  	TopicListThaw.Thaw()
   409  	return nil
   410  }
   411  
   412  func (s *DefaultUserStore) Exists(id int) bool {
   413  	err := s.exists.QueryRow(id).Scan(&id)
   414  	if err != nil && err != ErrNoRows {
   415  		LogError(err)
   416  	}
   417  	return err != ErrNoRows
   418  }
   419  
   420  func (s *DefaultUserStore) ClearLastIPs() error {
   421  	_, e := s.clearIPs.Exec()
   422  	return e
   423  }
   424  
   425  // TODO: Change active to a bool?
   426  // TODO: Use unique keys for the usernames
   427  func (s *DefaultUserStore) Create(name, password, email string, group int, active bool) (int, error) {
   428  	// TODO: Strip spaces?
   429  
   430  	// ? This number might be a little screwy with Unicode, but it's the only consistent thing we have, as Unicode characters can be any number of bytes in theory?
   431  	if len(name) > Config.MaxUsernameLength {
   432  		return 0, ErrLongUsername
   433  	}
   434  
   435  	// Is this name already taken..?
   436  	err := s.nameExists.QueryRow(name).Scan(&name)
   437  	if err != ErrNoRows {
   438  		return 0, ErrAccountExists
   439  	}
   440  	salt, err := GenerateSafeString(SaltLength)
   441  	if err != nil {
   442  		return 0, err
   443  	}
   444  	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password+salt), bcrypt.DefaultCost)
   445  	if err != nil {
   446  		return 0, err
   447  	}
   448  
   449  	res, err := s.register.Exec(name, email, string(hashedPassword), salt, group, active)
   450  	if err != nil {
   451  		return 0, err
   452  	}
   453  	lastID, err := res.LastInsertId()
   454  	return int(lastID), err
   455  }
   456  
   457  // Count returns the total number of users registered on the forums
   458  func (s *DefaultUserStore) Count() (count int) {
   459  	return Countf(s.count)
   460  }
   461  
   462  func (s *DefaultUserStore) CountSearch(name, email string, gid int) (count int) {
   463  	return Countf(s.countSearch, name, name, email, email, gid, gid)
   464  }
   465  
   466  func (s *DefaultUserStore) SetCache(cache UserCache) {
   467  	s.cache = cache
   468  }
   469  
   470  // TODO: We're temporarily doing this so that you can do ucache != nil in getTopicUser. Refactor it.
   471  func (s *DefaultUserStore) GetCache() UserCache {
   472  	_, ok := s.cache.(*NullUserCache)
   473  	if ok {
   474  		return nil
   475  	}
   476  	return s.cache
   477  }