github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/conversations.go (about)

     1  package common
     2  
     3  import (
     4  	"errors"
     5  	"time"
     6  
     7  	//"log"
     8  
     9  	"database/sql"
    10  	"strconv"
    11  
    12  	qgen "github.com/Azareal/Gosora/query_gen"
    13  )
    14  
    15  var Convos ConversationStore
    16  var convoStmts ConvoStmts
    17  
    18  type ConvoStmts struct {
    19  	fetchPost  *sql.Stmt
    20  	getPosts   *sql.Stmt
    21  	countPosts *sql.Stmt
    22  	edit       *sql.Stmt
    23  	create     *sql.Stmt
    24  	delete     *sql.Stmt
    25  	has        *sql.Stmt
    26  
    27  	editPost   *sql.Stmt
    28  	createPost *sql.Stmt
    29  	deletePost *sql.Stmt
    30  
    31  	getUsers *sql.Stmt
    32  }
    33  
    34  func init() {
    35  	DbInits.Add(func(acc *qgen.Accumulator) error {
    36  		cpo := "conversations_posts"
    37  		convoStmts = ConvoStmts{
    38  			fetchPost:  acc.Select(cpo).Columns("cid,body,post,createdBy").Where("pid=?").Prepare(),
    39  			getPosts:   acc.Select(cpo).Columns("pid,body,post,createdBy").Where("cid=?").Limit("?,?").Prepare(),
    40  			countPosts: acc.Count(cpo).Where("cid=?").Prepare(),
    41  			edit:       acc.Update("conversations").Set("lastReplyBy=?,lastReplyAt=?").Where("cid=?").Prepare(),
    42  			create:     acc.Insert("conversations").Columns("createdAt,lastReplyAt").Fields("UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(),
    43  			has:        acc.Count("conversations_participants").Where("uid=? AND cid=?").Prepare(),
    44  
    45  			editPost:   acc.Update(cpo).Set("body=?,post=?").Where("pid=?").Prepare(),
    46  			createPost: acc.Insert(cpo).Columns("cid,body,post,createdBy").Fields("?,?,?,?").Prepare(),
    47  			deletePost: acc.Delete(cpo).Where("pid=?").Prepare(),
    48  
    49  			getUsers: acc.Select("conversations_participants").Columns("uid").Where("cid=?").Prepare(),
    50  		}
    51  		return acc.FirstError()
    52  	})
    53  }
    54  
    55  type Conversation struct {
    56  	ID          int
    57  	Link        string
    58  	CreatedBy   int
    59  	CreatedAt   time.Time
    60  	LastReplyBy int
    61  	LastReplyAt time.Time
    62  }
    63  
    64  func (co *Conversation) Posts(offset, itemsPerPage int) (posts []*ConversationPost, err error) {
    65  	rows, err := convoStmts.getPosts.Query(co.ID, offset, itemsPerPage)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	defer rows.Close()
    70  
    71  	for rows.Next() {
    72  		p := &ConversationPost{CID: co.ID}
    73  		err := rows.Scan(&p.ID, &p.Body, &p.Post, &p.CreatedBy)
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  		p, err = ConvoPostProcess.OnLoad(p)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		posts = append(posts, p)
    82  	}
    83  
    84  	return posts, rows.Err()
    85  }
    86  
    87  func (co *Conversation) PostsCount() (count int) {
    88  	return Countf(convoStmts.countPosts, co.ID)
    89  }
    90  
    91  func (co *Conversation) Uids() (ids []int, err error) {
    92  	rows, e := convoStmts.getUsers.Query(co.ID)
    93  	if e != nil {
    94  		return nil, e
    95  	}
    96  	defer rows.Close()
    97  	for rows.Next() {
    98  		var id int
    99  		if e := rows.Scan(&id); e != nil {
   100  			return nil, e
   101  		}
   102  		ids = append(ids, id)
   103  	}
   104  	return ids, rows.Err()
   105  }
   106  
   107  func (co *Conversation) Has(uid int) (in bool) {
   108  	return Countf(convoStmts.has, uid, co.ID) > 0
   109  }
   110  
   111  func (co *Conversation) Update() error {
   112  	_, err := convoStmts.edit.Exec(co.CreatedAt, co.LastReplyBy, co.LastReplyAt, co.ID)
   113  	return err
   114  }
   115  
   116  func (co *Conversation) Create() (int, error) {
   117  	res, err := convoStmts.create.Exec()
   118  	if err != nil {
   119  		return 0, err
   120  	}
   121  
   122  	lastID, err := res.LastInsertId()
   123  	return int(lastID), err
   124  }
   125  
   126  func BuildConvoURL(coid int) string {
   127  	return "/user/convo/" + strconv.Itoa(coid)
   128  }
   129  
   130  type ConversationExtra struct {
   131  	*Conversation
   132  	Users []*User
   133  }
   134  
   135  type ConversationStore interface {
   136  	Get(id int) (*Conversation, error)
   137  	GetUser(uid, offset int) (cos []*Conversation, err error)
   138  	GetUserExtra(uid, offset int) (cos []*ConversationExtra, err error)
   139  	GetUserCount(uid int) (count int)
   140  	Delete(id int) error
   141  	Count() (count int)
   142  	Create(content string, createdBy int, participants []int) (int, error)
   143  }
   144  
   145  type DefaultConversationStore struct {
   146  	get                *sql.Stmt
   147  	getUser            *sql.Stmt
   148  	getUserCount       *sql.Stmt
   149  	delete             *sql.Stmt
   150  	deletePosts        *sql.Stmt
   151  	deleteParticipants *sql.Stmt
   152  	create             *sql.Stmt
   153  	addParticipant     *sql.Stmt
   154  	count              *sql.Stmt
   155  }
   156  
   157  func NewDefaultConversationStore(acc *qgen.Accumulator) (*DefaultConversationStore, error) {
   158  	co := "conversations"
   159  	return &DefaultConversationStore{
   160  		get:                acc.Select(co).Columns("createdBy,createdAt,lastReplyBy,lastReplyAt").Where("cid=?").Prepare(),
   161  		getUser:            acc.SimpleInnerJoin("conversations_participants AS cp", "conversations AS c", "cp.cid, c.createdBy, c.createdAt, c.lastReplyBy, c.lastReplyAt", "cp.cid=c.cid", "cp.uid=?", "c.lastReplyAt DESC, c.createdAt DESC, c.cid DESC", "?,?"),
   162  		getUserCount:       acc.Count("conversations_participants").Where("uid=?").Prepare(),
   163  		delete:             acc.Delete(co).Where("cid=?").Prepare(),
   164  		deletePosts:        acc.Delete("conversations_posts").Where("cid=?").Prepare(),
   165  		deleteParticipants: acc.Delete("conversations_participants").Where("cid=?").Prepare(),
   166  		create:             acc.Insert(co).Columns("createdBy,createdAt,lastReplyBy,lastReplyAt").Fields("?,UTC_TIMESTAMP(),?,UTC_TIMESTAMP()").Prepare(),
   167  		addParticipant:     acc.Insert("conversations_participants").Columns("uid,cid").Fields("?,?").Prepare(),
   168  		count:              acc.Count(co).Prepare(),
   169  	}, acc.FirstError()
   170  }
   171  
   172  func (s *DefaultConversationStore) Get(id int) (*Conversation, error) {
   173  	co := &Conversation{ID: id}
   174  	err := s.get.QueryRow(id).Scan(&co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt)
   175  	co.Link = BuildConvoURL(co.ID)
   176  	return co, err
   177  }
   178  
   179  func (s *DefaultConversationStore) GetUser(uid, offset int) (cos []*Conversation, err error) {
   180  	rows, err := s.getUser.Query(uid, offset, Config.ItemsPerPage)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  	defer rows.Close()
   185  
   186  	for rows.Next() {
   187  		co := &Conversation{}
   188  		err := rows.Scan(&co.ID, &co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt)
   189  		if err != nil {
   190  			return nil, err
   191  		}
   192  		co.Link = BuildConvoURL(co.ID)
   193  		cos = append(cos, co)
   194  	}
   195  	err = rows.Err()
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	if len(cos) == 0 {
   200  		err = sql.ErrNoRows
   201  	}
   202  	return cos, err
   203  }
   204  
   205  func (s *DefaultConversationStore) GetUserExtra(uid, offset int) (cos []*ConversationExtra, err error) {
   206  	raw, err := s.GetUser(uid, offset)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  	//log.Printf("raw: %+v\n", raw)
   211  
   212  	if len(raw) == 1 {
   213  		//log.Print("r0b2")
   214  		uids, err := raw[0].Uids()
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  		//log.Println("r1b2")
   219  		umap, err := Users.BulkGetMap(uids)
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		//log.Println("r2b2")
   224  		users := make([]*User, len(umap))
   225  		var i int
   226  		for _, user := range umap {
   227  			users[i] = user
   228  			i++
   229  		}
   230  		return []*ConversationExtra{{raw[0], users}}, nil
   231  	}
   232  	//log.Println("1")
   233  
   234  	cmap := make(map[int]*ConversationExtra, len(raw))
   235  	for _, co := range raw {
   236  		cmap[co.ID] = &ConversationExtra{co, nil}
   237  	}
   238  
   239  	// TODO: Use inqbuild for this or a similar function
   240  	var q string
   241  	idList := make([]interface{}, len(raw))
   242  	for i, co := range raw {
   243  		if i == 0 {
   244  			q = "?"
   245  		} else {
   246  			q += ",?"
   247  		}
   248  		idList[i] = strconv.Itoa(co.ID)
   249  	}
   250  
   251  	rows, err := qgen.NewAcc().Select("conversations_participants").Columns("uid,cid").Where("cid IN(" + q + ")").Query(idList...)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  	defer rows.Close()
   256  	//log.Println("2")
   257  
   258  	idmap := make(map[int][]int) // cid: []uid
   259  	puidmap := make(map[int]struct{})
   260  	for rows.Next() {
   261  		var uid, cid int
   262  		err := rows.Scan(&uid, &cid)
   263  		if err != nil {
   264  			return nil, err
   265  		}
   266  		idmap[cid] = append(idmap[cid], uid)
   267  		puidmap[uid] = struct{}{}
   268  	}
   269  	if err = rows.Err(); err != nil {
   270  		return nil, err
   271  	}
   272  	//log.Println("3")
   273  	//log.Printf("idmap: %+v\n", idmap)
   274  	//log.Printf("puidmap: %+v\n",puidmap)
   275  
   276  	puids := make([]int, len(puidmap))
   277  	var i int
   278  	for puid, _ := range puidmap {
   279  		puids[i] = puid
   280  		i++
   281  	}
   282  	umap, err := Users.BulkGetMap(puids)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	//log.Println("4")
   287  	//log.Printf("umap: %+v\n", umap)
   288  	for cid, uids := range idmap {
   289  		co := cmap[cid]
   290  		for _, uid := range uids {
   291  			co.Users = append(co.Users, umap[uid])
   292  		}
   293  		//log.Printf("co.Conversation: %+v\n", co.Conversation)
   294  		//log.Printf("co.Users: %+v\n", co.Users)
   295  		cmap[cid] = co
   296  	}
   297  	//log.Printf("cmap: %+v\n", cmap)
   298  	for _, ra := range raw {
   299  		cos = append(cos, cmap[ra.ID])
   300  	}
   301  	//log.Printf("cos: %+v\n", cos)
   302  
   303  	return cos, rows.Err()
   304  }
   305  
   306  func (s *DefaultConversationStore) GetUserCount(uid int) (count int) {
   307  	err := s.getUserCount.QueryRow(uid).Scan(&count)
   308  	if err != nil {
   309  		LogError(err)
   310  	}
   311  	return count
   312  }
   313  
   314  // TODO: Use a foreign key or transaction
   315  func (s *DefaultConversationStore) Delete(id int) error {
   316  	_, err := s.delete.Exec(id)
   317  	if err != nil {
   318  		return err
   319  	}
   320  	_, err = s.deletePosts.Exec(id)
   321  	if err != nil {
   322  		return err
   323  	}
   324  	_, err = s.deleteParticipants.Exec(id)
   325  	return err
   326  }
   327  
   328  func (s *DefaultConversationStore) Create(content string, createdBy int, participants []int) (int, error) {
   329  	if len(participants) == 0 {
   330  		return 0, errors.New("no participants set")
   331  	}
   332  	res, err := s.create.Exec(createdBy, createdBy)
   333  	if err != nil {
   334  		return 0, err
   335  	}
   336  	lastID, err := res.LastInsertId()
   337  	if err != nil {
   338  		return 0, err
   339  	}
   340  
   341  	post := &ConversationPost{CID: int(lastID), Body: content, CreatedBy: createdBy}
   342  	_, err = post.Create()
   343  	if err != nil {
   344  		return 0, err
   345  	}
   346  
   347  	for _, p := range participants {
   348  		if p == createdBy {
   349  			continue
   350  		}
   351  		_, err := s.addParticipant.Exec(p, lastID)
   352  		if err != nil {
   353  			return 0, err
   354  		}
   355  	}
   356  	_, err = s.addParticipant.Exec(createdBy, lastID)
   357  	if err != nil {
   358  		return 0, err
   359  	}
   360  
   361  	return int(lastID), err
   362  }
   363  
   364  // Count returns the total number of topics on these forums
   365  func (s *DefaultConversationStore) Count() (count int) {
   366  	err := s.count.QueryRow().Scan(&count)
   367  	if err != nil {
   368  		LogError(err)
   369  	}
   370  	return count
   371  }