github.com/status-im/status-go@v1.1.0/services/wallet/collectibles/filter_test.go (about)

     1  package collectibles
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"math/big"
     7  	"testing"
     8  
     9  	"github.com/ethereum/go-ethereum/common"
    10  
    11  	"github.com/status-im/status-go/protocol/communities/token"
    12  	"github.com/status-im/status-go/services/wallet/bigint"
    13  	w_common "github.com/status-im/status-go/services/wallet/common"
    14  	"github.com/status-im/status-go/services/wallet/thirdparty"
    15  	"github.com/status-im/status-go/t/helpers"
    16  	"github.com/status-im/status-go/walletdatabase"
    17  
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func setupTestFilterDB(t *testing.T) (db *sql.DB, close func()) {
    22  	db, err := helpers.SetupTestMemorySQLDB(walletdatabase.DbInitializer{})
    23  	require.NoError(t, err)
    24  
    25  	return db, func() {
    26  		require.NoError(t, db.Close())
    27  	}
    28  }
    29  
    30  func TestFilterOwnedCollectibles(t *testing.T) {
    31  	db, close := setupTestFilterDB(t)
    32  	defer close()
    33  
    34  	oDB := NewOwnershipDB(db)
    35  	cDB := NewCollectibleDataDB(db)
    36  
    37  	const nData = 50
    38  	data := thirdparty.GenerateTestCollectiblesData(nData)
    39  	communityData := thirdparty.GenerateTestCollectiblesCommunityData(nData)
    40  
    41  	ownerAddresses := []common.Address{
    42  		common.HexToAddress("0x1234"),
    43  		common.HexToAddress("0x5678"),
    44  		common.HexToAddress("0xABCD"),
    45  	}
    46  	randomAddress := common.HexToAddress("0xFFFF")
    47  
    48  	dataPerID := make(map[string]thirdparty.CollectibleData)
    49  	communityDataPerID := make(map[string]thirdparty.CollectibleCommunityInfo)
    50  	balancesPerChainIDAndOwner := make(map[w_common.ChainID]map[common.Address]thirdparty.TokenBalancesPerContractAddress)
    51  
    52  	var err error
    53  
    54  	var commonID thirdparty.CollectibleUniqueID
    55  
    56  	for i := 0; i < nData; i++ {
    57  		iData := data[i]
    58  		iCommunityData := communityData[i]
    59  
    60  		if i == 1 {
    61  			// Insert a duplicate ID to represent 2 owners having the same ERC1155 collectible
    62  			iData = data[0]
    63  			iCommunityData = communityData[0]
    64  			commonID = iData.ID
    65  		}
    66  
    67  		dataPerID[iData.ID.HashKey()] = iData
    68  		communityDataPerID[iData.ID.HashKey()] = iCommunityData
    69  
    70  		chainID := iData.ID.ContractID.ChainID
    71  		ownerAddress := ownerAddresses[i%len(ownerAddresses)]
    72  
    73  		if _, ok := balancesPerChainIDAndOwner[chainID]; !ok {
    74  			balancesPerChainIDAndOwner[chainID] = make(map[common.Address]thirdparty.TokenBalancesPerContractAddress)
    75  		}
    76  		if _, ok := balancesPerChainIDAndOwner[chainID][ownerAddress]; !ok {
    77  			balancesPerChainIDAndOwner[chainID][ownerAddress] = make(thirdparty.TokenBalancesPerContractAddress)
    78  		}
    79  
    80  		contractAddress := iData.ID.ContractID.Address
    81  		if _, ok := balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress]; !ok {
    82  			balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress] = make([]thirdparty.TokenBalance, 0, len(data))
    83  		}
    84  
    85  		tokenBalance := thirdparty.TokenBalance{
    86  			TokenID: iData.ID.TokenID,
    87  			Balance: &bigint.BigInt{Int: big.NewInt(int64(i % 10))},
    88  		}
    89  		balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress] = append(balancesPerChainIDAndOwner[chainID][ownerAddress][contractAddress], tokenBalance)
    90  	}
    91  
    92  	timestamp := int64(1234567890)
    93  
    94  	for chainID, balancesPerOwner := range balancesPerChainIDAndOwner {
    95  		for ownerAddress, balances := range balancesPerOwner {
    96  			_, _, _, err = oDB.Update(chainID, ownerAddress, balances, timestamp)
    97  			require.NoError(t, err)
    98  		}
    99  	}
   100  
   101  	err = cDB.SetData(data, true)
   102  	require.NoError(t, err)
   103  	for i := 0; i < nData; i++ {
   104  		err = cDB.SetCommunityInfo(data[i].ID, communityData[i])
   105  		require.NoError(t, err)
   106  	}
   107  
   108  	var filter Filter
   109  	var filterIDs []thirdparty.CollectibleUniqueID
   110  	var expectedIDs []thirdparty.CollectibleUniqueID
   111  	var tmpIDs []thirdparty.CollectibleUniqueID
   112  
   113  	ctx := context.Background()
   114  
   115  	filterChains := []w_common.ChainID{w_common.ChainID(1), w_common.ChainID(2)}
   116  	filterAddresses := []common.Address{ownerAddresses[0], ownerAddresses[1], ownerAddresses[2], randomAddress}
   117  
   118  	// Test common case
   119  	filter = allFilter()
   120  
   121  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   122  	require.NoError(t, err)
   123  
   124  	expectedIDs = tmpIDs
   125  
   126  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   127  	require.NoError(t, err)
   128  	require.Equal(t, expectedIDs, filterIDs)
   129  
   130  	// Test only non-community
   131  	filter = allFilter()
   132  	filter.FilterCommunity = OnlyNonCommunity
   133  
   134  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   135  	require.NoError(t, err)
   136  
   137  	expectedIDs = nil
   138  	for _, id := range tmpIDs {
   139  		if dataPerID[id.HashKey()].CommunityID == "" {
   140  			expectedIDs = append(expectedIDs, id)
   141  		}
   142  	}
   143  
   144  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   145  	require.NoError(t, err)
   146  	require.Equal(t, expectedIDs, filterIDs)
   147  
   148  	// Test only community
   149  	filter = allFilter()
   150  	filter.FilterCommunity = OnlyCommunity
   151  
   152  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   153  	require.NoError(t, err)
   154  
   155  	expectedIDs = nil
   156  	for _, id := range tmpIDs {
   157  		if dataPerID[id.HashKey()].CommunityID != "" {
   158  			expectedIDs = append(expectedIDs, id)
   159  		}
   160  	}
   161  
   162  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   163  	require.NoError(t, err)
   164  	require.Equal(t, expectedIDs, filterIDs)
   165  
   166  	// Test specific community
   167  	communityIDa := data[0].CommunityID
   168  	communityIDb := data[1].CommunityID
   169  	communityIDs := []string{communityIDa, communityIDb}
   170  
   171  	filter = allFilter()
   172  	filter.CommunityIDs = communityIDs
   173  
   174  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   175  	require.NoError(t, err)
   176  
   177  	expectedIDs = nil
   178  	for _, id := range tmpIDs {
   179  		if dataPerID[id.HashKey()].CommunityID == communityIDa || dataPerID[id.HashKey()].CommunityID == communityIDb {
   180  			expectedIDs = append(expectedIDs, id)
   181  		}
   182  	}
   183  
   184  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   185  	require.NoError(t, err)
   186  	require.Equal(t, expectedIDs, filterIDs)
   187  
   188  	// Test specific privileges level
   189  	privilegeLevel := token.PrivilegesLevel(2)
   190  
   191  	filter = allFilter()
   192  	filter.CommunityPrivilegesLevels = []token.PrivilegesLevel{privilegeLevel}
   193  
   194  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   195  	require.NoError(t, err)
   196  
   197  	expectedIDs = nil
   198  	for _, id := range tmpIDs {
   199  		if communityDataPerID[id.HashKey()].PrivilegesLevel == privilegeLevel {
   200  			expectedIDs = append(expectedIDs, id)
   201  		}
   202  	}
   203  
   204  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   205  	require.NoError(t, err)
   206  	require.Equal(t, expectedIDs, filterIDs)
   207  
   208  	// Test specific collectible IDs
   209  	tmpIDs, err = oDB.GetOwnedCollectibles(filterChains, filterAddresses, 0, nData)
   210  	require.NoError(t, err)
   211  
   212  	filter = allFilter()
   213  	for i := 0; i < 5; i++ {
   214  		filter.CollectibleIDs = append(filter.CollectibleIDs, tmpIDs[i*2])
   215  	}
   216  	expectedIDs = filter.CollectibleIDs
   217  
   218  	filter.CollectibleIDs = append(filter.CollectibleIDs, thirdparty.CollectibleUniqueID{
   219  		ContractID: thirdparty.ContractID{
   220  			ChainID: w_common.ChainID(1),
   221  			Address: common.HexToAddress("0x1234"),
   222  		},
   223  		TokenID: &bigint.BigInt{Int: big.NewInt(9999999)},
   224  	})
   225  
   226  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   227  	require.NoError(t, err)
   228  	require.Equal(t, expectedIDs, filterIDs)
   229  
   230  	// Test collectible ID owned by both accounts 0 and 1
   231  	filterChains = []w_common.ChainID{commonID.ContractID.ChainID}
   232  	filterAddresses = []common.Address{ownerAddresses[0], ownerAddresses[1]}
   233  
   234  	filter = allFilter()
   235  	filter.CollectibleIDs = append(filter.CollectibleIDs, commonID)
   236  	expectedIDs = filter.CollectibleIDs
   237  
   238  	filterIDs, err = filterOwnedCollectibles(ctx, db, filterChains, filterAddresses, filter, 0, nData)
   239  	require.NoError(t, err)
   240  	require.Equal(t, expectedIDs, filterIDs)
   241  }