github.com/code-to-go/safepool.lib@v0.0.0-20221205180519-ee25e63c226e/api/chat/chat.go (about)

     1  package chat
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"path"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/code-to-go/safepool.lib/core"
    14  	pool "github.com/code-to-go/safepool.lib/pool"
    15  	"github.com/code-to-go/safepool.lib/security"
    16  	"github.com/godruoyi/go-snowflake"
    17  	"github.com/sirupsen/logrus"
    18  )
    19  
    20  type Message struct {
    21  	Id          uint64    `json:"id,string"`
    22  	Author      string    `json:"author"`
    23  	Time        time.Time `json:"time"`
    24  	Content     string    `json:"content"`
    25  	ContentType string    `json:"contentType"`
    26  	Attachments [][]byte  `json:"attachments"`
    27  	Signature   []byte    `json:"signature"`
    28  }
    29  
    30  func getHash(m *Message) []byte {
    31  	h := security.NewHash()
    32  	h.Write([]byte(m.Content))
    33  	h.Write([]byte(m.ContentType))
    34  	h.Write([]byte(m.Author))
    35  	for _, a := range m.Attachments {
    36  		h.Write(a)
    37  	}
    38  	return h.Sum(nil)
    39  }
    40  
    41  type Chat struct {
    42  	Pool *pool.Pool
    43  }
    44  
    45  func Get(p *pool.Pool) Chat {
    46  	return Chat{
    47  		Pool: p,
    48  	}
    49  }
    50  
    51  func (c *Chat) TimeOffset(s *pool.Pool) time.Time {
    52  	return sqlGetOffset(s.Name)
    53  }
    54  
    55  func (c *Chat) Accept(s *pool.Pool, head pool.Head) bool {
    56  	name := head.Name
    57  	if !strings.HasPrefix(name, "/chat/") || !strings.HasSuffix(name, ".chat") || head.Size > 10*1024*1024 {
    58  		return false
    59  	}
    60  	name = path.Base(name)
    61  	id, err := strconv.ParseInt(name[0:len(name)-5], 10, 64)
    62  	if err != nil {
    63  		return false
    64  	}
    65  
    66  	buf := bytes.Buffer{}
    67  	err = s.Get(head.Id, nil, &buf)
    68  	if core.IsErr(err, "cannot read %s from %s: %v", head.Name, s.Name) {
    69  		return true
    70  	}
    71  
    72  	var m Message
    73  	err = json.Unmarshal(buf.Bytes(), &m)
    74  	if core.IsErr(err, "invalid chat message %s: %v", head.Name) {
    75  		return true
    76  	}
    77  
    78  	h := getHash(&m)
    79  	if !security.Verify(m.Author, h, m.Signature) {
    80  		logrus.Error("message %s has invalid signature", head.Name)
    81  		return true
    82  	}
    83  
    84  	err = sqlSetMessage(s.Name, uint64(id), m.Author, m, head.TimeStamp)
    85  	core.IsErr(err, "cannot write message %s to db:%v", head.Name)
    86  	return true
    87  }
    88  
    89  func (c *Chat) SendMessage(content string, contentType string, attachments [][]byte) (uint64, error) {
    90  	m := Message{
    91  		Id:          snowflake.ID(),
    92  		Author:      c.Pool.Self.Id(),
    93  		Time:        time.Now(),
    94  		Content:     content,
    95  		ContentType: contentType,
    96  		Attachments: attachments,
    97  	}
    98  	h := getHash(&m)
    99  	signature, err := security.Sign(c.Pool.Self, h)
   100  	if core.IsErr(err, "cannot sign chat message: %v") {
   101  		return 0, err
   102  	}
   103  	m.Signature = signature
   104  
   105  	data, err := json.Marshal(m)
   106  	if core.IsErr(err, "cannot sign chat message: %v") {
   107  		return 0, err
   108  	}
   109  
   110  	go func() {
   111  		name := fmt.Sprintf("/chat/%d.chat", m.Id)
   112  		_, err = c.Pool.Post(name, bytes.NewBuffer(data), nil)
   113  		core.IsErr(err, "cannot write chat message: %v")
   114  	}()
   115  
   116  	err = sqlSetMessage(c.Pool.Name, m.Id, c.Pool.Self.Id(), m, time.Now())
   117  	if core.IsErr(err, "cannot save message to db: %v") {
   118  		return 0, err
   119  	}
   120  
   121  	core.Info("added chat message with id %d", m.Id)
   122  	return m.Id, nil
   123  }
   124  
   125  func (c *Chat) GetMessages(afterId, beforeId uint64, limit int) ([]Message, error) {
   126  	messages, err := sqlGetMessages(c.Pool.Name, afterId, beforeId, limit)
   127  	if core.IsErr(err, "cannot read messages from db: %v") {
   128  		return nil, err
   129  	}
   130  
   131  	sort.Slice(messages, func(i, j int) bool {
   132  		return messages[i].Id < messages[j].Id
   133  	})
   134  	return messages, nil
   135  }