github.com/status-im/status-go@v1.1.0/services/wallet/collectibles/filter_test.go (about) 1 package collectibles 2 3 import ( 4 "context" 5 "database/sql" 6 "math/big" 7 "testing" 8 9 "github.com/ethereum/go-ethereum/common" 10 11 "github.com/status-im/status-go/protocol/communities/token" 12 "github.com/status-im/status-go/services/wallet/bigint" 13 w_common "github.com/status-im/status-go/services/wallet/common" 14 "github.com/status-im/status-go/services/wallet/thirdparty" 15 "github.com/status-im/status-go/t/helpers" 16 "github.com/status-im/status-go/walletdatabase" 17 18 "github.com/stretchr/testify/require" 19 ) 20 21 func setupTestFilterDB(t *testing.T) (db *sql.DB, close func()) { 22 db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{}) 23 require.NoError(t, err) 24 25 return db, func() { 26 require.NoError(t, db.Close()) 27 } 28 } 29 30 func TestFilterOwnedCollectibles(t *testing.T) { 31 db, close := setupTestFilterDB(t) 32 defer close() 33 34 oDB := NewOwnershipDB(db) 35 cDB := NewCollectibleDataDB(db) 36 37 const nData = 50 38 data := thirdparty.GenerateTestCollectiblesData(nData) 39 communityData := thirdparty.GenerateTestCollectiblesCommunityData(nData) 40 41 ownerAddresses := []common.Address{ 42 common.HexToAddress("0x1234"), 43 common.HexToAddress("0x5678"), 44 common.HexToAddress("0xABCD"), 45 } 46 randomAddress := common.HexToAddress("0xFFFF") 47 48 dataPerID := make(map[string]thirdparty.CollectibleData) 49 communityDataPerID := make(map[string]thirdparty.CollectibleCommunityInfo) 50 balancesPerChainIDAndOwner := make(map[w_common.ChainID]map[common.Address]thirdparty.TokenBalancesPerContractAddress) 51 52 var err error 53 54 var commonID thirdparty.CollectibleUniqueID 55 56 for i := 0; i < nData; i++ { 57 iData := data[i] 58 iCommunityData := communityData[i] 59 60 if i == 1 { 61 // Insert a duplicate ID to represent 2 owners having the same ERC1155 collectible 62 iData = data[0] 63 iCommunityData = communityData[0] 64 commonID = iData.ID 65 } 66 67 dataPerID[iData.ID.HashKey()] = iData 68 communityDataPerID[iData.ID.HashKey()] = iCommunityData 69 70 chainID := iData.ID.ContractID.ChainID 71 ownerAddress := ownerAddresses[i%len(ownerAddresses)] 72 73 if _, ok := balancesPerChainIDAndOwner[chainID]; !ok { 74 balancesPerChainIDAndOwner[chainID] = make(map[common.Address]thirdparty.TokenBalancesPerContractAddress) 75 } 76 if _, ok := balancesPerChainIDAndOwner[chainID][ownerAddress]; !ok { 77 balancesPerChainIDAndOwner[chainID][ownerAddress] = make(thirdparty.TokenBalancesPerContractAddress) 78 } 79 80 contractAddress := iData.ID.ContractID.Address 81 if _, ok := balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress]; !ok { 82 balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress] = make([]thirdparty.TokenBalance, 0, len(data)) 83 } 84 85 tokenBalance := thirdparty.TokenBalance{ 86 TokenID: iData.ID.TokenID, 87 Balance: &bigint.BigInt{Int: big.NewInt(int64(i % 10))}, 88 } 89 balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress] = append(balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress], tokenBalance) 90 } 91 92 timestamp := int64(1234567890) 93 94 for chainID, balancesPerOwner := range balancesPerChainIDAndOwner { 95 for ownerAddress, balances := range balancesPerOwner { 96 _, _, _, err = oDB.Update(chainID, ownerAddress, balances, timestamp) 97 require.NoError(t, err) 98 } 99 } 100 101 err = cDB.SetData(data, true) 102 require.NoError(t, err) 103 for i := 0; i < nData; i++ { 104 err = cDB.SetCommunityInfo(data[i].ID, communityData[i]) 105 require.NoError(t, err) 106 } 107 108 var filter Filter 109 var filterIDs []thirdparty.CollectibleUniqueID 110 var expectedIDs []thirdparty.CollectibleUniqueID 111 var tmpIDs []thirdparty.CollectibleUniqueID 112 113 ctx := context.Background() 114 115 filterChains := []w_common.ChainID{w_common.ChainID(1), w_common.ChainID(2)} 116 filterAddresses := []common.Address{ownerAddresses[0], ownerAddresses[1], ownerAddresses[2], randomAddress} 117 118 // Test common case 119 filter = allFilter() 120 121 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 122 require.NoError(t, err) 123 124 expectedIDs = tmpIDs 125 126 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 127 require.NoError(t, err) 128 require.Equal(t, expectedIDs, filterIDs) 129 130 // Test only non-community 131 filter = allFilter() 132 filter.FilterCommunity = OnlyNonCommunity 133 134 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 135 require.NoError(t, err) 136 137 expectedIDs = nil 138 for _, id := range tmpIDs { 139 if dataPerID[id.HashKey()].CommunityID == "" { 140 expectedIDs = append(expectedIDs, id) 141 } 142 } 143 144 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 145 require.NoError(t, err) 146 require.Equal(t, expectedIDs, filterIDs) 147 148 // Test only community 149 filter = allFilter() 150 filter.FilterCommunity = OnlyCommunity 151 152 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 153 require.NoError(t, err) 154 155 expectedIDs = nil 156 for _, id := range tmpIDs { 157 if dataPerID[id.HashKey()].CommunityID != "" { 158 expectedIDs = append(expectedIDs, id) 159 } 160 } 161 162 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 163 require.NoError(t, err) 164 require.Equal(t, expectedIDs, filterIDs) 165 166 // Test specific community 167 communityIDa := data[0].CommunityID 168 communityIDb := data[1].CommunityID 169 communityIDs := []string{communityIDa, communityIDb} 170 171 filter = allFilter() 172 filter.CommunityIDs = communityIDs 173 174 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 175 require.NoError(t, err) 176 177 expectedIDs = nil 178 for _, id := range tmpIDs { 179 if dataPerID[id.HashKey()].CommunityID == communityIDa || dataPerID[id.HashKey()].CommunityID == communityIDb { 180 expectedIDs = append(expectedIDs, id) 181 } 182 } 183 184 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 185 require.NoError(t, err) 186 require.Equal(t, expectedIDs, filterIDs) 187 188 // Test specific privileges level 189 privilegeLevel := token.PrivilegesLevel(2) 190 191 filter = allFilter() 192 filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{privilegeLevel} 193 194 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 195 require.NoError(t, err) 196 197 expectedIDs = nil 198 for _, id := range tmpIDs { 199 if communityDataPerID[id.HashKey()].PrivilegesLevel == privilegeLevel { 200 expectedIDs = append(expectedIDs, id) 201 } 202 } 203 204 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 205 require.NoError(t, err) 206 require.Equal(t, expectedIDs, filterIDs) 207 208 // Test specific collectible IDs 209 tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData) 210 require.NoError(t, err) 211 212 filter = allFilter() 213 for i := 0; i < 5; i++ { 214 filter.CollectibleIDs = append(filter.CollectibleIDs, tmpIDs[i*2]) 215 } 216 expectedIDs = filter.CollectibleIDs 217 218 filter.CollectibleIDs = append(filter.CollectibleIDs, thirdparty.CollectibleUniqueID{ 219 ContractID: thirdparty.ContractID{ 220 ChainID: w_common.ChainID(1), 221 Address: common.HexToAddress("0x1234"), 222 }, 223 TokenID: &bigint.BigInt{Int: big.NewInt(9999999)}, 224 }) 225 226 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 227 require.NoError(t, err) 228 require.Equal(t, expectedIDs, filterIDs) 229 230 // Test collectible ID owned by both accounts 0 and 1 231 filterChains = []w_common.ChainID{commonID.ContractID.ChainID} 232 filterAddresses = []common.Address{ownerAddresses[0], ownerAddresses[1]} 233 234 filter = allFilter() 235 filter.CollectibleIDs = append(filter.CollectibleIDs, commonID) 236 expectedIDs = filter.CollectibleIDs 237 238 filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData) 239 require.NoError(t, err) 240 require.Equal(t, expectedIDs, filterIDs) 241 }