github.com/status-im/status-go@v1.1.0/services/wallet/collectibles/collection_data_db.go (about) 1 package collectibles 2 3 import ( 4 "database/sql" 5 "fmt" 6 7 "github.com/status-im/status-go/services/wallet/thirdparty" 8 "github.com/status-im/status-go/sqlite" 9 ) 10 11 type CollectionDataStorage interface { 12 SetData(collections []thirdparty.CollectionData, allowUpdate bool) error 13 GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error) 14 GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error) 15 SetCollectionSocialsData(id thirdparty.ContractID, collectionSocials *thirdparty.CollectionSocials) error 16 GetSocialsForID(contractID thirdparty.ContractID) (*thirdparty.CollectionSocials, error) 17 } 18 19 type CollectionDataDB struct { 20 db *sql.DB 21 } 22 23 func NewCollectionDataDB(sqlDb *sql.DB) *CollectionDataDB { 24 return &CollectionDataDB{ 25 db: sqlDb, 26 } 27 } 28 29 const collectionDataColumns = "chain_id, contract_address, provider, name, slug, image_url, image_payload, community_id" 30 const collectionTraitsColumns = "chain_id, contract_address, trait_type, min, max" 31 const selectCollectionTraitsColumns = "trait_type, min, max" 32 const collectionSocialsColumns = "chain_id, contract_address, provider, website, twitter_handle" 33 const selectCollectionSocialsColumns = "website, twitter_handle, provider" 34 35 func rowsToCollectionTraits(rows *sql.Rows) (map[string]thirdparty.CollectionTrait, error) { 36 traits := make(map[string]thirdparty.CollectionTrait) 37 for rows.Next() { 38 var traitType string 39 var trait thirdparty.CollectionTrait 40 err := rows.Scan( 41 &traitType, 42 &trait.Min, 43 &trait.Max, 44 ) 45 if err != nil { 46 return nil, err 47 } 48 traits[traitType] = trait 49 } 50 return traits, nil 51 } 52 53 func getCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID) (map[string]thirdparty.CollectionTrait, error) { 54 // Get traits list 55 selectTraits, err := creator.Prepare(fmt.Sprintf(`SELECT %s 56 FROM collection_traits_cache 57 WHERE chain_id = ? AND contract_address = ?`, selectCollectionTraitsColumns)) 58 if err != nil { 59 return nil, err 60 } 61 62 rows, err := selectTraits.Query( 63 id.ChainID, 64 id.Address, 65 ) 66 if err != nil { 67 return nil, err 68 } 69 70 return rowsToCollectionTraits(rows) 71 } 72 73 func upsertCollectionTraits(creator sqlite.StatementCreator, id thirdparty.ContractID, traits map[string]thirdparty.CollectionTrait) error { 74 // Rremove old traits list 75 deleteTraits, err := creator.Prepare(`DELETE FROM collection_traits_cache WHERE chain_id = ? AND contract_address = ?`) 76 if err != nil { 77 return err 78 } 79 80 _, err = deleteTraits.Exec( 81 id.ChainID, 82 id.Address, 83 ) 84 if err != nil { 85 return err 86 } 87 88 // Insert new traits list 89 insertTrait, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_traits_cache (%s) 90 VALUES (?, ?, ?, ?, ?)`, collectionTraitsColumns)) 91 if err != nil { 92 return err 93 } 94 95 for traitType, trait := range traits { 96 _, err = insertTrait.Exec( 97 id.ChainID, 98 id.Address, 99 traitType, 100 trait.Min, 101 trait.Max, 102 ) 103 if err != nil { 104 return err 105 } 106 } 107 108 return nil 109 } 110 111 func setCollectionsData(creator sqlite.StatementCreator, collections []thirdparty.CollectionData, allowUpdate bool) error { 112 insertCollection, err := creator.Prepare(fmt.Sprintf(`%s INTO collection_data_cache (%s) 113 VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, insertStatement(allowUpdate), collectionDataColumns)) 114 if err != nil { 115 return err 116 } 117 118 for _, c := range collections { 119 _, err = insertCollection.Exec( 120 c.ID.ChainID, 121 c.ID.Address, 122 c.Provider, 123 c.Name, 124 c.Slug, 125 c.ImageURL, 126 c.ImagePayload, 127 c.CommunityID, 128 ) 129 if err != nil { 130 return err 131 } 132 133 err = upsertContractType(creator, c.ID, c.ContractType) 134 if err != nil { 135 return err 136 } 137 138 if allowUpdate { 139 err = upsertCollectionTraits(creator, c.ID, c.Traits) 140 if err != nil { 141 return err 142 } 143 144 if c.Socials != nil { 145 err = upsertCollectionSocials(creator, c.ID, c.Socials) 146 if err != nil { 147 return err 148 } 149 } 150 } 151 } 152 153 return nil 154 } 155 156 func (o *CollectionDataDB) SetData(collections []thirdparty.CollectionData, allowUpdate bool) (err error) { 157 tx, err := o.db.Begin() 158 if err != nil { 159 return err 160 } 161 defer func() { 162 if err == nil { 163 err = tx.Commit() 164 return 165 } 166 _ = tx.Rollback() 167 }() 168 169 // Insert new collections data 170 err = setCollectionsData(tx, collections, allowUpdate) 171 if err != nil { 172 return err 173 } 174 175 return 176 } 177 178 func scanCollectionsDataRow(row *sql.Row) (*thirdparty.CollectionData, error) { 179 c := thirdparty.CollectionData{ 180 Traits: make(map[string]thirdparty.CollectionTrait), 181 } 182 err := row.Scan( 183 &c.ID.ChainID, 184 &c.ID.Address, 185 &c.Provider, 186 &c.Name, 187 &c.Slug, 188 &c.ImageURL, 189 &c.ImagePayload, 190 &c.CommunityID, 191 ) 192 if err != nil { 193 return nil, err 194 } 195 return &c, nil 196 } 197 198 func (o *CollectionDataDB) GetIDsNotInDB(ids []thirdparty.ContractID) ([]thirdparty.ContractID, error) { 199 ret := make([]thirdparty.ContractID, 0, len(ids)) 200 idMap := make(map[string]thirdparty.ContractID, len(ids)) 201 202 // Ensure we don't have duplicates 203 for _, id := range ids { 204 idMap[id.HashKey()] = id 205 } 206 207 exists, err := o.db.Prepare(`SELECT EXISTS ( 208 SELECT 1 FROM collection_data_cache 209 WHERE chain_id=? AND contract_address=? 210 )`) 211 if err != nil { 212 return nil, err 213 } 214 215 for _, id := range idMap { 216 row := exists.QueryRow( 217 id.ChainID, 218 id.Address, 219 ) 220 var exists bool 221 err = row.Scan(&exists) 222 if err != nil { 223 return nil, err 224 } 225 if !exists { 226 ret = append(ret, id) 227 } 228 } 229 230 return ret, nil 231 } 232 233 func (o *CollectionDataDB) GetData(ids []thirdparty.ContractID) (map[string]thirdparty.CollectionData, error) { 234 ret := make(map[string]thirdparty.CollectionData) 235 236 getData, err := o.db.Prepare(fmt.Sprintf(`SELECT %s 237 FROM collection_data_cache 238 WHERE chain_id=? AND contract_address=?`, collectionDataColumns)) 239 if err != nil { 240 return nil, err 241 } 242 243 for _, id := range ids { 244 row := getData.QueryRow( 245 id.ChainID, 246 id.Address, 247 ) 248 c, err := scanCollectionsDataRow(row) 249 if err == sql.ErrNoRows { 250 continue 251 } else if err != nil { 252 return nil, err 253 } else { 254 // Get traits from different table 255 c.Traits, err = getCollectionTraits(o.db, c.ID) 256 if err != nil { 257 return nil, err 258 } 259 260 // Get contract type from different table 261 c.ContractType, err = readContractType(o.db, c.ID) 262 if err != nil { 263 return nil, err 264 } 265 266 // Get socials from different table 267 c.Socials, err = getCollectionSocials(o.db, c.ID) 268 if err != nil { 269 return nil, err 270 } 271 272 ret[c.ID.HashKey()] = *c 273 } 274 } 275 return ret, nil 276 } 277 278 func (o *CollectionDataDB) GetSocialsForID(contractID thirdparty.ContractID) (*thirdparty.CollectionSocials, error) { 279 return getCollectionSocials(o.db, contractID) 280 } 281 282 func (o *CollectionDataDB) SetCollectionSocialsData(id thirdparty.ContractID, collectionSocials *thirdparty.CollectionSocials) (err error) { 283 tx, err := o.db.Begin() 284 if err != nil { 285 return err 286 } 287 defer func() { 288 if err == nil { 289 err = tx.Commit() 290 return 291 } 292 _ = tx.Rollback() 293 }() 294 295 // Insert new collections socials 296 if collectionSocials != nil { 297 err = upsertCollectionSocials(tx, id, collectionSocials) 298 if err != nil { 299 return err 300 } 301 } 302 303 return 304 } 305 306 func rowsToCollectionSocials(rows *sql.Rows) (*thirdparty.CollectionSocials, error) { 307 var socials *thirdparty.CollectionSocials 308 socials = nil 309 for rows.Next() { 310 var website string 311 var twitterHandle string 312 var provider string 313 err := rows.Scan( 314 &website, 315 &twitterHandle, 316 &provider, 317 ) 318 if err != nil { 319 return nil, err 320 } 321 socials = &thirdparty.CollectionSocials{ 322 Website: website, 323 TwitterHandle: twitterHandle, 324 Provider: provider} 325 } 326 return socials, nil 327 } 328 329 func getCollectionSocials(creator sqlite.StatementCreator, id thirdparty.ContractID) (*thirdparty.CollectionSocials, error) { 330 // Get socials 331 selectSocials, err := creator.Prepare(fmt.Sprintf(`SELECT %s 332 FROM collection_socials_cache 333 WHERE chain_id = ? AND contract_address = ?`, selectCollectionSocialsColumns)) 334 if err != nil { 335 return nil, err 336 } 337 338 rows, err := selectSocials.Query( 339 id.ChainID, 340 id.Address, 341 ) 342 if err != nil { 343 return nil, err 344 } 345 346 return rowsToCollectionSocials(rows) 347 } 348 349 func upsertCollectionSocials(creator sqlite.StatementCreator, id thirdparty.ContractID, socials *thirdparty.CollectionSocials) error { 350 // Insert socials 351 insertSocial, err := creator.Prepare(fmt.Sprintf(`INSERT OR REPLACE INTO collection_socials_cache (%s) 352 VALUES (?, ?, ?, ?, ?)`, collectionSocialsColumns)) 353 if err != nil { 354 return err 355 } 356 357 _, err = insertSocial.Exec( 358 id.ChainID, 359 id.Address, 360 socials.Provider, 361 socials.Website, 362 socials.TwitterHandle, 363 ) 364 if err != nil { 365 return err 366 } 367 368 return nil 369 }