github.com/status-im/status-go@v1.1.0/wakuv2/persistence/dbstore.go (about) 1 package persistence 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "strings" 9 "sync" 10 "time" 11 12 gowakuPersistence "github.com/waku-org/go-waku/waku/persistence" 13 "github.com/waku-org/go-waku/waku/v2/protocol" 14 storepb "github.com/waku-org/go-waku/waku/v2/protocol/legacy_store/pb" 15 "github.com/waku-org/go-waku/waku/v2/protocol/pb" 16 "github.com/waku-org/go-waku/waku/v2/timesource" 17 "github.com/waku-org/go-waku/waku/v2/utils" 18 19 "go.uber.org/zap" 20 ) 21 22 var ErrInvalidCursor = errors.New("invalid cursor") 23 24 var ErrFutureMessage = errors.New("message timestamp in the future") 25 var ErrMessageTooOld = errors.New("message too old") 26 27 // MaxTimeVariance is the maximum duration in the future allowed for a message timestamp 28 const MaxTimeVariance = time.Duration(20) * time.Second 29 30 // DBStore is a MessageProvider that has a *sql.DB connection 31 type DBStore struct { 32 db *sql.DB 33 log *zap.Logger 34 35 maxMessages int 36 maxDuration time.Duration 37 38 wg sync.WaitGroup 39 cancel context.CancelFunc 40 } 41 42 // DBOption is an optional setting that can be used to configure the DBStore 43 type DBOption func(*DBStore) error 44 45 // WithDB is a DBOption that lets you use any custom *sql.DB with a DBStore. 46 func WithDB(db *sql.DB) DBOption { 47 return func(d *DBStore) error { 48 d.db = db 49 return nil 50 } 51 } 52 53 // WithRetentionPolicy is a DBOption that specifies the max number of messages 54 // to be stored and duration before they're removed from the message store 55 func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption { 56 return func(d *DBStore) error { 57 d.maxDuration = maxDuration 58 d.maxMessages = maxMessages 59 return nil 60 } 61 } 62 63 // Creates a new DB store using the db specified via options. 64 // It will create a messages table if it does not exist and 65 // clean up records according to the retention policy used 66 func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) { 67 result := new(DBStore) 68 result.log = log.Named("dbstore") 69 70 for _, opt := range options { 71 err := opt(result) 72 if err != nil { 73 return nil, err 74 } 75 } 76 77 return result, nil 78 } 79 80 func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error { 81 ctx, cancel := context.WithCancel(ctx) 82 83 d.cancel = cancel 84 85 err := d.cleanOlderRecords() 86 if err != nil { 87 return err 88 } 89 90 d.wg.Add(1) 91 go d.checkForOlderRecords(ctx, 60*time.Second) 92 93 return nil 94 } 95 96 func (d *DBStore) Validate(env *protocol.Envelope) error { 97 n := time.Unix(0, env.Index().ReceiverTime) 98 upperBound := n.Add(MaxTimeVariance) 99 lowerBound := n.Add(-MaxTimeVariance) 100 101 // Ensure that messages don't "jump" to the front of the queue with future timestamps 102 if env.Message().GetTimestamp() > upperBound.UnixNano() { 103 return ErrFutureMessage 104 } 105 106 if env.Message().GetTimestamp() < lowerBound.UnixNano() { 107 return ErrMessageTooOld 108 } 109 110 return nil 111 } 112 113 func (d *DBStore) cleanOlderRecords() error { 114 d.log.Debug("Cleaning older records...") 115 116 // Delete older messages 117 if d.maxDuration > 0 { 118 start := time.Now() 119 sqlStmt := `DELETE FROM store_messages WHERE receiverTimestamp < ?` 120 _, err := d.db.Exec(sqlStmt, utils.GetUnixEpochFrom(time.Now().Add(-d.maxDuration))) 121 if err != nil { 122 return err 123 } 124 elapsed := time.Since(start) 125 d.log.Debug("deleting older records from the DB", zap.Duration("duration", elapsed)) 126 } 127 128 // Limit number of records to a max N 129 if d.maxMessages > 0 { 130 start := time.Now() 131 sqlStmt := `DELETE FROM store_messages WHERE id IN (SELECT id FROM store_messages ORDER BY receiverTimestamp DESC LIMIT -1 OFFSET ?)` 132 _, err := d.db.Exec(sqlStmt, d.maxMessages) 133 if err != nil { 134 return err 135 } 136 elapsed := time.Since(start) 137 d.log.Debug("deleting excess records from the DB", zap.Duration("duration", elapsed)) 138 } 139 140 return nil 141 } 142 143 func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) { 144 defer d.wg.Done() 145 146 ticker := time.NewTicker(t) 147 defer ticker.Stop() 148 149 for { 150 select { 151 case <-ctx.Done(): 152 return 153 case <-ticker.C: 154 err := d.cleanOlderRecords() 155 if err != nil { 156 d.log.Error("cleaning older records", zap.Error(err)) 157 } 158 } 159 } 160 } 161 162 // Stop closes a DB connection 163 func (d *DBStore) Stop() { 164 if d.cancel == nil { 165 return 166 } 167 168 d.cancel() 169 d.wg.Wait() 170 d.db.Close() 171 } 172 173 // Put inserts a WakuMessage into the DB 174 func (d *DBStore) Put(env *protocol.Envelope) error { 175 stmt, err := d.db.Prepare("INSERT INTO store_messages (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)") 176 if err != nil { 177 return err 178 } 179 180 cursor := env.Index() 181 dbKey := NewDBKey(uint64(cursor.SenderTime), uint64(env.Index().ReceiverTime), env.PubsubTopic(), env.Index().Digest) 182 _, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().Timestamp, env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().Version) 183 if err != nil { 184 return err 185 } 186 187 err = stmt.Close() 188 if err != nil { 189 return err 190 } 191 192 return nil 193 } 194 195 // Query retrieves messages from the DB 196 func (d *DBStore) Query(query *storepb.HistoryQuery) (*storepb.Index, []gowakuPersistence.StoredMessage, error) { 197 start := time.Now() 198 defer func() { 199 elapsed := time.Since(start) 200 d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed)) 201 }() 202 203 sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version 204 FROM store_messages 205 %s 206 ORDER BY senderTimestamp %s, id %s, pubsubTopic %s, receiverTimestamp %s ` 207 208 var conditions []string 209 var parameters []interface{} 210 paramCnt := 0 211 212 if query.PubsubTopic != "" { 213 paramCnt++ 214 conditions = append(conditions, fmt.Sprintf("pubsubTopic = $%d", paramCnt)) 215 parameters = append(parameters, query.PubsubTopic) 216 } 217 218 if len(query.ContentFilters) != 0 { 219 var ctPlaceHolder []string 220 for _, ct := range query.ContentFilters { 221 if ct.ContentTopic != "" { 222 paramCnt++ 223 ctPlaceHolder = append(ctPlaceHolder, fmt.Sprintf("$%d", paramCnt)) 224 parameters = append(parameters, ct.ContentTopic) 225 } 226 } 227 conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")") 228 } 229 230 usesCursor := false 231 if query.PagingInfo.Cursor != nil { 232 usesCursor = true 233 var exists bool 234 cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), uint64(query.PagingInfo.Cursor.ReceiverTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest) 235 236 err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM store_messages WHERE id = $1)", 237 cursorDBKey.Bytes(), 238 ).Scan(&exists) 239 240 if err != nil { 241 return nil, nil, err 242 } 243 244 if exists { 245 eqOp := ">" 246 if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD { 247 eqOp = "<" 248 } 249 paramCnt++ 250 conditions = append(conditions, fmt.Sprintf("id %s $%d", eqOp, paramCnt)) 251 252 parameters = append(parameters, cursorDBKey.Bytes()) 253 } else { 254 return nil, nil, ErrInvalidCursor 255 } 256 } 257 258 if query.GetStartTime() != 0 { 259 if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD { 260 paramCnt++ 261 conditions = append(conditions, fmt.Sprintf("id >= $%d", paramCnt)) 262 startTimeDBKey := NewDBKey(uint64(query.GetStartTime()), uint64(query.GetStartTime()), "", []byte{}) 263 parameters = append(parameters, startTimeDBKey.Bytes()) 264 } 265 266 } 267 268 if query.GetEndTime() != 0 { 269 if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_FORWARD { 270 paramCnt++ 271 conditions = append(conditions, fmt.Sprintf("id <= $%d", paramCnt)) 272 endTimeDBKey := NewDBKey(uint64(query.GetEndTime()), uint64(query.GetEndTime()), "", []byte{}) 273 parameters = append(parameters, endTimeDBKey.Bytes()) 274 } 275 } 276 277 conditionStr := "" 278 if len(conditions) != 0 { 279 conditionStr = "WHERE " + strings.Join(conditions, " AND ") 280 } 281 282 orderDirection := "ASC" 283 if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD { 284 orderDirection = "DESC" 285 } 286 287 paramCnt++ 288 sqlQuery += fmt.Sprintf("LIMIT $%d", paramCnt) 289 sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection, orderDirection, orderDirection, orderDirection) 290 291 stmt, err := d.db.Prepare(sqlQuery) 292 if err != nil { 293 return nil, nil, err 294 } 295 defer stmt.Close() 296 297 pageSize := query.PagingInfo.PageSize + 1 298 299 parameters = append(parameters, pageSize) 300 rows, err := stmt.Query(parameters...) 301 if err != nil { 302 return nil, nil, err 303 } 304 305 var result []gowakuPersistence.StoredMessage 306 for rows.Next() { 307 record, err := d.GetStoredMessage(rows) 308 if err != nil { 309 return nil, nil, err 310 } 311 result = append(result, record) 312 } 313 defer rows.Close() 314 315 var cursor *storepb.Index 316 if len(result) != 0 { 317 if len(result) > int(query.PagingInfo.PageSize) { 318 result = result[0:query.PagingInfo.PageSize] 319 lastMsgIdx := len(result) - 1 320 cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index() 321 } 322 } 323 324 // The retrieved messages list should always be in chronological order 325 if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD { 326 for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { 327 result[i], result[j] = result[j], result[i] 328 } 329 } 330 331 return cursor, result, nil 332 } 333 334 // MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp 335 // in the message table 336 func (d *DBStore) MostRecentTimestamp() (int64, error) { 337 result := sql.NullInt64{} 338 339 err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM store_messages`).Scan(&result) 340 if err != nil && err != sql.ErrNoRows { 341 return 0, err 342 } 343 return result.Int64, nil 344 } 345 346 // Count returns the number of rows in the message table 347 func (d *DBStore) Count() (int, error) { 348 var result int 349 err := d.db.QueryRow(`SELECT COUNT(*) FROM store_messages`).Scan(&result) 350 if err != nil && err != sql.ErrNoRows { 351 return 0, err 352 } 353 return result, nil 354 } 355 356 // GetAll returns all the stored WakuMessages 357 func (d *DBStore) GetAll() ([]gowakuPersistence.StoredMessage, error) { 358 start := time.Now() 359 defer func() { 360 elapsed := time.Since(start) 361 d.log.Info("loading records from the DB", zap.Duration("duration", elapsed)) 362 }() 363 364 rows, err := d.db.Query("SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version FROM store_messages ORDER BY senderTimestamp ASC") 365 if err != nil { 366 return nil, err 367 } 368 369 var result []gowakuPersistence.StoredMessage 370 371 defer rows.Close() 372 373 for rows.Next() { 374 record, err := d.GetStoredMessage(rows) 375 if err != nil { 376 return nil, err 377 } 378 result = append(result, record) 379 } 380 381 d.log.Info("DB returned records", zap.Int("count", len(result))) 382 383 err = rows.Err() 384 if err != nil { 385 return nil, err 386 } 387 388 return result, nil 389 } 390 391 // GetStoredMessage is a helper function used to convert a `*sql.Rows` into a `StoredMessage` 392 func (d *DBStore) GetStoredMessage(row *sql.Rows) (gowakuPersistence.StoredMessage, error) { 393 var id []byte 394 var receiverTimestamp int64 395 var senderTimestamp int64 396 var contentTopic string 397 var payload []byte 398 var version uint32 399 var pubsubTopic string 400 401 err := row.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version) 402 if err != nil { 403 d.log.Error("scanning messages from db", zap.Error(err)) 404 return gowakuPersistence.StoredMessage{}, err 405 } 406 407 msg := new(pb.WakuMessage) 408 msg.ContentTopic = contentTopic 409 msg.Payload = payload 410 msg.Timestamp = &senderTimestamp 411 msg.Version = &version 412 413 record := gowakuPersistence.StoredMessage{ 414 ID: id, 415 PubsubTopic: pubsubTopic, 416 ReceiverTime: receiverTimestamp, 417 Message: msg, 418 } 419 420 return record, nil 421 }