github.com/status-im/status-go@v1.1.0/protocol/pushnotificationclient/persistence.go (about)

     1  package pushnotificationclient
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"database/sql"
     8  	"encoding/gob"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/golang/protobuf/proto"
    13  
    14  	"github.com/status-im/status-go/eth-node/crypto"
    15  	"github.com/status-im/status-go/protocol/protobuf"
    16  )
    17  
    18  type Persistence struct {
    19  	db *sql.DB
    20  }
    21  
    22  func NewPersistence(db *sql.DB) *Persistence {
    23  	return &Persistence{db: db}
    24  }
    25  
    26  func (p *Persistence) GetLastPushNotificationRegistration() (*protobuf.PushNotificationRegistration, []*ecdsa.PublicKey, error) {
    27  	var registrationBytes []byte
    28  	var contactIDsBytes []byte
    29  	err := p.db.QueryRow(`SELECT registration,contact_ids FROM push_notification_client_registrations LIMIT 1`).Scan(&registrationBytes, &contactIDsBytes)
    30  	if err == sql.ErrNoRows {
    31  		return nil, nil, nil
    32  	} else if err != nil {
    33  		return nil, nil, err
    34  	}
    35  
    36  	var publicKeyBytes [][]byte
    37  	var contactIDs []*ecdsa.PublicKey
    38  	// Restore contactIDs
    39  	contactIDsDecoder := gob.NewDecoder(bytes.NewBuffer(contactIDsBytes))
    40  	err = contactIDsDecoder.Decode(&publicKeyBytes)
    41  	if err != nil {
    42  		return nil, nil, err
    43  	}
    44  	for _, pkBytes := range publicKeyBytes {
    45  		pk, err := crypto.DecompressPubkey(pkBytes)
    46  		if err != nil {
    47  			return nil, nil, err
    48  		}
    49  		contactIDs = append(contactIDs, pk)
    50  	}
    51  
    52  	registration := &protobuf.PushNotificationRegistration{}
    53  
    54  	err = proto.Unmarshal(registrationBytes, registration)
    55  	if err != nil {
    56  		return nil, nil, err
    57  	}
    58  
    59  	return registration, contactIDs, nil
    60  }
    61  
    62  func (p *Persistence) SaveLastPushNotificationRegistration(registration *protobuf.PushNotificationRegistration, contactIDs []*ecdsa.PublicKey) error {
    63  	var encodedContactIDs bytes.Buffer
    64  	var contactIDsBytes [][]byte
    65  	for _, pk := range contactIDs {
    66  		contactIDsBytes = append(contactIDsBytes, crypto.CompressPubkey(pk))
    67  	}
    68  	pkEncoder := gob.NewEncoder(&encodedContactIDs)
    69  	if err := pkEncoder.Encode(contactIDsBytes); err != nil {
    70  		return err
    71  	}
    72  
    73  	marshaledRegistration, err := proto.Marshal(registration)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	_, err = p.db.Exec(`INSERT INTO push_notification_client_registrations (registration,contact_ids) VALUES (?, ?)`, marshaledRegistration, encodedContactIDs.Bytes())
    78  	return err
    79  }
    80  
    81  func (p *Persistence) TrackPushNotification(chatID string, messageID []byte) error {
    82  	trackedAt := time.Now().Unix()
    83  	_, err := p.db.Exec(`INSERT INTO push_notification_client_tracked_messages (chat_id, message_id, tracked_at) VALUES (?,?,?)`, chatID, messageID, trackedAt)
    84  	return err
    85  }
    86  
    87  func (p *Persistence) TrackedMessage(messageID []byte) (bool, error) {
    88  	var count uint64
    89  	err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count)
    90  	if err != nil {
    91  		return false, err
    92  	}
    93  
    94  	if count == 0 {
    95  		return false, nil
    96  	}
    97  
    98  	return true, nil
    99  }
   100  
   101  func (p *Persistence) SavePushNotificationQuery(publicKey *ecdsa.PublicKey, queryID []byte) error {
   102  	queriedAt := time.Now().Unix()
   103  	_, err := p.db.Exec(`INSERT INTO push_notification_client_queries (public_key, query_id, queried_at) VALUES (?,?,?)`, crypto.CompressPubkey(publicKey), queryID, queriedAt)
   104  	return err
   105  }
   106  
   107  func (p *Persistence) GetQueriedAt(publicKey *ecdsa.PublicKey) (int64, error) {
   108  	var queriedAt int64
   109  	err := p.db.QueryRow(`SELECT queried_at FROM push_notification_client_queries WHERE public_key = ? ORDER BY queried_at DESC LIMIT 1`, crypto.CompressPubkey(publicKey)).Scan(&queriedAt)
   110  	if err == sql.ErrNoRows {
   111  		return 0, nil
   112  	}
   113  	if err != nil {
   114  		return 0, err
   115  	}
   116  
   117  	return queriedAt, nil
   118  }
   119  
   120  func (p *Persistence) GetQueryPublicKey(queryID []byte) (*ecdsa.PublicKey, error) {
   121  	var publicKeyBytes []byte
   122  	err := p.db.QueryRow(`SELECT public_key FROM push_notification_client_queries WHERE query_id = ?`, queryID).Scan(&publicKeyBytes)
   123  	if err == sql.ErrNoRows {
   124  		return nil, nil
   125  	}
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	publicKey, err := crypto.DecompressPubkey(publicKeyBytes)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	return publicKey, nil
   135  }
   136  
   137  func (p *Persistence) SavePushNotificationInfo(infos []*PushNotificationInfo) error {
   138  	tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{})
   139  	defer func() {
   140  		if err == nil {
   141  			err = tx.Commit()
   142  			return
   143  		}
   144  		// don't shadow original error
   145  		_ = tx.Rollback()
   146  	}()
   147  	for _, info := range infos {
   148  		var latestVersion uint64
   149  		clientCompressedKey := crypto.CompressPubkey(info.PublicKey)
   150  		err := tx.QueryRow(`SELECT IFNULL(MAX(version),0) FROM push_notification_client_info WHERE public_key = ? AND installation_id = ? LIMIT 1`, clientCompressedKey, info.InstallationID).Scan(&latestVersion)
   151  		if err != sql.ErrNoRows && err != nil {
   152  			return err
   153  		}
   154  		if latestVersion > info.Version {
   155  			// Nothing to do
   156  			continue
   157  		}
   158  
   159  		// Remove anything that as a lower version
   160  		_, err = tx.Exec(`DELETE FROM push_notification_client_info WHERE public_key = ? AND installation_id = ? AND version < ?`, clientCompressedKey, info.InstallationID, info.Version)
   161  		if err != nil {
   162  			return err
   163  		}
   164  		// Insert
   165  		_, err = tx.Exec(`INSERT INTO push_notification_client_info (public_key, server_public_key, installation_id, access_token, retrieved_at, version) VALUES (?, ?, ?, ?, ?,?)`, clientCompressedKey, crypto.CompressPubkey(info.ServerPublicKey), info.InstallationID, info.AccessToken, info.RetrievedAt, info.Version)
   166  		if err != nil {
   167  			return err
   168  		}
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (p *Persistence) GetPushNotificationInfo(publicKey *ecdsa.PublicKey, installationIDs []string) ([]*PushNotificationInfo, error) {
   175  	queryArgs := make([]interface{}, 0, len(installationIDs)+1)
   176  	queryArgs = append(queryArgs, crypto.CompressPubkey(publicKey))
   177  	for _, installationID := range installationIDs {
   178  		queryArgs = append(queryArgs, installationID)
   179  	}
   180  
   181  	inVector := strings.Repeat("?, ", len(installationIDs)-1) + "?"
   182  
   183  	rows, err := p.db.Query(`SELECT server_public_key, installation_id, version, access_token, retrieved_at FROM push_notification_client_info WHERE public_key = ? AND installation_id IN (`+inVector+`)`, queryArgs...) //nolint: gosec
   184  
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	defer rows.Close()
   189  
   190  	var infos []*PushNotificationInfo
   191  	for rows.Next() {
   192  		var serverPublicKeyBytes []byte
   193  		info := &PushNotificationInfo{PublicKey: publicKey}
   194  		err := rows.Scan(&serverPublicKeyBytes, &info.InstallationID, &info.Version, &info.AccessToken, &info.RetrievedAt)
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  
   199  		serverPublicKey, err := crypto.DecompressPubkey(serverPublicKeyBytes)
   200  		if err != nil {
   201  			return nil, err
   202  		}
   203  
   204  		info.ServerPublicKey = serverPublicKey
   205  		infos = append(infos, info)
   206  	}
   207  
   208  	return infos, nil
   209  }
   210  
   211  func (p *Persistence) GetPushNotificationInfoByPublicKey(publicKey *ecdsa.PublicKey) ([]*PushNotificationInfo, error) {
   212  	rows, err := p.db.Query(`SELECT server_public_key, installation_id, access_token, retrieved_at FROM push_notification_client_info WHERE public_key = ?`, crypto.CompressPubkey(publicKey))
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	defer rows.Close()
   217  
   218  	var infos []*PushNotificationInfo
   219  	for rows.Next() {
   220  		var serverPublicKeyBytes []byte
   221  		info := &PushNotificationInfo{PublicKey: publicKey}
   222  		err := rows.Scan(&serverPublicKeyBytes, &info.InstallationID, &info.AccessToken, &info.RetrievedAt)
   223  		if err != nil {
   224  			return nil, err
   225  		}
   226  
   227  		serverPublicKey, err := crypto.DecompressPubkey(serverPublicKeyBytes)
   228  		if err != nil {
   229  			return nil, err
   230  		}
   231  
   232  		info.ServerPublicKey = serverPublicKey
   233  		infos = append(infos, info)
   234  	}
   235  
   236  	return infos, nil
   237  }
   238  
   239  func (p *Persistence) ShouldSendNotificationFor(publicKey *ecdsa.PublicKey, installationID string, messageID []byte) (bool, error) {
   240  	// First we check that we are tracking this message, next we check that we haven't already sent this
   241  	var count uint64
   242  	err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count)
   243  	if err != nil {
   244  		return false, err
   245  	}
   246  
   247  	if count == 0 {
   248  		return false, nil
   249  	}
   250  
   251  	err = p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_sent_notifications WHERE message_id = ? AND public_key = ? AND installation_id = ? `, messageID, crypto.CompressPubkey(publicKey), installationID).Scan(&count)
   252  	if err != nil {
   253  		return false, err
   254  	}
   255  
   256  	return count == 0, nil
   257  }
   258  
   259  func (p *Persistence) ShouldSendNotificationToAllInstallationIDs(publicKey *ecdsa.PublicKey, messageID []byte) (bool, error) {
   260  	// First we check that we are tracking this message, next we check that we haven't already sent this
   261  	var count uint64
   262  	err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count)
   263  	if err != nil {
   264  		return false, err
   265  	}
   266  
   267  	if count == 0 {
   268  		return false, nil
   269  	}
   270  
   271  	err = p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_sent_notifications WHERE message_id = ? AND public_key = ? `, messageID, crypto.CompressPubkey(publicKey)).Scan(&count)
   272  	if err != nil {
   273  		return false, err
   274  	}
   275  
   276  	return count == 0, nil
   277  }
   278  
   279  func (p *Persistence) UpsertSentNotification(n *SentNotification) error {
   280  	_, err := p.db.Exec(`INSERT INTO push_notification_client_sent_notifications (public_key, installation_id, message_id, last_tried_at, retry_count, success, error, hashed_public_key,chat_id, notification_type) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, crypto.CompressPubkey(n.PublicKey), n.InstallationID, n.MessageID, n.LastTriedAt, n.RetryCount, n.Success, n.Error, n.HashedPublicKey(), n.ChatID, n.NotificationType)
   281  	return err
   282  }
   283  
   284  func (p *Persistence) GetSentNotification(hashedPublicKey []byte, installationID string, messageID []byte) (*SentNotification, error) {
   285  	var publicKeyBytes []byte
   286  	sentNotification := &SentNotification{
   287  		InstallationID: installationID,
   288  		MessageID:      messageID,
   289  	}
   290  	err := p.db.QueryRow(`SELECT retry_count, last_tried_at, error, success, public_key,chat_id,notification_type FROM push_notification_client_sent_notifications WHERE hashed_public_key = ?`, hashedPublicKey).Scan(&sentNotification.RetryCount, &sentNotification.LastTriedAt, &sentNotification.Error, &sentNotification.Success, &publicKeyBytes, &sentNotification.ChatID, &sentNotification.NotificationType)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	publicKey, err := crypto.DecompressPubkey(publicKeyBytes)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	sentNotification.PublicKey = publicKey
   301  
   302  	return sentNotification, nil
   303  }
   304  
   305  func (p *Persistence) UpdateNotificationResponse(messageID []byte, response *protobuf.PushNotificationReport) error {
   306  	_, err := p.db.Exec(`UPDATE push_notification_client_sent_notifications SET success = ?, error = ? WHERE hashed_public_key = ? AND installation_id = ? AND message_id = ? AND NOT success`, response.Success, response.Error, response.PublicKey, response.InstallationId, messageID)
   307  	return err
   308  }
   309  
   310  func (p *Persistence) GetRetriablePushNotifications() ([]*SentNotification, error) {
   311  	var notifications []*SentNotification
   312  	rows, err := p.db.Query(`SELECT retry_count, last_tried_at, error, success, public_key, installation_id, message_id,chat_id, notification_type FROM push_notification_client_sent_notifications WHERE NOT success AND error = ? AND retry_count <= ?`, protobuf.PushNotificationReport_WRONG_TOKEN, maxPushNotificationRetries)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  	defer rows.Close()
   317  
   318  	for rows.Next() {
   319  		var publicKeyBytes []byte
   320  		notification := &SentNotification{}
   321  		err = rows.Scan(&notification.RetryCount, &notification.LastTriedAt, &notification.Error, &notification.Success, &publicKeyBytes, &notification.InstallationID, &notification.MessageID, &notification.ChatID, &notification.NotificationType)
   322  		if err != nil {
   323  			return nil, err
   324  		}
   325  		publicKey, err := crypto.DecompressPubkey(publicKeyBytes)
   326  		if err != nil {
   327  			return nil, err
   328  		}
   329  		notification.PublicKey = publicKey
   330  		notifications = append(notifications, notification)
   331  	}
   332  	return notifications, err
   333  }
   334  
   335  func (p *Persistence) UpsertServer(server *PushNotificationServer) error {
   336  	_, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at, access_token, last_retried_at, retry_count, server_type) VALUES (?,?,?,?,?,?,?)`, crypto.CompressPubkey(server.PublicKey), server.Registered, server.RegisteredAt, server.AccessToken, server.LastRetriedAt, server.RetryCount, server.Type)
   337  	return err
   338  
   339  }
   340  
   341  func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
   342  	rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token,last_retried_at, retry_count, server_type FROM push_notification_client_servers`)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  	defer rows.Close()
   347  
   348  	var servers []*PushNotificationServer
   349  	for rows.Next() {
   350  		server := &PushNotificationServer{}
   351  		var key []byte
   352  		err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken, &server.LastRetriedAt, &server.RetryCount, &server.Type)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  		parsedKey, err := crypto.DecompressPubkey(key)
   357  		if err != nil {
   358  			return nil, err
   359  		}
   360  		server.PublicKey = parsedKey
   361  		servers = append(servers, server)
   362  	}
   363  	return servers, nil
   364  }
   365  
   366  func (p *Persistence) GetServersByPublicKey(keys []*ecdsa.PublicKey) ([]*PushNotificationServer, error) {
   367  
   368  	keyArgs := make([]interface{}, 0, len(keys))
   369  	for _, key := range keys {
   370  		keyArgs = append(keyArgs, crypto.CompressPubkey(key))
   371  	}
   372  
   373  	inVector := strings.Repeat("?, ", len(keys)-1) + "?"
   374  	rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers WHERE public_key IN (`+inVector+")", keyArgs...) //nolint: gosec
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  	defer rows.Close()
   379  
   380  	var servers []*PushNotificationServer
   381  	for rows.Next() {
   382  		server := &PushNotificationServer{}
   383  		var key []byte
   384  		err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken)
   385  		if err != nil {
   386  			return nil, err
   387  		}
   388  		parsedKey, err := crypto.DecompressPubkey(key)
   389  		if err != nil {
   390  			return nil, err
   391  		}
   392  		server.PublicKey = parsedKey
   393  		servers = append(servers, server)
   394  	}
   395  	return servers, nil
   396  }