github.com/status-im/status-go@v1.1.0/protocol/pushnotificationclient/persistence.go (about) 1 package pushnotificationclient 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "encoding/gob" 9 "strings" 10 "time" 11 12 "github.com/golang/protobuf/proto" 13 14 "github.com/status-im/status-go/eth-node/crypto" 15 "github.com/status-im/status-go/protocol/protobuf" 16 ) 17 18 type Persistence struct { 19 db *sql.DB 20 } 21 22 func NewPersistence(db *sql.DB) *Persistence { 23 return &Persistence{db: db} 24 } 25 26 func (p *Persistence) GetLastPushNotificationRegistration() (*protobuf.PushNotificationRegistration, []*ecdsa.PublicKey, error) { 27 var registrationBytes []byte 28 var contactIDsBytes []byte 29 err := p.db.QueryRow(`SELECT registration,contact_ids FROM push_notification_client_registrations LIMIT 1`).Scan(®istrationBytes, &contactIDsBytes) 30 if err == sql.ErrNoRows { 31 return nil, nil, nil 32 } else if err != nil { 33 return nil, nil, err 34 } 35 36 var publicKeyBytes [][]byte 37 var contactIDs []*ecdsa.PublicKey 38 // Restore contactIDs 39 contactIDsDecoder := gob.NewDecoder(bytes.NewBuffer(contactIDsBytes)) 40 err = contactIDsDecoder.Decode(&publicKeyBytes) 41 if err != nil { 42 return nil, nil, err 43 } 44 for _, pkBytes := range publicKeyBytes { 45 pk, err := crypto.DecompressPubkey(pkBytes) 46 if err != nil { 47 return nil, nil, err 48 } 49 contactIDs = append(contactIDs, pk) 50 } 51 52 registration := &protobuf.PushNotificationRegistration{} 53 54 err = proto.Unmarshal(registrationBytes, registration) 55 if err != nil { 56 return nil, nil, err 57 } 58 59 return registration, contactIDs, nil 60 } 61 62 func (p *Persistence) SaveLastPushNotificationRegistration(registration *protobuf.PushNotificationRegistration, contactIDs []*ecdsa.PublicKey) error { 63 var encodedContactIDs bytes.Buffer 64 var contactIDsBytes [][]byte 65 for _, pk := range contactIDs { 66 contactIDsBytes = append(contactIDsBytes, crypto.CompressPubkey(pk)) 67 } 68 pkEncoder := gob.NewEncoder(&encodedContactIDs) 69 if err := pkEncoder.Encode(contactIDsBytes); err != nil { 70 return err 71 } 72 73 marshaledRegistration, err := proto.Marshal(registration) 74 if err != nil { 75 return err 76 } 77 _, err = p.db.Exec(`INSERT INTO push_notification_client_registrations (registration,contact_ids) VALUES (?, ?)`, marshaledRegistration, encodedContactIDs.Bytes()) 78 return err 79 } 80 81 func (p *Persistence) TrackPushNotification(chatID string, messageID []byte) error { 82 trackedAt := time.Now().Unix() 83 _, err := p.db.Exec(`INSERT INTO push_notification_client_tracked_messages (chat_id, message_id, tracked_at) VALUES (?,?,?)`, chatID, messageID, trackedAt) 84 return err 85 } 86 87 func (p *Persistence) TrackedMessage(messageID []byte) (bool, error) { 88 var count uint64 89 err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count) 90 if err != nil { 91 return false, err 92 } 93 94 if count == 0 { 95 return false, nil 96 } 97 98 return true, nil 99 } 100 101 func (p *Persistence) SavePushNotificationQuery(publicKey *ecdsa.PublicKey, queryID []byte) error { 102 queriedAt := time.Now().Unix() 103 _, err := p.db.Exec(`INSERT INTO push_notification_client_queries (public_key, query_id, queried_at) VALUES (?,?,?)`, crypto.CompressPubkey(publicKey), queryID, queriedAt) 104 return err 105 } 106 107 func (p *Persistence) GetQueriedAt(publicKey *ecdsa.PublicKey) (int64, error) { 108 var queriedAt int64 109 err := p.db.QueryRow(`SELECT queried_at FROM push_notification_client_queries WHERE public_key = ? ORDER BY queried_at DESC LIMIT 1`, crypto.CompressPubkey(publicKey)).Scan(&queriedAt) 110 if err == sql.ErrNoRows { 111 return 0, nil 112 } 113 if err != nil { 114 return 0, err 115 } 116 117 return queriedAt, nil 118 } 119 120 func (p *Persistence) GetQueryPublicKey(queryID []byte) (*ecdsa.PublicKey, error) { 121 var publicKeyBytes []byte 122 err := p.db.QueryRow(`SELECT public_key FROM push_notification_client_queries WHERE query_id = ?`, queryID).Scan(&publicKeyBytes) 123 if err == sql.ErrNoRows { 124 return nil, nil 125 } 126 if err != nil { 127 return nil, err 128 } 129 130 publicKey, err := crypto.DecompressPubkey(publicKeyBytes) 131 if err != nil { 132 return nil, err 133 } 134 return publicKey, nil 135 } 136 137 func (p *Persistence) SavePushNotificationInfo(infos []*PushNotificationInfo) error { 138 tx, err := p.db.BeginTx(context.Background(), &sql.TxOptions{}) 139 defer func() { 140 if err == nil { 141 err = tx.Commit() 142 return 143 } 144 // don't shadow original error 145 _ = tx.Rollback() 146 }() 147 for _, info := range infos { 148 var latestVersion uint64 149 clientCompressedKey := crypto.CompressPubkey(info.PublicKey) 150 err := tx.QueryRow(`SELECT IFNULL(MAX(version),0) FROM push_notification_client_info WHERE public_key = ? AND installation_id = ? LIMIT 1`, clientCompressedKey, info.InstallationID).Scan(&latestVersion) 151 if err != sql.ErrNoRows && err != nil { 152 return err 153 } 154 if latestVersion > info.Version { 155 // Nothing to do 156 continue 157 } 158 159 // Remove anything that as a lower version 160 _, err = tx.Exec(`DELETE FROM push_notification_client_info WHERE public_key = ? AND installation_id = ? AND version < ?`, clientCompressedKey, info.InstallationID, info.Version) 161 if err != nil { 162 return err 163 } 164 // Insert 165 _, err = tx.Exec(`INSERT INTO push_notification_client_info (public_key, server_public_key, installation_id, access_token, retrieved_at, version) VALUES (?, ?, ?, ?, ?,?)`, clientCompressedKey, crypto.CompressPubkey(info.ServerPublicKey), info.InstallationID, info.AccessToken, info.RetrievedAt, info.Version) 166 if err != nil { 167 return err 168 } 169 } 170 171 return nil 172 } 173 174 func (p *Persistence) GetPushNotificationInfo(publicKey *ecdsa.PublicKey, installationIDs []string) ([]*PushNotificationInfo, error) { 175 queryArgs := make([]interface{}, 0, len(installationIDs)+1) 176 queryArgs = append(queryArgs, crypto.CompressPubkey(publicKey)) 177 for _, installationID := range installationIDs { 178 queryArgs = append(queryArgs, installationID) 179 } 180 181 inVector := strings.Repeat("?, ", len(installationIDs)-1) + "?" 182 183 rows, err := p.db.Query(`SELECT server_public_key, installation_id, version, access_token, retrieved_at FROM push_notification_client_info WHERE public_key = ? AND installation_id IN (`+inVector+`)`, queryArgs...) //nolint: gosec 184 185 if err != nil { 186 return nil, err 187 } 188 defer rows.Close() 189 190 var infos []*PushNotificationInfo 191 for rows.Next() { 192 var serverPublicKeyBytes []byte 193 info := &PushNotificationInfo{PublicKey: publicKey} 194 err := rows.Scan(&serverPublicKeyBytes, &info.InstallationID, &info.Version, &info.AccessToken, &info.RetrievedAt) 195 if err != nil { 196 return nil, err 197 } 198 199 serverPublicKey, err := crypto.DecompressPubkey(serverPublicKeyBytes) 200 if err != nil { 201 return nil, err 202 } 203 204 info.ServerPublicKey = serverPublicKey 205 infos = append(infos, info) 206 } 207 208 return infos, nil 209 } 210 211 func (p *Persistence) GetPushNotificationInfoByPublicKey(publicKey *ecdsa.PublicKey) ([]*PushNotificationInfo, error) { 212 rows, err := p.db.Query(`SELECT server_public_key, installation_id, access_token, retrieved_at FROM push_notification_client_info WHERE public_key = ?`, crypto.CompressPubkey(publicKey)) 213 if err != nil { 214 return nil, err 215 } 216 defer rows.Close() 217 218 var infos []*PushNotificationInfo 219 for rows.Next() { 220 var serverPublicKeyBytes []byte 221 info := &PushNotificationInfo{PublicKey: publicKey} 222 err := rows.Scan(&serverPublicKeyBytes, &info.InstallationID, &info.AccessToken, &info.RetrievedAt) 223 if err != nil { 224 return nil, err 225 } 226 227 serverPublicKey, err := crypto.DecompressPubkey(serverPublicKeyBytes) 228 if err != nil { 229 return nil, err 230 } 231 232 info.ServerPublicKey = serverPublicKey 233 infos = append(infos, info) 234 } 235 236 return infos, nil 237 } 238 239 func (p *Persistence) ShouldSendNotificationFor(publicKey *ecdsa.PublicKey, installationID string, messageID []byte) (bool, error) { 240 // First we check that we are tracking this message, next we check that we haven't already sent this 241 var count uint64 242 err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count) 243 if err != nil { 244 return false, err 245 } 246 247 if count == 0 { 248 return false, nil 249 } 250 251 err = p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_sent_notifications WHERE message_id = ? AND public_key = ? AND installation_id = ? `, messageID, crypto.CompressPubkey(publicKey), installationID).Scan(&count) 252 if err != nil { 253 return false, err 254 } 255 256 return count == 0, nil 257 } 258 259 func (p *Persistence) ShouldSendNotificationToAllInstallationIDs(publicKey *ecdsa.PublicKey, messageID []byte) (bool, error) { 260 // First we check that we are tracking this message, next we check that we haven't already sent this 261 var count uint64 262 err := p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_tracked_messages WHERE message_id = ?`, messageID).Scan(&count) 263 if err != nil { 264 return false, err 265 } 266 267 if count == 0 { 268 return false, nil 269 } 270 271 err = p.db.QueryRow(`SELECT COUNT(1) FROM push_notification_client_sent_notifications WHERE message_id = ? AND public_key = ? `, messageID, crypto.CompressPubkey(publicKey)).Scan(&count) 272 if err != nil { 273 return false, err 274 } 275 276 return count == 0, nil 277 } 278 279 func (p *Persistence) UpsertSentNotification(n *SentNotification) error { 280 _, err := p.db.Exec(`INSERT INTO push_notification_client_sent_notifications (public_key, installation_id, message_id, last_tried_at, retry_count, success, error, hashed_public_key,chat_id, notification_type) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, crypto.CompressPubkey(n.PublicKey), n.InstallationID, n.MessageID, n.LastTriedAt, n.RetryCount, n.Success, n.Error, n.HashedPublicKey(), n.ChatID, n.NotificationType) 281 return err 282 } 283 284 func (p *Persistence) GetSentNotification(hashedPublicKey []byte, installationID string, messageID []byte) (*SentNotification, error) { 285 var publicKeyBytes []byte 286 sentNotification := &SentNotification{ 287 InstallationID: installationID, 288 MessageID: messageID, 289 } 290 err := p.db.QueryRow(`SELECT retry_count, last_tried_at, error, success, public_key,chat_id,notification_type FROM push_notification_client_sent_notifications WHERE hashed_public_key = ?`, hashedPublicKey).Scan(&sentNotification.RetryCount, &sentNotification.LastTriedAt, &sentNotification.Error, &sentNotification.Success, &publicKeyBytes, &sentNotification.ChatID, &sentNotification.NotificationType) 291 if err != nil { 292 return nil, err 293 } 294 295 publicKey, err := crypto.DecompressPubkey(publicKeyBytes) 296 if err != nil { 297 return nil, err 298 } 299 300 sentNotification.PublicKey = publicKey 301 302 return sentNotification, nil 303 } 304 305 func (p *Persistence) UpdateNotificationResponse(messageID []byte, response *protobuf.PushNotificationReport) error { 306 _, err := p.db.Exec(`UPDATE push_notification_client_sent_notifications SET success = ?, error = ? WHERE hashed_public_key = ? AND installation_id = ? AND message_id = ? AND NOT success`, response.Success, response.Error, response.PublicKey, response.InstallationId, messageID) 307 return err 308 } 309 310 func (p *Persistence) GetRetriablePushNotifications() ([]*SentNotification, error) { 311 var notifications []*SentNotification 312 rows, err := p.db.Query(`SELECT retry_count, last_tried_at, error, success, public_key, installation_id, message_id,chat_id, notification_type FROM push_notification_client_sent_notifications WHERE NOT success AND error = ? AND retry_count <= ?`, protobuf.PushNotificationReport_WRONG_TOKEN, maxPushNotificationRetries) 313 if err != nil { 314 return nil, err 315 } 316 defer rows.Close() 317 318 for rows.Next() { 319 var publicKeyBytes []byte 320 notification := &SentNotification{} 321 err = rows.Scan(¬ification.RetryCount, ¬ification.LastTriedAt, ¬ification.Error, ¬ification.Success, &publicKeyBytes, ¬ification.InstallationID, ¬ification.MessageID, ¬ification.ChatID, ¬ification.NotificationType) 322 if err != nil { 323 return nil, err 324 } 325 publicKey, err := crypto.DecompressPubkey(publicKeyBytes) 326 if err != nil { 327 return nil, err 328 } 329 notification.PublicKey = publicKey 330 notifications = append(notifications, notification) 331 } 332 return notifications, err 333 } 334 335 func (p *Persistence) UpsertServer(server *PushNotificationServer) error { 336 _, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at, access_token, last_retried_at, retry_count, server_type) VALUES (?,?,?,?,?,?,?)`, crypto.CompressPubkey(server.PublicKey), server.Registered, server.RegisteredAt, server.AccessToken, server.LastRetriedAt, server.RetryCount, server.Type) 337 return err 338 339 } 340 341 func (p *Persistence) GetServers() ([]*PushNotificationServer, error) { 342 rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token,last_retried_at, retry_count, server_type FROM push_notification_client_servers`) 343 if err != nil { 344 return nil, err 345 } 346 defer rows.Close() 347 348 var servers []*PushNotificationServer 349 for rows.Next() { 350 server := &PushNotificationServer{} 351 var key []byte 352 err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken, &server.LastRetriedAt, &server.RetryCount, &server.Type) 353 if err != nil { 354 return nil, err 355 } 356 parsedKey, err := crypto.DecompressPubkey(key) 357 if err != nil { 358 return nil, err 359 } 360 server.PublicKey = parsedKey 361 servers = append(servers, server) 362 } 363 return servers, nil 364 } 365 366 func (p *Persistence) GetServersByPublicKey(keys []*ecdsa.PublicKey) ([]*PushNotificationServer, error) { 367 368 keyArgs := make([]interface{}, 0, len(keys)) 369 for _, key := range keys { 370 keyArgs = append(keyArgs, crypto.CompressPubkey(key)) 371 } 372 373 inVector := strings.Repeat("?, ", len(keys)-1) + "?" 374 rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers WHERE public_key IN (`+inVector+")", keyArgs...) //nolint: gosec 375 if err != nil { 376 return nil, err 377 } 378 defer rows.Close() 379 380 var servers []*PushNotificationServer 381 for rows.Next() { 382 server := &PushNotificationServer{} 383 var key []byte 384 err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken) 385 if err != nil { 386 return nil, err 387 } 388 parsedKey, err := crypto.DecompressPubkey(key) 389 if err != nil { 390 return nil, err 391 } 392 server.PublicKey = parsedKey 393 servers = append(servers, server) 394 } 395 return servers, nil 396 }