github.com/status-im/status-go@v1.1.0/services/wallet/collectibles/collection_data_db.go (about)

     1  package collectibles
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  
     7  	"github.com/status-im/status-go/services/wallet/thirdparty"
     8  	"github.com/status-im/status-go/sqlite"
     9  )
    10  
    11  type CollectionDataStorage interface {
    12  	SetData(collections []thirdparty.CollectionData, allowUpdate bool) error
    13  	GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error)
    14  	GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error)
    15  	SetCollectionSocialsData(id thirdparty.ContractID, collectionSocials *thirdparty.CollectionSocials) error
    16  	GetSocialsForID(contractID thirdparty.ContractID) (*thirdparty.CollectionSocials, error)
    17  }
    18  
    19  type CollectionDataDB struct {
    20  	db *sql.DB
    21  }
    22  
    23  func NewCollectionDataDB(sqlDb *sql.DB) *CollectionDataDB {
    24  	return &CollectionDataDB{
    25  		db: sqlDb,
    26  	}
    27  }
    28  
    29  const collectionDataColumns = "chain_id, contract_address, provider, name, slug, image_url, image_payload, community_id"
    30  const collectionTraitsColumns = "chain_id, contract_address, trait_type, min, max"
    31  const selectCollectionTraitsColumns = "trait_type, min, max"
    32  const collectionSocialsColumns = "chain_id, contract_address, provider, website, twitter_handle"
    33  const selectCollectionSocialsColumns = "website, twitter_handle, provider"
    34  
    35  func rowsToCollectionTraits(rows *sql.Rows) (map[string]thirdparty.CollectionTrait, error) {
    36  	traits := make(map[string]thirdparty.CollectionTrait)
    37  	for rows.Next() {
    38  		var traitType string
    39  		var trait thirdparty.CollectionTrait
    40  		err := rows.Scan(
    41  			&traitType,
    42  			&trait.Min,
    43  			&trait.Max,
    44  		)
    45  		if err != nil {
    46  			return nil, err
    47  		}
    48  		traits[traitType] = trait
    49  	}
    50  	return traits, nil
    51  }
    52  
    53  func getCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID) (map[string]thirdparty.CollectionTrait, error) {
    54  	// Get traits list
    55  	selectTraits, err := creator.Prepare(fmt.Sprintf(`SELECT %s
    56  		FROM collection_traits_cache
    57  		WHERE chain_id = ? AND contract_address = ?`, selectCollectionTraitsColumns))
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	rows, err := selectTraits.Query(
    63  		id.ChainID,
    64  		id.Address,
    65  	)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	return rowsToCollectionTraits(rows)
    71  }
    72  
    73  func upsertCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID, traits map[string]thirdparty.CollectionTrait) error {
    74  	// Rremove old traits list
    75  	deleteTraits, err := creator.Prepare(`DELETE FROM collection_traits_cache WHERE chain_id = ? AND contract_address = ?`)
    76  	if err != nil {
    77  		return err
    78  	}
    79  
    80  	_, err = deleteTraits.Exec(
    81  		id.ChainID,
    82  		id.Address,
    83  	)
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	// Insert new traits list
    89  	insertTrait, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_traits_cache (%s)
    90  		VALUES (?, ?, ?, ?, ?)`, collectionTraitsColumns))
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	for traitType, trait := range traits {
    96  		_, err = insertTrait.Exec(
    97  			id.ChainID,
    98  			id.Address,
    99  			traitType,
   100  			trait.Min,
   101  			trait.Max,
   102  		)
   103  		if err != nil {
   104  			return err
   105  		}
   106  	}
   107  
   108  	return nil
   109  }
   110  
   111  func setCollectionsData(creator sqlite.StatementCreator, collections []thirdparty.CollectionData, allowUpdate bool) error {
   112  	insertCollection, err := creator.Prepare(fmt.Sprintf(`%s INTO collection_data_cache (%s) 
   113  																				VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, insertStatement(allowUpdate), collectionDataColumns))
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	for _, c := range collections {
   119  		_, err = insertCollection.Exec(
   120  			c.ID.ChainID,
   121  			c.ID.Address,
   122  			c.Provider,
   123  			c.Name,
   124  			c.Slug,
   125  			c.ImageURL,
   126  			c.ImagePayload,
   127  			c.CommunityID,
   128  		)
   129  		if err != nil {
   130  			return err
   131  		}
   132  
   133  		err = upsertContractType(creator, c.ID, c.ContractType)
   134  		if err != nil {
   135  			return err
   136  		}
   137  
   138  		if allowUpdate {
   139  			err = upsertCollectionTraits(creator, c.ID, c.Traits)
   140  			if err != nil {
   141  				return err
   142  			}
   143  
   144  			if c.Socials != nil {
   145  				err = upsertCollectionSocials(creator, c.ID, c.Socials)
   146  				if err != nil {
   147  					return err
   148  				}
   149  			}
   150  		}
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  func (o *CollectionDataDB) SetData(collections []thirdparty.CollectionData, allowUpdate bool) (err error) {
   157  	tx, err := o.db.Begin()
   158  	if err != nil {
   159  		return err
   160  	}
   161  	defer func() {
   162  		if err == nil {
   163  			err = tx.Commit()
   164  			return
   165  		}
   166  		_ = tx.Rollback()
   167  	}()
   168  
   169  	// Insert new collections data
   170  	err = setCollectionsData(tx, collections, allowUpdate)
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	return
   176  }
   177  
   178  func scanCollectionsDataRow(row *sql.Row) (*thirdparty.CollectionData, error) {
   179  	c := thirdparty.CollectionData{
   180  		Traits: make(map[string]thirdparty.CollectionTrait),
   181  	}
   182  	err := row.Scan(
   183  		&c.ID.ChainID,
   184  		&c.ID.Address,
   185  		&c.Provider,
   186  		&c.Name,
   187  		&c.Slug,
   188  		&c.ImageURL,
   189  		&c.ImagePayload,
   190  		&c.CommunityID,
   191  	)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	return &c, nil
   196  }
   197  
   198  func (o *CollectionDataDB) GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error) {
   199  	ret := make([]thirdparty.ContractID, 0, len(ids))
   200  	idMap := make(map[string]thirdparty.ContractID, len(ids))
   201  
   202  	// Ensure we don't have duplicates
   203  	for _, id := range ids {
   204  		idMap[id.HashKey()] = id
   205  	}
   206  
   207  	exists, err := o.db.Prepare(`SELECT EXISTS (
   208  			SELECT 1 FROM collection_data_cache
   209  			WHERE chain_id=? AND contract_address=?
   210  		)`)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	for _, id := range idMap {
   216  		row := exists.QueryRow(
   217  			id.ChainID,
   218  			id.Address,
   219  		)
   220  		var exists bool
   221  		err = row.Scan(&exists)
   222  		if err != nil {
   223  			return nil, err
   224  		}
   225  		if !exists {
   226  			ret = append(ret, id)
   227  		}
   228  	}
   229  
   230  	return ret, nil
   231  }
   232  
   233  func (o *CollectionDataDB) GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error) {
   234  	ret := make(map[string]thirdparty.CollectionData)
   235  
   236  	getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s
   237  		FROM collection_data_cache
   238  		WHERE chain_id=? AND contract_address=?`, collectionDataColumns))
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	for _, id := range ids {
   244  		row := getData.QueryRow(
   245  			id.ChainID,
   246  			id.Address,
   247  		)
   248  		c, err := scanCollectionsDataRow(row)
   249  		if err == sql.ErrNoRows {
   250  			continue
   251  		} else if err != nil {
   252  			return nil, err
   253  		} else {
   254  			// Get traits from different table
   255  			c.Traits, err = getCollectionTraits(o.db, c.ID)
   256  			if err != nil {
   257  				return nil, err
   258  			}
   259  
   260  			// Get contract type from different table
   261  			c.ContractType, err = readContractType(o.db, c.ID)
   262  			if err != nil {
   263  				return nil, err
   264  			}
   265  
   266  			// Get socials from different table
   267  			c.Socials, err = getCollectionSocials(o.db, c.ID)
   268  			if err != nil {
   269  				return nil, err
   270  			}
   271  
   272  			ret[c.ID.HashKey()] = *c
   273  		}
   274  	}
   275  	return ret, nil
   276  }
   277  
   278  func (o *CollectionDataDB) GetSocialsForID(contractID thirdparty.ContractID) (*thirdparty.CollectionSocials, error) {
   279  	return getCollectionSocials(o.db, contractID)
   280  }
   281  
   282  func (o *CollectionDataDB) SetCollectionSocialsData(id thirdparty.ContractID, collectionSocials *thirdparty.CollectionSocials) (err error) {
   283  	tx, err := o.db.Begin()
   284  	if err != nil {
   285  		return err
   286  	}
   287  	defer func() {
   288  		if err == nil {
   289  			err = tx.Commit()
   290  			return
   291  		}
   292  		_ = tx.Rollback()
   293  	}()
   294  
   295  	// Insert new collections socials
   296  	if collectionSocials != nil {
   297  		err = upsertCollectionSocials(tx, id, collectionSocials)
   298  		if err != nil {
   299  			return err
   300  		}
   301  	}
   302  
   303  	return
   304  }
   305  
   306  func rowsToCollectionSocials(rows *sql.Rows) (*thirdparty.CollectionSocials, error) {
   307  	var socials *thirdparty.CollectionSocials
   308  	socials = nil
   309  	for rows.Next() {
   310  		var website string
   311  		var twitterHandle string
   312  		var provider string
   313  		err := rows.Scan(
   314  			&website,
   315  			&twitterHandle,
   316  			&provider,
   317  		)
   318  		if err != nil {
   319  			return nil, err
   320  		}
   321  		socials = &thirdparty.CollectionSocials{
   322  			Website:       website,
   323  			TwitterHandle: twitterHandle,
   324  			Provider:      provider}
   325  	}
   326  	return socials, nil
   327  }
   328  
   329  func getCollectionSocials(creator sqlite.StatementCreator, id thirdparty.ContractID) (*thirdparty.CollectionSocials, error) {
   330  	// Get socials
   331  	selectSocials, err := creator.Prepare(fmt.Sprintf(`SELECT %s
   332                  FROM collection_socials_cache
   333                  WHERE chain_id = ? AND contract_address = ?`, selectCollectionSocialsColumns))
   334  	if err != nil {
   335  		return nil, err
   336  	}
   337  
   338  	rows, err := selectSocials.Query(
   339  		id.ChainID,
   340  		id.Address,
   341  	)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	return rowsToCollectionSocials(rows)
   347  }
   348  
   349  func upsertCollectionSocials(creator sqlite.StatementCreator, id thirdparty.ContractID, socials *thirdparty.CollectionSocials) error {
   350  	// Insert socials
   351  	insertSocial, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_socials_cache (%s)
   352          VALUES (?, ?, ?, ?, ?)`, collectionSocialsColumns))
   353  	if err != nil {
   354  		return err
   355  	}
   356  
   357  	_, err = insertSocial.Exec(
   358  		id.ChainID,
   359  		id.Address,
   360  		socials.Provider,
   361  		socials.Website,
   362  		socials.TwitterHandle,
   363  	)
   364  	if err != nil {
   365  		return err
   366  	}
   367  
   368  	return nil
   369  }