github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/conversations.go (about) 1 package common 2 3 import ( 4 "errors" 5 "time" 6 7 //"log" 8 9 "database/sql" 10 "strconv" 11 12 qgen "github.com/Azareal/Gosora/query_gen" 13 ) 14 15 var Convos ConversationStore 16 var convoStmts ConvoStmts 17 18 type ConvoStmts struct { 19 fetchPost *sql.Stmt 20 getPosts *sql.Stmt 21 countPosts *sql.Stmt 22 edit *sql.Stmt 23 create *sql.Stmt 24 delete *sql.Stmt 25 has *sql.Stmt 26 27 editPost *sql.Stmt 28 createPost *sql.Stmt 29 deletePost *sql.Stmt 30 31 getUsers *sql.Stmt 32 } 33 34 func init() { 35 DbInits.Add(func(acc *qgen.Accumulator) error { 36 cpo := "conversations_posts" 37 convoStmts = ConvoStmts{ 38 fetchPost: acc.Select(cpo).Columns("cid,body,post,createdBy").Where("pid=?").Prepare(), 39 getPosts: acc.Select(cpo).Columns("pid,body,post,createdBy").Where("cid=?").Limit("?,?").Prepare(), 40 countPosts: acc.Count(cpo).Where("cid=?").Prepare(), 41 edit: acc.Update("conversations").Set("lastReplyBy=?,lastReplyAt=?").Where("cid=?").Prepare(), 42 create: acc.Insert("conversations").Columns("createdAt,lastReplyAt").Fields("UTC_TIMESTAMP(),UTC_TIMESTAMP()").Prepare(), 43 has: acc.Count("conversations_participants").Where("uid=? AND cid=?").Prepare(), 44 45 editPost: acc.Update(cpo).Set("body=?,post=?").Where("pid=?").Prepare(), 46 createPost: acc.Insert(cpo).Columns("cid,body,post,createdBy").Fields("?,?,?,?").Prepare(), 47 deletePost: acc.Delete(cpo).Where("pid=?").Prepare(), 48 49 getUsers: acc.Select("conversations_participants").Columns("uid").Where("cid=?").Prepare(), 50 } 51 return acc.FirstError() 52 }) 53 } 54 55 type Conversation struct { 56 ID int 57 Link string 58 CreatedBy int 59 CreatedAt time.Time 60 LastReplyBy int 61 LastReplyAt time.Time 62 } 63 64 func (co *Conversation) Posts(offset, itemsPerPage int) (posts []*ConversationPost, err error) { 65 rows, err := convoStmts.getPosts.Query(co.ID, offset, itemsPerPage) 66 if err != nil { 67 return nil, err 68 } 69 defer rows.Close() 70 71 for rows.Next() { 72 p := &ConversationPost{CID: co.ID} 73 err := rows.Scan(&p.ID, &p.Body, &p.Post, &p.CreatedBy) 74 if err != nil { 75 return nil, err 76 } 77 p, err = ConvoPostProcess.OnLoad(p) 78 if err != nil { 79 return nil, err 80 } 81 posts = append(posts, p) 82 } 83 84 return posts, rows.Err() 85 } 86 87 func (co *Conversation) PostsCount() (count int) { 88 return Countf(convoStmts.countPosts, co.ID) 89 } 90 91 func (co *Conversation) Uids() (ids []int, err error) { 92 rows, e := convoStmts.getUsers.Query(co.ID) 93 if e != nil { 94 return nil, e 95 } 96 defer rows.Close() 97 for rows.Next() { 98 var id int 99 if e := rows.Scan(&id); e != nil { 100 return nil, e 101 } 102 ids = append(ids, id) 103 } 104 return ids, rows.Err() 105 } 106 107 func (co *Conversation) Has(uid int) (in bool) { 108 return Countf(convoStmts.has, uid, co.ID) > 0 109 } 110 111 func (co *Conversation) Update() error { 112 _, err := convoStmts.edit.Exec(co.CreatedAt, co.LastReplyBy, co.LastReplyAt, co.ID) 113 return err 114 } 115 116 func (co *Conversation) Create() (int, error) { 117 res, err := convoStmts.create.Exec() 118 if err != nil { 119 return 0, err 120 } 121 122 lastID, err := res.LastInsertId() 123 return int(lastID), err 124 } 125 126 func BuildConvoURL(coid int) string { 127 return "/user/convo/" + strconv.Itoa(coid) 128 } 129 130 type ConversationExtra struct { 131 *Conversation 132 Users []*User 133 } 134 135 type ConversationStore interface { 136 Get(id int) (*Conversation, error) 137 GetUser(uid, offset int) (cos []*Conversation, err error) 138 GetUserExtra(uid, offset int) (cos []*ConversationExtra, err error) 139 GetUserCount(uid int) (count int) 140 Delete(id int) error 141 Count() (count int) 142 Create(content string, createdBy int, participants []int) (int, error) 143 } 144 145 type DefaultConversationStore struct { 146 get *sql.Stmt 147 getUser *sql.Stmt 148 getUserCount *sql.Stmt 149 delete *sql.Stmt 150 deletePosts *sql.Stmt 151 deleteParticipants *sql.Stmt 152 create *sql.Stmt 153 addParticipant *sql.Stmt 154 count *sql.Stmt 155 } 156 157 func NewDefaultConversationStore(acc *qgen.Accumulator) (*DefaultConversationStore, error) { 158 co := "conversations" 159 return &DefaultConversationStore{ 160 get: acc.Select(co).Columns("createdBy,createdAt,lastReplyBy,lastReplyAt").Where("cid=?").Prepare(), 161 getUser: acc.SimpleInnerJoin("conversations_participants AS cp", "conversations AS c", "cp.cid, c.createdBy, c.createdAt, c.lastReplyBy, c.lastReplyAt", "cp.cid=c.cid", "cp.uid=?", "c.lastReplyAt DESC, c.createdAt DESC, c.cid DESC", "?,?"), 162 getUserCount: acc.Count("conversations_participants").Where("uid=?").Prepare(), 163 delete: acc.Delete(co).Where("cid=?").Prepare(), 164 deletePosts: acc.Delete("conversations_posts").Where("cid=?").Prepare(), 165 deleteParticipants: acc.Delete("conversations_participants").Where("cid=?").Prepare(), 166 create: acc.Insert(co).Columns("createdBy,createdAt,lastReplyBy,lastReplyAt").Fields("?,UTC_TIMESTAMP(),?,UTC_TIMESTAMP()").Prepare(), 167 addParticipant: acc.Insert("conversations_participants").Columns("uid,cid").Fields("?,?").Prepare(), 168 count: acc.Count(co).Prepare(), 169 }, acc.FirstError() 170 } 171 172 func (s *DefaultConversationStore) Get(id int) (*Conversation, error) { 173 co := &Conversation{ID: id} 174 err := s.get.QueryRow(id).Scan(&co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt) 175 co.Link = BuildConvoURL(co.ID) 176 return co, err 177 } 178 179 func (s *DefaultConversationStore) GetUser(uid, offset int) (cos []*Conversation, err error) { 180 rows, err := s.getUser.Query(uid, offset, Config.ItemsPerPage) 181 if err != nil { 182 return nil, err 183 } 184 defer rows.Close() 185 186 for rows.Next() { 187 co := &Conversation{} 188 err := rows.Scan(&co.ID, &co.CreatedBy, &co.CreatedAt, &co.LastReplyBy, &co.LastReplyAt) 189 if err != nil { 190 return nil, err 191 } 192 co.Link = BuildConvoURL(co.ID) 193 cos = append(cos, co) 194 } 195 err = rows.Err() 196 if err != nil { 197 return nil, err 198 } 199 if len(cos) == 0 { 200 err = sql.ErrNoRows 201 } 202 return cos, err 203 } 204 205 func (s *DefaultConversationStore) GetUserExtra(uid, offset int) (cos []*ConversationExtra, err error) { 206 raw, err := s.GetUser(uid, offset) 207 if err != nil { 208 return nil, err 209 } 210 //log.Printf("raw: %+v\n", raw) 211 212 if len(raw) == 1 { 213 //log.Print("r0b2") 214 uids, err := raw[0].Uids() 215 if err != nil { 216 return nil, err 217 } 218 //log.Println("r1b2") 219 umap, err := Users.BulkGetMap(uids) 220 if err != nil { 221 return nil, err 222 } 223 //log.Println("r2b2") 224 users := make([]*User, len(umap)) 225 var i int 226 for _, user := range umap { 227 users[i] = user 228 i++ 229 } 230 return []*ConversationExtra{{raw[0], users}}, nil 231 } 232 //log.Println("1") 233 234 cmap := make(map[int]*ConversationExtra, len(raw)) 235 for _, co := range raw { 236 cmap[co.ID] = &ConversationExtra{co, nil} 237 } 238 239 // TODO: Use inqbuild for this or a similar function 240 var q string 241 idList := make([]interface{}, len(raw)) 242 for i, co := range raw { 243 if i == 0 { 244 q = "?" 245 } else { 246 q += ",?" 247 } 248 idList[i] = strconv.Itoa(co.ID) 249 } 250 251 rows, err := qgen.NewAcc().Select("conversations_participants").Columns("uid,cid").Where("cid IN(" + q + ")").Query(idList...) 252 if err != nil { 253 return nil, err 254 } 255 defer rows.Close() 256 //log.Println("2") 257 258 idmap := make(map[int][]int) // cid: []uid 259 puidmap := make(map[int]struct{}) 260 for rows.Next() { 261 var uid, cid int 262 err := rows.Scan(&uid, &cid) 263 if err != nil { 264 return nil, err 265 } 266 idmap[cid] = append(idmap[cid], uid) 267 puidmap[uid] = struct{}{} 268 } 269 if err = rows.Err(); err != nil { 270 return nil, err 271 } 272 //log.Println("3") 273 //log.Printf("idmap: %+v\n", idmap) 274 //log.Printf("puidmap: %+v\n",puidmap) 275 276 puids := make([]int, len(puidmap)) 277 var i int 278 for puid, _ := range puidmap { 279 puids[i] = puid 280 i++ 281 } 282 umap, err := Users.BulkGetMap(puids) 283 if err != nil { 284 return nil, err 285 } 286 //log.Println("4") 287 //log.Printf("umap: %+v\n", umap) 288 for cid, uids := range idmap { 289 co := cmap[cid] 290 for _, uid := range uids { 291 co.Users = append(co.Users, umap[uid]) 292 } 293 //log.Printf("co.Conversation: %+v\n", co.Conversation) 294 //log.Printf("co.Users: %+v\n", co.Users) 295 cmap[cid] = co 296 } 297 //log.Printf("cmap: %+v\n", cmap) 298 for _, ra := range raw { 299 cos = append(cos, cmap[ra.ID]) 300 } 301 //log.Printf("cos: %+v\n", cos) 302 303 return cos, rows.Err() 304 } 305 306 func (s *DefaultConversationStore) GetUserCount(uid int) (count int) { 307 err := s.getUserCount.QueryRow(uid).Scan(&count) 308 if err != nil { 309 LogError(err) 310 } 311 return count 312 } 313 314 // TODO: Use a foreign key or transaction 315 func (s *DefaultConversationStore) Delete(id int) error { 316 _, err := s.delete.Exec(id) 317 if err != nil { 318 return err 319 } 320 _, err = s.deletePosts.Exec(id) 321 if err != nil { 322 return err 323 } 324 _, err = s.deleteParticipants.Exec(id) 325 return err 326 } 327 328 func (s *DefaultConversationStore) Create(content string, createdBy int, participants []int) (int, error) { 329 if len(participants) == 0 { 330 return 0, errors.New("no participants set") 331 } 332 res, err := s.create.Exec(createdBy, createdBy) 333 if err != nil { 334 return 0, err 335 } 336 lastID, err := res.LastInsertId() 337 if err != nil { 338 return 0, err 339 } 340 341 post := &ConversationPost{CID: int(lastID), Body: content, CreatedBy: createdBy} 342 _, err = post.Create() 343 if err != nil { 344 return 0, err 345 } 346 347 for _, p := range participants { 348 if p == createdBy { 349 continue 350 } 351 _, err := s.addParticipant.Exec(p, lastID) 352 if err != nil { 353 return 0, err 354 } 355 } 356 _, err = s.addParticipant.Exec(createdBy, lastID) 357 if err != nil { 358 return 0, err 359 } 360 361 return int(lastID), err 362 } 363 364 // Count returns the total number of topics on these forums 365 func (s *DefaultConversationStore) Count() (count int) { 366 err := s.count.QueryRow().Scan(&count) 367 if err != nil { 368 LogError(err) 369 } 370 return count 371 }