github.com/status-im/status-go@v1.1.0/protocol/common/raw_messages_persistence.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"database/sql"
     8  	"encoding/gob"
     9  	"errors"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/status-im/status-go/eth-node/crypto"
    14  	"github.com/status-im/status-go/eth-node/types"
    15  	"github.com/status-im/status-go/protocol/protobuf"
    16  )
    17  
    18  type RawMessageConfirmation struct {
    19  	// DataSyncID is the ID of the datasync message sent
    20  	DataSyncID []byte
    21  	// MessageID is the message id of the message
    22  	MessageID []byte
    23  	// PublicKey is the compressed receiver public key
    24  	PublicKey []byte
    25  	// ConfirmedAt is the unix timestamp in seconds of when the message was confirmed
    26  	ConfirmedAt int64
    27  }
    28  
    29  type RawMessagesPersistence struct {
    30  	db *sql.DB
    31  }
    32  
    33  func NewRawMessagesPersistence(db *sql.DB) *RawMessagesPersistence {
    34  	return &RawMessagesPersistence{db: db}
    35  }
    36  
    37  func (db RawMessagesPersistence) SaveRawMessage(message *RawMessage) error {
    38  	tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
    39  	if err != nil {
    40  		return err
    41  	}
    42  	defer func() {
    43  		if err == nil {
    44  			err = tx.Commit()
    45  			return
    46  		}
    47  		// don't shadow original error
    48  		_ = tx.Rollback()
    49  	}()
    50  
    51  	var pubKeys [][]byte
    52  	for _, pk := range message.Recipients {
    53  		pubKeys = append(pubKeys, crypto.CompressPubkey(pk))
    54  	}
    55  	// Encode recipients
    56  	var encodedRecipients bytes.Buffer
    57  	encoder := gob.NewEncoder(&encodedRecipients)
    58  
    59  	if err := encoder.Encode(pubKeys); err != nil {
    60  		return err
    61  	}
    62  
    63  	// If the message is not sent, we check whether there's a record
    64  	// in the database already and preserve the state
    65  	if !message.Sent {
    66  		oldMessage, err := db.rawMessageByID(tx, message.ID)
    67  		if err != nil && err != sql.ErrNoRows {
    68  			return err
    69  		}
    70  		if oldMessage != nil {
    71  			message.Sent = oldMessage.Sent
    72  		}
    73  	}
    74  	var sender []byte
    75  	if message.Sender != nil {
    76  		sender = crypto.FromECDSA(message.Sender)
    77  	}
    78  	_, err = tx.Exec(`
    79  		 INSERT INTO
    80  		 raw_messages
    81  		 (
    82  		   id,
    83  		   local_chat_id,
    84  		   last_sent,
    85  		   send_count,
    86  		   sent,
    87  		   message_type,
    88  		   recipients,
    89  		   skip_encryption,
    90  	       send_push_notification,
    91  		   skip_group_message_wrap,
    92  		   send_on_personal_topic,
    93  		   payload,
    94  		   sender,
    95  		   community_id,
    96  		   resend_type,
    97  		   pubsub_topic,
    98  		   hash_ratchet_group_id,
    99  		   community_key_ex_msg_type,
   100  		   resend_method
   101  		)
   102  		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
   103  		message.ID,
   104  		message.LocalChatID,
   105  		message.LastSent,
   106  		message.SendCount,
   107  		message.Sent,
   108  		message.MessageType,
   109  		encodedRecipients.Bytes(),
   110  		message.SkipEncryptionLayer,
   111  		message.SendPushNotification,
   112  		message.SkipGroupMessageWrap,
   113  		message.SendOnPersonalTopic,
   114  		message.Payload,
   115  		sender,
   116  		message.CommunityID,
   117  		message.ResendType,
   118  		message.PubsubTopic,
   119  		message.HashRatchetGroupID,
   120  		message.CommunityKeyExMsgType,
   121  		message.ResendMethod,
   122  	)
   123  	return err
   124  }
   125  
   126  func (db RawMessagesPersistence) RawMessageByID(id string) (*RawMessage, error) {
   127  	tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	defer func() {
   132  		if err == nil {
   133  			err = tx.Commit()
   134  			return
   135  		}
   136  		// don't shadow original error
   137  		_ = tx.Rollback()
   138  	}()
   139  
   140  	return db.rawMessageByID(tx, id)
   141  }
   142  
   143  func (db RawMessagesPersistence) rawMessageByID(tx *sql.Tx, id string) (*RawMessage, error) {
   144  	var rawPubKeys [][]byte
   145  	var encodedRecipients []byte
   146  	var skipGroupMessageWrap, sendOnPersonalTopic sql.NullBool
   147  	var sender []byte
   148  	message := &RawMessage{}
   149  
   150  	err := tx.QueryRow(`
   151  			SELECT
   152  			  id,
   153  			  local_chat_id,
   154  			  last_sent,
   155  			  send_count,
   156  			  sent,
   157  			  message_type,
   158  			  recipients,
   159  			  skip_encryption,
   160  		      send_push_notification,
   161  			  skip_group_message_wrap,
   162  			  send_on_personal_topic,
   163  		      payload,
   164  			  sender,
   165  			  community_id,
   166  			  resend_type,
   167  			  pubsub_topic,
   168  			  hash_ratchet_group_id,
   169  			  community_key_ex_msg_type,
   170  			  resend_method
   171  			FROM
   172  				raw_messages
   173  			WHERE
   174  				id = ?`,
   175  		id,
   176  	).Scan(
   177  		&message.ID,
   178  		&message.LocalChatID,
   179  		&message.LastSent,
   180  		&message.SendCount,
   181  		&message.Sent,
   182  		&message.MessageType,
   183  		&encodedRecipients,
   184  		&message.SkipEncryptionLayer,
   185  		&message.SendPushNotification,
   186  		&skipGroupMessageWrap,
   187  		&sendOnPersonalTopic,
   188  		&message.Payload,
   189  		&sender,
   190  		&message.CommunityID,
   191  		&message.ResendType,
   192  		&message.PubsubTopic,
   193  		&message.HashRatchetGroupID,
   194  		&message.CommunityKeyExMsgType,
   195  		&message.ResendMethod,
   196  	)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	if encodedRecipients != nil {
   202  		// Restore recipients
   203  		decoder := gob.NewDecoder(bytes.NewBuffer(encodedRecipients))
   204  		err = decoder.Decode(&rawPubKeys)
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  		for _, pkBytes := range rawPubKeys {
   209  			pubkey, err := crypto.DecompressPubkey(pkBytes)
   210  			if err != nil {
   211  				return nil, err
   212  			}
   213  			message.Recipients = append(message.Recipients, pubkey)
   214  		}
   215  	}
   216  
   217  	if skipGroupMessageWrap.Valid {
   218  		message.SkipGroupMessageWrap = skipGroupMessageWrap.Bool
   219  	}
   220  
   221  	if sendOnPersonalTopic.Valid {
   222  		message.SendOnPersonalTopic = sendOnPersonalTopic.Bool
   223  	}
   224  
   225  	if sender != nil {
   226  		message.Sender, err = crypto.ToECDSA(sender)
   227  		if err != nil {
   228  			return nil, err
   229  		}
   230  	}
   231  	return message, nil
   232  }
   233  
   234  func (db RawMessagesPersistence) RawMessagesIDsByType(t protobuf.ApplicationMetadataMessage_Type) ([]string, error) {
   235  	ids := []string{}
   236  
   237  	rows, err := db.db.Query(`
   238  			SELECT
   239  			  id
   240  			FROM
   241  				raw_messages
   242  			WHERE
   243  			message_type = ?`,
   244  		t)
   245  	if err != nil {
   246  		return ids, err
   247  	}
   248  	defer rows.Close()
   249  
   250  	for rows.Next() {
   251  		var id string
   252  		if err := rows.Scan(&id); err != nil {
   253  			return ids, err
   254  		}
   255  		ids = append(ids, id)
   256  	}
   257  
   258  	return ids, nil
   259  }
   260  
   261  // MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns
   262  // the messageIDs that can be considered confirmed.
   263  // If atLeastOne is set it will return messageid if at least once of the messages
   264  // sent has been confirmed
   265  func (db RawMessagesPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID types.HexBytes, err error) {
   266  	tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	defer func() {
   271  		if err == nil {
   272  			err = tx.Commit()
   273  			return
   274  		}
   275  		// don't shadow original error
   276  		_ = tx.Rollback()
   277  	}()
   278  
   279  	confirmedAt := time.Now().Unix()
   280  	_, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID)
   281  	if err != nil {
   282  		return
   283  	}
   284  
   285  	// Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed
   286  	rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID)
   287  	if err != nil {
   288  		return
   289  	}
   290  	defer rows.Close()
   291  
   292  	confirmedResult := true
   293  
   294  	for rows.Next() {
   295  		var confirmedAt int64
   296  		err = rows.Scan(&messageID, &confirmedAt)
   297  		if err != nil {
   298  			return
   299  		}
   300  		confirmed := confirmedAt > 0
   301  
   302  		if atLeastOne && confirmed {
   303  			// We return, as at least one was confirmed
   304  			return
   305  		}
   306  
   307  		confirmedResult = confirmedResult && confirmed
   308  	}
   309  
   310  	if !confirmedResult {
   311  		messageID = nil
   312  		return
   313  	}
   314  
   315  	return
   316  }
   317  
   318  func (db RawMessagesPersistence) InsertPendingConfirmation(confirmation *RawMessageConfirmation) error {
   319  
   320  	_, err := db.db.Exec(`INSERT INTO raw_message_confirmations
   321  		 (datasync_id, message_id, public_key)
   322  		 VALUES
   323  		 (?,?,?)`,
   324  		confirmation.DataSyncID,
   325  		confirmation.MessageID,
   326  		confirmation.PublicKey,
   327  	)
   328  	return err
   329  }
   330  
   331  func (db RawMessagesPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.Message) error {
   332  	_, err := db.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.TTL, m.Timestamp, types.TopicTypeToByteArray(m.Topic), m.Payload, m.Dst, m.P2P, m.Padding, groupID, keyID)
   333  	return err
   334  }
   335  
   336  func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.Message, error) {
   337  	var messages []*types.Message
   338  
   339  	rows, err := db.db.Query(`SELECT hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	for rows.Next() {
   345  		var topic []byte
   346  		message := &types.Message{}
   347  
   348  		err := rows.Scan(&message.Hash, &message.Sig, &message.TTL, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.P2P, &message.Padding)
   349  		if err != nil {
   350  			return nil, err
   351  		}
   352  
   353  		message.Topic = types.BytesToTopic(topic)
   354  		messages = append(messages, message)
   355  	}
   356  
   357  	return messages, nil
   358  }
   359  
   360  func (db RawMessagesPersistence) GetHashRatchetMessagesCountForGroup(groupID []byte) (int, error) {
   361  	var count int
   362  	err := db.db.QueryRow(`SELECT count(*) FROM hash_ratchet_encrypted_messages WHERE group_id = ?`, groupID).Scan(&count)
   363  	if err == nil {
   364  		return count, nil
   365  	}
   366  	if errors.Is(err, sql.ErrNoRows) {
   367  		return 0, nil
   368  	}
   369  	return 0, err
   370  }
   371  
   372  func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
   373  	if len(ids) == 0 {
   374  		return nil
   375  	}
   376  
   377  	idsArgs := make([]interface{}, 0, len(ids))
   378  	for _, id := range ids {
   379  		idsArgs = append(idsArgs, id)
   380  	}
   381  	inVector := strings.Repeat("?, ", len(ids)-1) + "?"
   382  
   383  	_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec
   384  
   385  	return err
   386  }
   387  
   388  func (db *RawMessagesPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error {
   389  	_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp)
   390  	return err
   391  }
   392  
   393  func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) {
   394  	var alreadyCompleted int
   395  	err := db.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted)
   396  	if err != nil {
   397  		return false, err
   398  	}
   399  	return alreadyCompleted > 0, nil
   400  }
   401  
   402  func (db *RawMessagesPersistence) SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
   403  	sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
   404  
   405  	_, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
   406  		segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp)
   407  
   408  	return err
   409  }
   410  
   411  // Get ordered message segments for given hash
   412  func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) {
   413  	sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
   414  
   415  	rows, err := db.db.Query(`
   416  		SELECT
   417  			hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload
   418  		FROM
   419  			message_segments
   420  		WHERE
   421  			hash = ? AND sig_pub_key = ?
   422  		ORDER BY
   423  			(segments_count = 0) ASC, -- Prioritize segments_count > 0
   424  			segment_index ASC,
   425  			parity_segment_index ASC`,
   426  		hash, sigPubKeyBlob)
   427  	if err != nil {
   428  		return nil, err
   429  	}
   430  	defer rows.Close()
   431  
   432  	var segments []*SegmentMessage
   433  	for rows.Next() {
   434  		segment := &SegmentMessage{
   435  			SegmentMessage: &protobuf.SegmentMessage{},
   436  		}
   437  		err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload)
   438  		if err != nil {
   439  			return nil, err
   440  		}
   441  		segments = append(segments, segment)
   442  	}
   443  	err = rows.Err()
   444  	if err != nil {
   445  		return nil, err
   446  	}
   447  
   448  	return segments, nil
   449  }
   450  
   451  func (db *RawMessagesPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error {
   452  	_, err := db.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp)
   453  	return err
   454  }
   455  
   456  func (db *RawMessagesPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
   457  	tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
   458  	if err != nil {
   459  		return err
   460  	}
   461  
   462  	defer func() {
   463  		if err == nil {
   464  			err = tx.Commit()
   465  			return
   466  		}
   467  		// don't shadow original error
   468  		_ = tx.Rollback()
   469  	}()
   470  
   471  	sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
   472  
   473  	_, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob)
   474  	if err != nil {
   475  		return err
   476  	}
   477  
   478  	_, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp)
   479  	if err != nil {
   480  		return err
   481  	}
   482  
   483  	return err
   484  }
   485  
   486  func (db *RawMessagesPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error {
   487  	_, err := db.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp)
   488  	return err
   489  }
   490  
   491  func (db RawMessagesPersistence) UpdateRawMessageSent(id string, sent bool) error {
   492  	_, err := db.db.Exec("UPDATE raw_messages SET sent = ? WHERE id = ?", sent, id)
   493  	return err
   494  }
   495  
   496  func (db RawMessagesPersistence) UpdateRawMessageLastSent(id string, lastSent uint64) error {
   497  	_, err := db.db.Exec("UPDATE raw_messages SET last_sent = ? WHERE id = ?", lastSent, id)
   498  	return err
   499  }