github.com/status-im/status-go@v1.1.0/protocol/common/raw_messages_persistence.go (about) 1 package common 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "encoding/gob" 9 "errors" 10 "strings" 11 "time" 12 13 "github.com/status-im/status-go/eth-node/crypto" 14 "github.com/status-im/status-go/eth-node/types" 15 "github.com/status-im/status-go/protocol/protobuf" 16 ) 17 18 type RawMessageConfirmation struct { 19 // DataSyncID is the ID of the datasync message sent 20 DataSyncID []byte 21 // MessageID is the message id of the message 22 MessageID []byte 23 // PublicKey is the compressed receiver public key 24 PublicKey []byte 25 // ConfirmedAt is the unix timestamp in seconds of when the message was confirmed 26 ConfirmedAt int64 27 } 28 29 type RawMessagesPersistence struct { 30 db *sql.DB 31 } 32 33 func NewRawMessagesPersistence(db *sql.DB) *RawMessagesPersistence { 34 return &RawMessagesPersistence{db: db} 35 } 36 37 func (db RawMessagesPersistence) SaveRawMessage(message *RawMessage) error { 38 tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) 39 if err != nil { 40 return err 41 } 42 defer func() { 43 if err == nil { 44 err = tx.Commit() 45 return 46 } 47 // don't shadow original error 48 _ = tx.Rollback() 49 }() 50 51 var pubKeys [][]byte 52 for _, pk := range message.Recipients { 53 pubKeys = append(pubKeys, crypto.CompressPubkey(pk)) 54 } 55 // Encode recipients 56 var encodedRecipients bytes.Buffer 57 encoder := gob.NewEncoder(&encodedRecipients) 58 59 if err := encoder.Encode(pubKeys); err != nil { 60 return err 61 } 62 63 // If the message is not sent, we check whether there's a record 64 // in the database already and preserve the state 65 if !message.Sent { 66 oldMessage, err := db.rawMessageByID(tx, message.ID) 67 if err != nil && err != sql.ErrNoRows { 68 return err 69 } 70 if oldMessage != nil { 71 message.Sent = oldMessage.Sent 72 } 73 } 74 var sender []byte 75 if message.Sender != nil { 76 sender = crypto.FromECDSA(message.Sender) 77 } 78 _, err = tx.Exec(` 79 INSERT INTO 80 raw_messages 81 ( 82 id, 83 local_chat_id, 84 last_sent, 85 send_count, 86 sent, 87 message_type, 88 recipients, 89 skip_encryption, 90 send_push_notification, 91 skip_group_message_wrap, 92 send_on_personal_topic, 93 payload, 94 sender, 95 community_id, 96 resend_type, 97 pubsub_topic, 98 hash_ratchet_group_id, 99 community_key_ex_msg_type, 100 resend_method 101 ) 102 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 103 message.ID, 104 message.LocalChatID, 105 message.LastSent, 106 message.SendCount, 107 message.Sent, 108 message.MessageType, 109 encodedRecipients.Bytes(), 110 message.SkipEncryptionLayer, 111 message.SendPushNotification, 112 message.SkipGroupMessageWrap, 113 message.SendOnPersonalTopic, 114 message.Payload, 115 sender, 116 message.CommunityID, 117 message.ResendType, 118 message.PubsubTopic, 119 message.HashRatchetGroupID, 120 message.CommunityKeyExMsgType, 121 message.ResendMethod, 122 ) 123 return err 124 } 125 126 func (db RawMessagesPersistence) RawMessageByID(id string) (*RawMessage, error) { 127 tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) 128 if err != nil { 129 return nil, err 130 } 131 defer func() { 132 if err == nil { 133 err = tx.Commit() 134 return 135 } 136 // don't shadow original error 137 _ = tx.Rollback() 138 }() 139 140 return db.rawMessageByID(tx, id) 141 } 142 143 func (db RawMessagesPersistence) rawMessageByID(tx *sql.Tx, id string) (*RawMessage, error) { 144 var rawPubKeys [][]byte 145 var encodedRecipients []byte 146 var skipGroupMessageWrap, sendOnPersonalTopic sql.NullBool 147 var sender []byte 148 message := &RawMessage{} 149 150 err := tx.QueryRow(` 151 SELECT 152 id, 153 local_chat_id, 154 last_sent, 155 send_count, 156 sent, 157 message_type, 158 recipients, 159 skip_encryption, 160 send_push_notification, 161 skip_group_message_wrap, 162 send_on_personal_topic, 163 payload, 164 sender, 165 community_id, 166 resend_type, 167 pubsub_topic, 168 hash_ratchet_group_id, 169 community_key_ex_msg_type, 170 resend_method 171 FROM 172 raw_messages 173 WHERE 174 id = ?`, 175 id, 176 ).Scan( 177 &message.ID, 178 &message.LocalChatID, 179 &message.LastSent, 180 &message.SendCount, 181 &message.Sent, 182 &message.MessageType, 183 &encodedRecipients, 184 &message.SkipEncryptionLayer, 185 &message.SendPushNotification, 186 &skipGroupMessageWrap, 187 &sendOnPersonalTopic, 188 &message.Payload, 189 &sender, 190 &message.CommunityID, 191 &message.ResendType, 192 &message.PubsubTopic, 193 &message.HashRatchetGroupID, 194 &message.CommunityKeyExMsgType, 195 &message.ResendMethod, 196 ) 197 if err != nil { 198 return nil, err 199 } 200 201 if encodedRecipients != nil { 202 // Restore recipients 203 decoder := gob.NewDecoder(bytes.NewBuffer(encodedRecipients)) 204 err = decoder.Decode(&rawPubKeys) 205 if err != nil { 206 return nil, err 207 } 208 for _, pkBytes := range rawPubKeys { 209 pubkey, err := crypto.DecompressPubkey(pkBytes) 210 if err != nil { 211 return nil, err 212 } 213 message.Recipients = append(message.Recipients, pubkey) 214 } 215 } 216 217 if skipGroupMessageWrap.Valid { 218 message.SkipGroupMessageWrap = skipGroupMessageWrap.Bool 219 } 220 221 if sendOnPersonalTopic.Valid { 222 message.SendOnPersonalTopic = sendOnPersonalTopic.Bool 223 } 224 225 if sender != nil { 226 message.Sender, err = crypto.ToECDSA(sender) 227 if err != nil { 228 return nil, err 229 } 230 } 231 return message, nil 232 } 233 234 func (db RawMessagesPersistence) RawMessagesIDsByType(t protobuf.ApplicationMetadataMessage_Type) ([]string, error) { 235 ids := []string{} 236 237 rows, err := db.db.Query(` 238 SELECT 239 id 240 FROM 241 raw_messages 242 WHERE 243 message_type = ?`, 244 t) 245 if err != nil { 246 return ids, err 247 } 248 defer rows.Close() 249 250 for rows.Next() { 251 var id string 252 if err := rows.Scan(&id); err != nil { 253 return ids, err 254 } 255 ids = append(ids, id) 256 } 257 258 return ids, nil 259 } 260 261 // MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns 262 // the messageIDs that can be considered confirmed. 263 // If atLeastOne is set it will return messageid if at least once of the messages 264 // sent has been confirmed 265 func (db RawMessagesPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID types.HexBytes, err error) { 266 tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) 267 if err != nil { 268 return nil, err 269 } 270 defer func() { 271 if err == nil { 272 err = tx.Commit() 273 return 274 } 275 // don't shadow original error 276 _ = tx.Rollback() 277 }() 278 279 confirmedAt := time.Now().Unix() 280 _, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID) 281 if err != nil { 282 return 283 } 284 285 // Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed 286 rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID) 287 if err != nil { 288 return 289 } 290 defer rows.Close() 291 292 confirmedResult := true 293 294 for rows.Next() { 295 var confirmedAt int64 296 err = rows.Scan(&messageID, &confirmedAt) 297 if err != nil { 298 return 299 } 300 confirmed := confirmedAt > 0 301 302 if atLeastOne && confirmed { 303 // We return, as at least one was confirmed 304 return 305 } 306 307 confirmedResult = confirmedResult && confirmed 308 } 309 310 if !confirmedResult { 311 messageID = nil 312 return 313 } 314 315 return 316 } 317 318 func (db RawMessagesPersistence) InsertPendingConfirmation(confirmation *RawMessageConfirmation) error { 319 320 _, err := db.db.Exec(`INSERT INTO raw_message_confirmations 321 (datasync_id, message_id, public_key) 322 VALUES 323 (?,?,?)`, 324 confirmation.DataSyncID, 325 confirmation.MessageID, 326 confirmation.PublicKey, 327 ) 328 return err 329 } 330 331 func (db RawMessagesPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.Message) error { 332 _, err := db.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.TTL, m.Timestamp, types.TopicTypeToByteArray(m.Topic), m.Payload, m.Dst, m.P2P, m.Padding, groupID, keyID) 333 return err 334 } 335 336 func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.Message, error) { 337 var messages []*types.Message 338 339 rows, err := db.db.Query(`SELECT hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID) 340 if err != nil { 341 return nil, err 342 } 343 344 for rows.Next() { 345 var topic []byte 346 message := &types.Message{} 347 348 err := rows.Scan(&message.Hash, &message.Sig, &message.TTL, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.P2P, &message.Padding) 349 if err != nil { 350 return nil, err 351 } 352 353 message.Topic = types.BytesToTopic(topic) 354 messages = append(messages, message) 355 } 356 357 return messages, nil 358 } 359 360 func (db RawMessagesPersistence) GetHashRatchetMessagesCountForGroup(groupID []byte) (int, error) { 361 var count int 362 err := db.db.QueryRow(`SELECT count(*) FROM hash_ratchet_encrypted_messages WHERE group_id = ?`, groupID).Scan(&count) 363 if err == nil { 364 return count, nil 365 } 366 if errors.Is(err, sql.ErrNoRows) { 367 return 0, nil 368 } 369 return 0, err 370 } 371 372 func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error { 373 if len(ids) == 0 { 374 return nil 375 } 376 377 idsArgs := make([]interface{}, 0, len(ids)) 378 for _, id := range ids { 379 idsArgs = append(idsArgs, id) 380 } 381 inVector := strings.Repeat("?, ", len(ids)-1) + "?" 382 383 _, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec 384 385 return err 386 } 387 388 func (db *RawMessagesPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error { 389 _, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp) 390 return err 391 } 392 393 func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) { 394 var alreadyCompleted int 395 err := db.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted) 396 if err != nil { 397 return false, err 398 } 399 return alreadyCompleted > 0, nil 400 } 401 402 func (db *RawMessagesPersistence) SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error { 403 sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) 404 405 _, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", 406 segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp) 407 408 return err 409 } 410 411 // Get ordered message segments for given hash 412 func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) { 413 sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) 414 415 rows, err := db.db.Query(` 416 SELECT 417 hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload 418 FROM 419 message_segments 420 WHERE 421 hash = ? AND sig_pub_key = ? 422 ORDER BY 423 (segments_count = 0) ASC, -- Prioritize segments_count > 0 424 segment_index ASC, 425 parity_segment_index ASC`, 426 hash, sigPubKeyBlob) 427 if err != nil { 428 return nil, err 429 } 430 defer rows.Close() 431 432 var segments []*SegmentMessage 433 for rows.Next() { 434 segment := &SegmentMessage{ 435 SegmentMessage: &protobuf.SegmentMessage{}, 436 } 437 err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload) 438 if err != nil { 439 return nil, err 440 } 441 segments = append(segments, segment) 442 } 443 err = rows.Err() 444 if err != nil { 445 return nil, err 446 } 447 448 return segments, nil 449 } 450 451 func (db *RawMessagesPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error { 452 _, err := db.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp) 453 return err 454 } 455 456 func (db *RawMessagesPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error { 457 tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) 458 if err != nil { 459 return err 460 } 461 462 defer func() { 463 if err == nil { 464 err = tx.Commit() 465 return 466 } 467 // don't shadow original error 468 _ = tx.Rollback() 469 }() 470 471 sigPubKeyBlob := crypto.CompressPubkey(sigPubKey) 472 473 _, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob) 474 if err != nil { 475 return err 476 } 477 478 _, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp) 479 if err != nil { 480 return err 481 } 482 483 return err 484 } 485 486 func (db *RawMessagesPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error { 487 _, err := db.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp) 488 return err 489 } 490 491 func (db RawMessagesPersistence) UpdateRawMessageSent(id string, sent bool) error { 492 _, err := db.db.Exec("UPDATE raw_messages SET sent = ? WHERE id = ?", sent, id) 493 return err 494 } 495 496 func (db RawMessagesPersistence) UpdateRawMessageLastSent(id string, lastSent uint64) error { 497 _, err := db.db.Exec("UPDATE raw_messages SET last_sent = ? WHERE id = ?", lastSent, id) 498 return err 499 }