github.com/status-im/status-go@v1.1.0/protocol/transport/processed_message_ids_cache.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"strings"
     7  )
     8  
     9  type ProcessedMessageIDsCache struct {
    10  	db *sql.DB
    11  }
    12  
    13  func NewProcessedMessageIDsCache(db *sql.DB) *ProcessedMessageIDsCache {
    14  	return &ProcessedMessageIDsCache{db: db}
    15  }
    16  
    17  func (c *ProcessedMessageIDsCache) Clear() error {
    18  	_, err := c.db.Exec("DELETE FROM transport_message_cache")
    19  	return err
    20  }
    21  
    22  func (c *ProcessedMessageIDsCache) Hits(ids []string) (map[string]bool, error) {
    23  	hits := make(map[string]bool)
    24  
    25  	// Split the results into batches of 999 items.
    26  	// To prevent excessive memory allocations, the maximum value of a host parameter number
    27  	// is SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999
    28  	batch := 999
    29  	for i := 0; i < len(ids); i += batch {
    30  		j := i + batch
    31  		if j > len(ids) {
    32  			j = len(ids)
    33  		}
    34  
    35  		currentBatch := ids[i:j]
    36  
    37  		idsArgs := make([]interface{}, 0, len(currentBatch))
    38  		for _, id := range currentBatch {
    39  			idsArgs = append(idsArgs, id)
    40  		}
    41  
    42  		inVector := strings.Repeat("?, ", len(currentBatch)-1) + "?"
    43  		query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec
    44  
    45  		rows, err := c.db.Query(query, idsArgs...)
    46  		if err != nil {
    47  			return nil, err
    48  		}
    49  		defer rows.Close()
    50  
    51  		for rows.Next() {
    52  			var id string
    53  			err := rows.Scan(&id)
    54  			if err != nil {
    55  				return nil, err
    56  			}
    57  			hits[id] = true
    58  		}
    59  	}
    60  
    61  	return hits, nil
    62  }
    63  
    64  func (c *ProcessedMessageIDsCache) Add(ids []string, timestamp uint64) (err error) {
    65  	var tx *sql.Tx
    66  	tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{})
    67  	if err != nil {
    68  		return
    69  	}
    70  
    71  	defer func() {
    72  		if err == nil {
    73  			err = tx.Commit()
    74  			return
    75  		}
    76  		// don't shadow original error
    77  		_ = tx.Rollback()
    78  	}()
    79  
    80  	for _, id := range ids {
    81  
    82  		var stmt *sql.Stmt
    83  		stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`)
    84  		if err != nil {
    85  			return
    86  		}
    87  
    88  		_, err = stmt.Exec(id, timestamp)
    89  		if err != nil {
    90  			return
    91  		}
    92  	}
    93  
    94  	return
    95  }
    96  
    97  func (c *ProcessedMessageIDsCache) Clean(timestamp uint64) error {
    98  	_, err := c.db.Exec(`DELETE FROM transport_message_cache WHERE timestamp < ?`, timestamp)
    99  	return err
   100  }