github.com/status-im/status-go@v1.1.0/protocol/storenodes/database.go (about)

     1  package storenodes
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/multiformats/go-multiaddr"
    10  
    11  	"github.com/status-im/status-go/eth-node/types"
    12  )
    13  
    14  type Database struct {
    15  	db *sql.DB
    16  }
    17  
    18  func NewDB(db *sql.DB) *Database {
    19  	return &Database{db: db}
    20  }
    21  
    22  // syncSave will sync the storenodes in the DB from the snode slice
    23  //   - if a storenode is not in the provided list, it will be soft-deleted
    24  //   - if a storenode is in the provided list, it will be inserted or updated
    25  func (d *Database) syncSave(communityID types.HexBytes, snode []Storenode, clock uint64) (err error) {
    26  	var tx *sql.Tx
    27  	tx, err = d.db.Begin()
    28  	if err != nil {
    29  		return err
    30  	}
    31  	defer func() {
    32  		if err == nil {
    33  			err = tx.Commit()
    34  			return
    35  		}
    36  		_ = tx.Rollback()
    37  	}()
    38  
    39  	now := time.Now().Unix()
    40  	dbNodes, err := d.getByCommunityID(communityID, tx)
    41  	if err != nil {
    42  		return fmt.Errorf("getting storenodes by community id: %w", err)
    43  	}
    44  	// Soft-delete db nodes that are not in the provided list
    45  	for _, dbN := range dbNodes {
    46  		if find(dbN, snode) != nil {
    47  			continue
    48  		}
    49  		if clock != 0 && dbN.Clock >= clock {
    50  			continue
    51  		}
    52  		if err := d.softDelete(communityID, dbN.StorenodeID, now, tx); err != nil {
    53  			return fmt.Errorf("soft deleting existing storenodes: %w", err)
    54  		}
    55  
    56  	}
    57  	// Insert or update the nodes in the provided list
    58  	for _, n := range snode {
    59  		// defensively validate the communityID
    60  		if len(n.CommunityID) == 0 || !bytes.Equal(communityID, n.CommunityID) {
    61  			err = fmt.Errorf("communityID mismatch %v != %v", communityID, n.CommunityID)
    62  			return err
    63  		}
    64  		dbN := find(n, dbNodes)
    65  		if dbN != nil && n.Clock != 0 && dbN.Clock >= n.Clock {
    66  			continue
    67  		}
    68  		if err := d.upsert(n, tx); err != nil {
    69  			return fmt.Errorf("upserting storenodes: %w", err)
    70  		}
    71  	}
    72  	// TODO for now only allow one storenode per community
    73  	count, err := d.countByCommunity(communityID, tx)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	if count > 1 {
    78  		err = fmt.Errorf("only one storenode per community is allowed")
    79  		return err
    80  	}
    81  	return nil
    82  }
    83  
    84  func (d *Database) getAll() ([]Storenode, error) {
    85  	rows, err := d.db.Query(`
    86  		SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
    87  		FROM community_storenodes
    88  		WHERE removed = 0
    89  	`)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	defer rows.Close()
    94  	return toStorenodes(rows)
    95  }
    96  
    97  func (d *Database) getByCommunityID(communityID types.HexBytes, tx ...*sql.Tx) ([]Storenode, error) {
    98  	var rows *sql.Rows
    99  	var err error
   100  	q := `
   101  	SELECT community_id, storenode_id, name, address, fleet, version, clock, removed, deleted_at
   102  	FROM community_storenodes
   103  	WHERE community_id = ? AND removed = 0
   104  `
   105  	if len(tx) > 0 {
   106  		rows, err = tx[0].Query(q, communityID)
   107  	} else {
   108  		rows, err = d.db.Query(q, communityID)
   109  	}
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	defer rows.Close()
   114  	return toStorenodes(rows)
   115  }
   116  
   117  func (d *Database) softDelete(communityID types.HexBytes, storenodeID string, deletedAt int64, tx *sql.Tx) error {
   118  	_, err := tx.Exec("UPDATE community_storenodes SET removed = 1, deleted_at = ? WHERE community_id = ? AND storenode_id = ?", deletedAt, communityID, storenodeID)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	return nil
   123  }
   124  
   125  func (d *Database) upsert(n Storenode, tx *sql.Tx) error {
   126  	_, err := tx.Exec(`INSERT OR REPLACE INTO community_storenodes(
   127  		community_id,
   128  		storenode_id,
   129  		name,
   130  		address,
   131  		fleet,
   132  		version,
   133  		clock,
   134  		removed,
   135  		deleted_at
   136  	) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
   137  		n.CommunityID,
   138  		n.StorenodeID,
   139  		n.Name,
   140  		n.Address.String(),
   141  		n.Fleet,
   142  		n.Version,
   143  		n.Clock,
   144  		n.Removed,
   145  		n.DeletedAt,
   146  	)
   147  	if err != nil {
   148  		return err
   149  	}
   150  	return nil
   151  }
   152  
   153  func (d *Database) countByCommunity(communityID types.HexBytes, tx *sql.Tx) (int, error) {
   154  	var count int
   155  	err := tx.QueryRow(`SELECT COUNT(*) FROM community_storenodes WHERE community_id = ? AND removed = 0`, communityID).Scan(&count)
   156  	if err != nil {
   157  		return 0, err
   158  	}
   159  	return count, nil
   160  }
   161  
   162  func toStorenodes(rows *sql.Rows) ([]Storenode, error) {
   163  	var result []Storenode
   164  
   165  	for rows.Next() {
   166  		var m Storenode
   167  		var addr string
   168  		if err := rows.Scan(
   169  			&m.CommunityID,
   170  			&m.StorenodeID,
   171  			&m.Name,
   172  			&addr,
   173  			&m.Fleet,
   174  			&m.Version,
   175  			&m.Clock,
   176  			&m.Removed,
   177  			&m.DeletedAt,
   178  		); err != nil {
   179  			return nil, err
   180  		}
   181  
   182  		maddr, err := multiaddr.NewMultiaddr(addr)
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		m.Address = maddr
   187  		result = append(result, m)
   188  	}
   189  
   190  	return result, nil
   191  }
   192  
   193  func find(n Storenode, nodes []Storenode) *Storenode {
   194  	for i, node := range nodes {
   195  		if node.StorenodeID == n.StorenodeID && bytes.Equal(node.CommunityID, n.CommunityID) {
   196  			return &nodes[i]
   197  		}
   198  	}
   199  	return nil
   200  }