github.com/status-im/status-go@v1.1.0/wakuv2/persistence/dbstore.go (about)

     1  package persistence
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	gowakuPersistence "github.com/waku-org/go-waku/waku/persistence"
    13  	"github.com/waku-org/go-waku/waku/v2/protocol"
    14  	storepb "github.com/waku-org/go-waku/waku/v2/protocol/legacy_store/pb"
    15  	"github.com/waku-org/go-waku/waku/v2/protocol/pb"
    16  	"github.com/waku-org/go-waku/waku/v2/timesource"
    17  	"github.com/waku-org/go-waku/waku/v2/utils"
    18  
    19  	"go.uber.org/zap"
    20  )
    21  
    22  var ErrInvalidCursor = errors.New("invalid cursor")
    23  
    24  var ErrFutureMessage = errors.New("message timestamp in the future")
    25  var ErrMessageTooOld = errors.New("message too old")
    26  
    27  // MaxTimeVariance is the maximum duration in the future allowed for a message timestamp
    28  const MaxTimeVariance = time.Duration(20) * time.Second
    29  
    30  // DBStore is a MessageProvider that has a *sql.DB connection
    31  type DBStore struct {
    32  	db  *sql.DB
    33  	log *zap.Logger
    34  
    35  	maxMessages int
    36  	maxDuration time.Duration
    37  
    38  	wg     sync.WaitGroup
    39  	cancel context.CancelFunc
    40  }
    41  
    42  // DBOption is an optional setting that can be used to configure the DBStore
    43  type DBOption func(*DBStore) error
    44  
    45  // WithDB is a DBOption that lets you use any custom *sql.DB with a DBStore.
    46  func WithDB(db *sql.DB) DBOption {
    47  	return func(d *DBStore) error {
    48  		d.db = db
    49  		return nil
    50  	}
    51  }
    52  
    53  // WithRetentionPolicy is a DBOption that specifies the max number of messages
    54  // to be stored and duration before they're removed from the message store
    55  func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
    56  	return func(d *DBStore) error {
    57  		d.maxDuration = maxDuration
    58  		d.maxMessages = maxMessages
    59  		return nil
    60  	}
    61  }
    62  
    63  // Creates a new DB store using the db specified via options.
    64  // It will create a messages table if it does not exist and
    65  // clean up records according to the retention policy used
    66  func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
    67  	result := new(DBStore)
    68  	result.log = log.Named("dbstore")
    69  
    70  	for _, opt := range options {
    71  		err := opt(result)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  	}
    76  
    77  	return result, nil
    78  }
    79  
    80  func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
    81  	ctx, cancel := context.WithCancel(ctx)
    82  
    83  	d.cancel = cancel
    84  
    85  	err := d.cleanOlderRecords()
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	d.wg.Add(1)
    91  	go d.checkForOlderRecords(ctx, 60*time.Second)
    92  
    93  	return nil
    94  }
    95  
    96  func (d *DBStore) Validate(env *protocol.Envelope) error {
    97  	n := time.Unix(0, env.Index().ReceiverTime)
    98  	upperBound := n.Add(MaxTimeVariance)
    99  	lowerBound := n.Add(-MaxTimeVariance)
   100  
   101  	// Ensure that messages don't "jump" to the front of the queue with future timestamps
   102  	if env.Message().GetTimestamp() > upperBound.UnixNano() {
   103  		return ErrFutureMessage
   104  	}
   105  
   106  	if env.Message().GetTimestamp() < lowerBound.UnixNano() {
   107  		return ErrMessageTooOld
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func (d *DBStore) cleanOlderRecords() error {
   114  	d.log.Debug("Cleaning older records...")
   115  
   116  	// Delete older messages
   117  	if d.maxDuration > 0 {
   118  		start := time.Now()
   119  		sqlStmt := `DELETE FROM store_messages WHERE receiverTimestamp < ?`
   120  		_, err := d.db.Exec(sqlStmt, utils.GetUnixEpochFrom(time.Now().Add(-d.maxDuration)))
   121  		if err != nil {
   122  			return err
   123  		}
   124  		elapsed := time.Since(start)
   125  		d.log.Debug("deleting older records from the DB", zap.Duration("duration", elapsed))
   126  	}
   127  
   128  	// Limit number of records to a max N
   129  	if d.maxMessages > 0 {
   130  		start := time.Now()
   131  		sqlStmt := `DELETE FROM store_messages WHERE id IN (SELECT id FROM store_messages ORDER BY receiverTimestamp DESC LIMIT -1 OFFSET ?)`
   132  		_, err := d.db.Exec(sqlStmt, d.maxMessages)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		elapsed := time.Since(start)
   137  		d.log.Debug("deleting excess records from the DB", zap.Duration("duration", elapsed))
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
   144  	defer d.wg.Done()
   145  
   146  	ticker := time.NewTicker(t)
   147  	defer ticker.Stop()
   148  
   149  	for {
   150  		select {
   151  		case <-ctx.Done():
   152  			return
   153  		case <-ticker.C:
   154  			err := d.cleanOlderRecords()
   155  			if err != nil {
   156  				d.log.Error("cleaning older records", zap.Error(err))
   157  			}
   158  		}
   159  	}
   160  }
   161  
   162  // Stop closes a DB connection
   163  func (d *DBStore) Stop() {
   164  	if d.cancel == nil {
   165  		return
   166  	}
   167  
   168  	d.cancel()
   169  	d.wg.Wait()
   170  	d.db.Close()
   171  }
   172  
   173  // Put inserts a WakuMessage into the DB
   174  func (d *DBStore) Put(env *protocol.Envelope) error {
   175  	stmt, err := d.db.Prepare("INSERT INTO store_messages (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)")
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	cursor := env.Index()
   181  	dbKey := NewDBKey(uint64(cursor.SenderTime), uint64(env.Index().ReceiverTime), env.PubsubTopic(), env.Index().Digest)
   182  	_, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().Timestamp, env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().Version)
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	err = stmt.Close()
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  // Query retrieves messages from the DB
   196  func (d *DBStore) Query(query *storepb.HistoryQuery) (*storepb.Index, []gowakuPersistence.StoredMessage, error) {
   197  	start := time.Now()
   198  	defer func() {
   199  		elapsed := time.Since(start)
   200  		d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed))
   201  	}()
   202  
   203  	sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version 
   204  					 FROM store_messages 
   205  					 %s
   206  					 ORDER BY senderTimestamp %s, id %s, pubsubTopic %s, receiverTimestamp %s `
   207  
   208  	var conditions []string
   209  	var parameters []interface{}
   210  	paramCnt := 0
   211  
   212  	if query.PubsubTopic != "" {
   213  		paramCnt++
   214  		conditions = append(conditions, fmt.Sprintf("pubsubTopic = $%d", paramCnt))
   215  		parameters = append(parameters, query.PubsubTopic)
   216  	}
   217  
   218  	if len(query.ContentFilters) != 0 {
   219  		var ctPlaceHolder []string
   220  		for _, ct := range query.ContentFilters {
   221  			if ct.ContentTopic != "" {
   222  				paramCnt++
   223  				ctPlaceHolder = append(ctPlaceHolder, fmt.Sprintf("$%d", paramCnt))
   224  				parameters = append(parameters, ct.ContentTopic)
   225  			}
   226  		}
   227  		conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")")
   228  	}
   229  
   230  	usesCursor := false
   231  	if query.PagingInfo.Cursor != nil {
   232  		usesCursor = true
   233  		var exists bool
   234  		cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), uint64(query.PagingInfo.Cursor.ReceiverTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest)
   235  
   236  		err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM store_messages WHERE id = $1)",
   237  			cursorDBKey.Bytes(),
   238  		).Scan(&exists)
   239  
   240  		if err != nil {
   241  			return nil, nil, err
   242  		}
   243  
   244  		if exists {
   245  			eqOp := ">"
   246  			if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
   247  				eqOp = "<"
   248  			}
   249  			paramCnt++
   250  			conditions = append(conditions, fmt.Sprintf("id %s $%d", eqOp, paramCnt))
   251  
   252  			parameters = append(parameters, cursorDBKey.Bytes())
   253  		} else {
   254  			return nil, nil, ErrInvalidCursor
   255  		}
   256  	}
   257  
   258  	if query.GetStartTime() != 0 {
   259  		if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
   260  			paramCnt++
   261  			conditions = append(conditions, fmt.Sprintf("id >= $%d", paramCnt))
   262  			startTimeDBKey := NewDBKey(uint64(query.GetStartTime()), uint64(query.GetStartTime()), "", []byte{})
   263  			parameters = append(parameters, startTimeDBKey.Bytes())
   264  		}
   265  
   266  	}
   267  
   268  	if query.GetEndTime() != 0 {
   269  		if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_FORWARD {
   270  			paramCnt++
   271  			conditions = append(conditions, fmt.Sprintf("id <= $%d", paramCnt))
   272  			endTimeDBKey := NewDBKey(uint64(query.GetEndTime()), uint64(query.GetEndTime()), "", []byte{})
   273  			parameters = append(parameters, endTimeDBKey.Bytes())
   274  		}
   275  	}
   276  
   277  	conditionStr := ""
   278  	if len(conditions) != 0 {
   279  		conditionStr = "WHERE " + strings.Join(conditions, " AND ")
   280  	}
   281  
   282  	orderDirection := "ASC"
   283  	if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
   284  		orderDirection = "DESC"
   285  	}
   286  
   287  	paramCnt++
   288  	sqlQuery += fmt.Sprintf("LIMIT $%d", paramCnt)
   289  	sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection, orderDirection, orderDirection, orderDirection)
   290  
   291  	stmt, err := d.db.Prepare(sqlQuery)
   292  	if err != nil {
   293  		return nil, nil, err
   294  	}
   295  	defer stmt.Close()
   296  
   297  	pageSize := query.PagingInfo.PageSize + 1
   298  
   299  	parameters = append(parameters, pageSize)
   300  	rows, err := stmt.Query(parameters...)
   301  	if err != nil {
   302  		return nil, nil, err
   303  	}
   304  
   305  	var result []gowakuPersistence.StoredMessage
   306  	for rows.Next() {
   307  		record, err := d.GetStoredMessage(rows)
   308  		if err != nil {
   309  			return nil, nil, err
   310  		}
   311  		result = append(result, record)
   312  	}
   313  	defer rows.Close()
   314  
   315  	var cursor *storepb.Index
   316  	if len(result) != 0 {
   317  		if len(result) > int(query.PagingInfo.PageSize) {
   318  			result = result[0:query.PagingInfo.PageSize]
   319  			lastMsgIdx := len(result) - 1
   320  			cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index()
   321  		}
   322  	}
   323  
   324  	// The retrieved messages list should always be in chronological order
   325  	if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
   326  		for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
   327  			result[i], result[j] = result[j], result[i]
   328  		}
   329  	}
   330  
   331  	return cursor, result, nil
   332  }
   333  
   334  // MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp
   335  // in the message table
   336  func (d *DBStore) MostRecentTimestamp() (int64, error) {
   337  	result := sql.NullInt64{}
   338  
   339  	err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM store_messages`).Scan(&result)
   340  	if err != nil && err != sql.ErrNoRows {
   341  		return 0, err
   342  	}
   343  	return result.Int64, nil
   344  }
   345  
   346  // Count returns the number of rows in the message table
   347  func (d *DBStore) Count() (int, error) {
   348  	var result int
   349  	err := d.db.QueryRow(`SELECT COUNT(*) FROM store_messages`).Scan(&result)
   350  	if err != nil && err != sql.ErrNoRows {
   351  		return 0, err
   352  	}
   353  	return result, nil
   354  }
   355  
   356  // GetAll returns all the stored WakuMessages
   357  func (d *DBStore) GetAll() ([]gowakuPersistence.StoredMessage, error) {
   358  	start := time.Now()
   359  	defer func() {
   360  		elapsed := time.Since(start)
   361  		d.log.Info("loading records from the DB", zap.Duration("duration", elapsed))
   362  	}()
   363  
   364  	rows, err := d.db.Query("SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version FROM store_messages ORDER BY senderTimestamp ASC")
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	var result []gowakuPersistence.StoredMessage
   370  
   371  	defer rows.Close()
   372  
   373  	for rows.Next() {
   374  		record, err := d.GetStoredMessage(rows)
   375  		if err != nil {
   376  			return nil, err
   377  		}
   378  		result = append(result, record)
   379  	}
   380  
   381  	d.log.Info("DB returned records", zap.Int("count", len(result)))
   382  
   383  	err = rows.Err()
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	return result, nil
   389  }
   390  
   391  // GetStoredMessage is a helper function used to convert a `*sql.Rows` into a `StoredMessage`
   392  func (d *DBStore) GetStoredMessage(row *sql.Rows) (gowakuPersistence.StoredMessage, error) {
   393  	var id []byte
   394  	var receiverTimestamp int64
   395  	var senderTimestamp int64
   396  	var contentTopic string
   397  	var payload []byte
   398  	var version uint32
   399  	var pubsubTopic string
   400  
   401  	err := row.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
   402  	if err != nil {
   403  		d.log.Error("scanning messages from db", zap.Error(err))
   404  		return gowakuPersistence.StoredMessage{}, err
   405  	}
   406  
   407  	msg := new(pb.WakuMessage)
   408  	msg.ContentTopic = contentTopic
   409  	msg.Payload = payload
   410  	msg.Timestamp = &senderTimestamp
   411  	msg.Version = &version
   412  
   413  	record := gowakuPersistence.StoredMessage{
   414  		ID:           id,
   415  		PubsubTopic:  pubsubTopic,
   416  		ReceiverTime: receiverTimestamp,
   417  		Message:      msg,
   418  	}
   419  
   420  	return record, nil
   421  }