github.com/status-im/status-go@v1.1.0/protocol/peersyncing/sync_message_persistence.go (about)

     1  package peersyncing
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"strings"
     8  )
     9  
    10  type SyncMessagePersistence interface {
    11  	Add(SyncMessage) error
    12  	All() ([]SyncMessage, error)
    13  	Complement([]SyncMessage) ([]SyncMessage, error)
    14  	ByChatIDs([][]byte, int) ([]SyncMessage, error)
    15  	ByMessageIDs([][]byte) ([]SyncMessage, error)
    16  }
    17  
    18  type SyncMessageSQLitePersistence struct {
    19  	db *sql.DB
    20  }
    21  
    22  func NewSyncMessageSQLitePersistence(db *sql.DB) *SyncMessageSQLitePersistence {
    23  	return &SyncMessageSQLitePersistence{db: db}
    24  }
    25  
    26  func (p *SyncMessageSQLitePersistence) Add(message SyncMessage) error {
    27  	if err := message.Valid(); err != nil {
    28  		return err
    29  	}
    30  	_, err := p.db.Exec(`INSERT INTO peersyncing_messages (id, type, chat_id, payload, timestamp) VALUES (?, ?, ?, ?, ?)`, message.ID, message.Type, message.ChatID, message.Payload, message.Timestamp)
    31  	return err
    32  }
    33  
    34  func (p *SyncMessageSQLitePersistence) All() ([]SyncMessage, error) {
    35  	var messages []SyncMessage
    36  	rows, err := p.db.Query(`SELECT id, type, chat_id, payload, timestamp FROM peersyncing_messages`)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	defer rows.Close()
    42  
    43  	for rows.Next() {
    44  		var m SyncMessage
    45  
    46  		err := rows.Scan(&m.ID, &m.Type, &m.ChatID, &m.Payload, &m.Timestamp)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  
    51  		messages = append(messages, m)
    52  	}
    53  	return messages, nil
    54  }
    55  
    56  func (p *SyncMessageSQLitePersistence) ByChatIDs(ids [][]byte, limit int) ([]SyncMessage, error) {
    57  	if len(ids) == 0 {
    58  		return nil, nil
    59  	}
    60  
    61  	queryArgs := make([]interface{}, 0, len(ids))
    62  	for _, id := range ids {
    63  		queryArgs = append(queryArgs, id)
    64  	}
    65  	queryArgs = append(queryArgs, limit)
    66  
    67  	inVector := strings.Repeat("?, ", len(ids)-1) + "?"
    68  	query := "SELECT id, type, chat_id, payload, timestamp FROM peersyncing_messages WHERE chat_id IN (" + inVector + ") ORDER BY timestamp DESC LIMIT ?" // nolint: gosec
    69  
    70  	var messages []SyncMessage
    71  	rows, err := p.db.Query(query, queryArgs...)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	defer rows.Close()
    77  
    78  	for rows.Next() {
    79  		var m SyncMessage
    80  
    81  		err := rows.Scan(&m.ID, &m.Type, &m.ChatID, &m.Payload, &m.Timestamp)
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  
    86  		messages = append(messages, m)
    87  	}
    88  	return messages, nil
    89  }
    90  
    91  func (p *SyncMessageSQLitePersistence) Complement(messages []SyncMessage) ([]SyncMessage, error) {
    92  	if len(messages) == 0 {
    93  		return nil, nil
    94  	}
    95  
    96  	ids := make([]interface{}, 0, len(messages))
    97  	for _, m := range messages {
    98  		ids = append(ids, m.ID)
    99  	}
   100  
   101  	inVector := strings.Repeat("?, ", len(ids)-1) + "?"
   102  	query := "SELECT id, type, chat_id, payload, timestamp FROM peersyncing_messages WHERE id IN (" + inVector + ")" // nolint: gosec
   103  
   104  	availableMessages := make(map[string]SyncMessage)
   105  	rows, err := p.db.Query(query, ids...)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	defer rows.Close()
   111  
   112  	for rows.Next() {
   113  		var m SyncMessage
   114  
   115  		err := rows.Scan(&m.ID, &m.Type, &m.ChatID, &m.Payload, &m.Timestamp)
   116  		if err != nil {
   117  			return nil, err
   118  		}
   119  
   120  		fmt.Printf("GOT MESSAGE: %x\n", m.ID)
   121  		availableMessages[hex.EncodeToString(m.ID)] = m
   122  	}
   123  
   124  	var complement []SyncMessage
   125  	for _, m := range messages {
   126  		fmt.Printf("CHECKING MESSAGE: %x\n", m.ID)
   127  		if _, ok := availableMessages[hex.EncodeToString(m.ID)]; !ok {
   128  			complement = append(complement, m)
   129  		}
   130  	}
   131  
   132  	return complement, nil
   133  }
   134  
   135  func (p *SyncMessageSQLitePersistence) ByMessageIDs(ids [][]byte) ([]SyncMessage, error) {
   136  	if len(ids) == 0 {
   137  		return nil, nil
   138  	}
   139  
   140  	queryArgs := make([]interface{}, 0, len(ids))
   141  	for _, id := range ids {
   142  		queryArgs = append(queryArgs, id)
   143  	}
   144  
   145  	inVector := strings.Repeat("?, ", len(ids)-1) + "?"
   146  	query := "SELECT id, type, chat_id, payload, timestamp FROM peersyncing_messages WHERE id IN (" + inVector + ")" // nolint: gosec
   147  
   148  	var messages []SyncMessage
   149  	rows, err := p.db.Query(query, queryArgs...)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	defer rows.Close()
   155  
   156  	for rows.Next() {
   157  		var m SyncMessage
   158  
   159  		err := rows.Scan(&m.ID, &m.Type, &m.ChatID, &m.Payload, &m.Timestamp)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  
   164  		messages = append(messages, m)
   165  	}
   166  	return messages, nil
   167  
   168  }