github.com/status-im/status-go@v1.1.0/services/wallet/activity/service.go (about)

     1  package activity
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"errors"
     8  	"strconv"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/ethereum/go-ethereum/common"
    14  	"github.com/ethereum/go-ethereum/event"
    15  	"github.com/ethereum/go-ethereum/log"
    16  
    17  	"github.com/status-im/status-go/multiaccounts/accounts"
    18  	"github.com/status-im/status-go/services/wallet/async"
    19  	"github.com/status-im/status-go/services/wallet/collectibles"
    20  	w_common "github.com/status-im/status-go/services/wallet/common"
    21  	"github.com/status-im/status-go/services/wallet/thirdparty"
    22  	"github.com/status-im/status-go/services/wallet/token"
    23  	"github.com/status-im/status-go/services/wallet/walletevent"
    24  	"github.com/status-im/status-go/transactions"
    25  )
    26  
    27  const (
    28  	// EventActivityFilteringDone contains a FilterResponse payload
    29  	EventActivityFilteringDone          walletevent.EventType = "wallet-activity-filtering-done"
    30  	EventActivityFilteringUpdate        walletevent.EventType = "wallet-activity-filtering-entries-updated"
    31  	EventActivityGetRecipientsDone      walletevent.EventType = "wallet-activity-get-recipients-result"
    32  	EventActivityGetOldestTimestampDone walletevent.EventType = "wallet-activity-get-oldest-timestamp-result"
    33  	EventActivityGetCollectibles        walletevent.EventType = "wallet-activity-get-collectibles"
    34  
    35  	// EventActivitySessionUpdated contains a SessionUpdate payload
    36  	EventActivitySessionUpdated walletevent.EventType = "wallet-activity-session-updated"
    37  )
    38  
    39  var (
    40  	filterTask = async.TaskType{
    41  		ID:     1,
    42  		Policy: async.ReplacementPolicyCancelOld,
    43  	}
    44  	getRecipientsTask = async.TaskType{
    45  		ID:     2,
    46  		Policy: async.ReplacementPolicyIgnoreNew,
    47  	}
    48  	getOldestTimestampTask = async.TaskType{
    49  		ID:     3,
    50  		Policy: async.ReplacementPolicyCancelOld,
    51  	}
    52  	getCollectiblesTask = async.TaskType{
    53  		ID:     4,
    54  		Policy: async.ReplacementPolicyCancelOld,
    55  	}
    56  )
    57  
    58  // Service provides an async interface, ensuring only one filter request, of each type, is running at a time. It also provides lazy load of NFT info and token mapping
    59  type Service struct {
    60  	db           *sql.DB
    61  	accountsDB   *accounts.Database
    62  	tokenManager token.ManagerInterface
    63  	collectibles collectibles.ManagerInterface
    64  	eventFeed    *event.Feed
    65  
    66  	scheduler *async.MultiClientScheduler
    67  
    68  	sessions      map[SessionID]*Session
    69  	lastSessionID atomic.Int32
    70  	subscriptions event.Subscription
    71  	ch            chan walletevent.Event
    72  	// sessionsRWMutex is used to protect all sessions related members
    73  	sessionsRWMutex  sync.RWMutex
    74  	debounceDuration time.Duration
    75  
    76  	pendingTracker *transactions.PendingTxTracker
    77  }
    78  
    79  func (s *Service) nextSessionID() SessionID {
    80  	return SessionID(s.lastSessionID.Add(1))
    81  }
    82  
    83  func NewService(db *sql.DB, accountsDB *accounts.Database, tokenManager token.ManagerInterface, collectibles collectibles.ManagerInterface, eventFeed *event.Feed, pendingTracker *transactions.PendingTxTracker) *Service {
    84  	return &Service{
    85  		db:           db,
    86  		accountsDB:   accountsDB,
    87  		tokenManager: tokenManager,
    88  		collectibles: collectibles,
    89  		eventFeed:    eventFeed,
    90  		scheduler:    async.NewMultiClientScheduler(),
    91  
    92  		sessions: make(map[SessionID]*Session),
    93  		// here to be overwritten by tests
    94  		debounceDuration: 1 * time.Second,
    95  
    96  		pendingTracker: pendingTracker,
    97  	}
    98  }
    99  
   100  type ErrorCode = int
   101  
   102  const (
   103  	ErrorCodeSuccess ErrorCode = iota + 1
   104  	ErrorCodeTaskCanceled
   105  	ErrorCodeFailed
   106  )
   107  
   108  type FilterResponse struct {
   109  	Activities []Entry `json:"activities"`
   110  	Offset     int     `json:"offset"`
   111  	// Used to indicate that there might be more entries that were not returned
   112  	// based on a simple heuristic
   113  	HasMore   bool      `json:"hasMore"`
   114  	ErrorCode ErrorCode `json:"errorCode"`
   115  }
   116  
   117  // FilterActivityAsync allows only one filter task to run at a time
   118  // it cancels the current one if a new one is started
   119  // and should not expect other owners to have data in one of the queried tables
   120  //
   121  // All calls will trigger an EventActivityFilteringDone event with the result of the filtering
   122  // TODO #12120: replace with session based APIs
   123  func (s *Service) FilterActivityAsync(requestID int32, addresses []common.Address, chainIDs []w_common.ChainID, filter Filter, offset int, limit int) {
   124  	s.scheduler.Enqueue(requestID, filterTask, func(ctx context.Context) (interface{}, error) {
   125  		allAddresses := s.areAllAddresses(addresses)
   126  		activities, err := getActivityEntries(ctx, s.getDeps(), addresses, allAddresses, chainIDs, filter, offset, limit)
   127  		return activities, err
   128  	}, func(result interface{}, taskType async.TaskType, err error) {
   129  		res := FilterResponse{
   130  			ErrorCode: ErrorCodeFailed,
   131  		}
   132  
   133  		if errors.Is(err, context.Canceled) || errors.Is(err, async.ErrTaskOverwritten) {
   134  			res.ErrorCode = ErrorCodeTaskCanceled
   135  		} else if err == nil {
   136  			activities := result.([]Entry)
   137  			res.Activities = activities
   138  			res.Offset = offset
   139  			res.HasMore = len(activities) == limit
   140  			res.ErrorCode = ErrorCodeSuccess
   141  		}
   142  
   143  		sendResponseEvent(s.eventFeed, &requestID, EventActivityFilteringDone, res, err)
   144  
   145  		s.getActivityDetailsAsync(requestID, res.Activities)
   146  	})
   147  }
   148  
   149  type CollectibleHeader struct {
   150  	ID       thirdparty.CollectibleUniqueID `json:"id"`
   151  	Name     string                         `json:"name"`
   152  	ImageURL string                         `json:"image_url"`
   153  }
   154  
   155  type GetollectiblesResponse struct {
   156  	Collectibles []CollectibleHeader `json:"collectibles"`
   157  	Offset       int                 `json:"offset"`
   158  	// Used to indicate that there might be more collectibles that were not returned
   159  	// based on a simple heuristic
   160  	HasMore   bool      `json:"hasMore"`
   161  	ErrorCode ErrorCode `json:"errorCode"`
   162  }
   163  
   164  func (s *Service) GetActivityCollectiblesAsync(requestID int32, chainIDs []w_common.ChainID, addresses []common.Address, offset int, limit int) {
   165  	s.scheduler.Enqueue(requestID, getCollectiblesTask, func(ctx context.Context) (interface{}, error) {
   166  		collectibles, err := GetActivityCollectibles(ctx, s.db, chainIDs, addresses, offset, limit)
   167  
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  
   172  		data, err := s.collectibles.FetchAssetsByCollectibleUniqueID(ctx, collectibles, true)
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  
   177  		res := make([]CollectibleHeader, 0, len(data))
   178  
   179  		for _, c := range data {
   180  			res = append(res, CollectibleHeader{
   181  				ID:       c.CollectibleData.ID,
   182  				Name:     c.CollectibleData.Name,
   183  				ImageURL: c.CollectibleData.ImageURL,
   184  			})
   185  		}
   186  
   187  		return res, err
   188  	}, func(result interface{}, taskType async.TaskType, err error) {
   189  		res := GetollectiblesResponse{
   190  			ErrorCode: ErrorCodeFailed,
   191  		}
   192  
   193  		if errors.Is(err, context.Canceled) || errors.Is(err, async.ErrTaskOverwritten) {
   194  			res.ErrorCode = ErrorCodeTaskCanceled
   195  		} else if err == nil {
   196  			collectibles := result.([]CollectibleHeader)
   197  			res.Collectibles = collectibles
   198  			res.Offset = offset
   199  			res.HasMore = len(collectibles) == limit
   200  			res.ErrorCode = ErrorCodeSuccess
   201  		}
   202  
   203  		sendResponseEvent(s.eventFeed, &requestID, EventActivityGetCollectibles, res, err)
   204  	})
   205  }
   206  
   207  func (s *Service) GetMultiTxDetails(ctx context.Context, multiTxID int) (*EntryDetails, error) {
   208  	return getMultiTxDetails(ctx, s.db, multiTxID)
   209  }
   210  
   211  func (s *Service) GetTxDetails(ctx context.Context, id string) (*EntryDetails, error) {
   212  	return getTxDetails(ctx, s.db, id)
   213  }
   214  
   215  // getActivityDetails check if any of the entries have details that are not loaded then fetch and emit result
   216  func (s *Service) getActivityDetails(ctx context.Context, entries []Entry) ([]*EntryData, error) {
   217  	res := make([]*EntryData, 0)
   218  	var err error
   219  	ids := make([]thirdparty.CollectibleUniqueID, 0)
   220  	entriesForIds := make(map[string][]*Entry)
   221  
   222  	idExists := func(ids []thirdparty.CollectibleUniqueID, id *thirdparty.CollectibleUniqueID) bool {
   223  		for _, existingID := range ids {
   224  			if existingID.Same(id) {
   225  				return true
   226  			}
   227  		}
   228  		return false
   229  	}
   230  
   231  	for i := range entries {
   232  		if !entries[i].isNFT() {
   233  			continue
   234  		}
   235  
   236  		id := entries[i].anyIdentity()
   237  		if id == nil {
   238  			continue
   239  		}
   240  
   241  		entriesForIds[id.HashKey()] = append(entriesForIds[id.HashKey()], &entries[i])
   242  		if !idExists(ids, id) {
   243  			ids = append(ids, *id)
   244  		}
   245  	}
   246  
   247  	if len(ids) == 0 {
   248  		return nil, nil
   249  	}
   250  
   251  	log.Debug("wallet.activity.Service lazyLoadDetails", "entries.len", len(entries), "ids.len", len(ids))
   252  
   253  	colData, err := s.collectibles.FetchAssetsByCollectibleUniqueID(ctx, ids, true)
   254  	if err != nil {
   255  		log.Error("Error fetching collectible details", "error", err)
   256  		return nil, err
   257  	}
   258  
   259  	for _, col := range colData {
   260  		nftName := w_common.NewAndSet(col.CollectibleData.Name)
   261  		nftURL := w_common.NewAndSet(col.CollectibleData.ImageURL)
   262  		for i := range ids {
   263  			if !col.CollectibleData.ID.Same(&ids[i]) {
   264  				continue
   265  			}
   266  
   267  			entryList, ok := entriesForIds[ids[i].HashKey()]
   268  			if !ok {
   269  				continue
   270  			}
   271  			for _, e := range entryList {
   272  				data := &EntryData{
   273  					NftName: nftName,
   274  					NftURL:  nftURL,
   275  				}
   276  				if e.payloadType == MultiTransactionPT {
   277  					data.ID = w_common.NewAndSet(e.id)
   278  				} else {
   279  					data.Transaction = e.transaction
   280  				}
   281  
   282  				data.PayloadType = e.payloadType
   283  				res = append(res, data)
   284  			}
   285  		}
   286  	}
   287  	return res, nil
   288  }
   289  
   290  type GetRecipientsResponse struct {
   291  	Addresses []common.Address `json:"addresses"`
   292  	Offset    int              `json:"offset"`
   293  	// Used to indicate that there might be more entries that were not returned
   294  	// based on a simple heuristic
   295  	HasMore   bool      `json:"hasMore"`
   296  	ErrorCode ErrorCode `json:"errorCode"`
   297  }
   298  
   299  // GetRecipientsAsync returns true if a task is already running or scheduled due to a previous call; meaning that
   300  // this call won't receive an answer but client should rely on the answer from the previous call.
   301  // If no task is already scheduled false will be returned
   302  func (s *Service) GetRecipientsAsync(requestID int32, chainIDs []w_common.ChainID, addresses []common.Address, offset int, limit int) bool {
   303  	return s.scheduler.Enqueue(requestID, getRecipientsTask, func(ctx context.Context) (interface{}, error) {
   304  		var err error
   305  		result := &GetRecipientsResponse{
   306  			Offset:    offset,
   307  			ErrorCode: ErrorCodeSuccess,
   308  		}
   309  		result.Addresses, result.HasMore, err = GetRecipients(ctx, s.db, chainIDs, addresses, offset, limit)
   310  		return result, err
   311  	}, func(result interface{}, taskType async.TaskType, err error) {
   312  		res := result.(*GetRecipientsResponse)
   313  		if errors.Is(err, context.Canceled) || errors.Is(err, async.ErrTaskOverwritten) {
   314  			res.ErrorCode = ErrorCodeTaskCanceled
   315  		} else if err != nil {
   316  			res.ErrorCode = ErrorCodeFailed
   317  		}
   318  
   319  		sendResponseEvent(s.eventFeed, &requestID, EventActivityGetRecipientsDone, result, err)
   320  	})
   321  }
   322  
   323  type GetOldestTimestampResponse struct {
   324  	Timestamp int64     `json:"timestamp"`
   325  	ErrorCode ErrorCode `json:"errorCode"`
   326  }
   327  
   328  func (s *Service) GetOldestTimestampAsync(requestID int32, addresses []common.Address) {
   329  	s.scheduler.Enqueue(requestID, getOldestTimestampTask, func(ctx context.Context) (interface{}, error) {
   330  		timestamp, err := GetOldestTimestamp(ctx, s.db, addresses)
   331  		return timestamp, err
   332  	}, func(result interface{}, taskType async.TaskType, err error) {
   333  		res := GetOldestTimestampResponse{
   334  			ErrorCode: ErrorCodeFailed,
   335  		}
   336  
   337  		if errors.Is(err, context.Canceled) || errors.Is(err, async.ErrTaskOverwritten) {
   338  			res.ErrorCode = ErrorCodeTaskCanceled
   339  		} else if err == nil {
   340  			res.Timestamp = int64(result.(uint64))
   341  			res.ErrorCode = ErrorCodeSuccess
   342  		}
   343  
   344  		sendResponseEvent(s.eventFeed, &requestID, EventActivityGetOldestTimestampDone, res, err)
   345  	})
   346  }
   347  
   348  func (s *Service) CancelFilterTask(requestID int32) {
   349  	s.scheduler.Enqueue(requestID, filterTask, func(ctx context.Context) (interface{}, error) {
   350  		// No-op
   351  		return nil, nil
   352  	}, func(result interface{}, taskType async.TaskType, err error) {
   353  		// Ignore result
   354  	})
   355  }
   356  
   357  func (s *Service) Stop() {
   358  	s.scheduler.Stop()
   359  }
   360  
   361  func (s *Service) getDeps() FilterDependencies {
   362  	return FilterDependencies{
   363  		db: s.db,
   364  		tokenSymbol: func(t Token) string {
   365  			info := s.tokenManager.LookupTokenIdentity(uint64(t.ChainID), t.Address, t.TokenType == Native)
   366  			if info == nil {
   367  				return ""
   368  			}
   369  			return info.Symbol
   370  		},
   371  		tokenFromSymbol: func(chainID *w_common.ChainID, symbol string) *Token {
   372  			var cID *uint64
   373  			if chainID != nil {
   374  				cID = new(uint64)
   375  				*cID = uint64(*chainID)
   376  			}
   377  			t, detectedNative := s.tokenManager.LookupToken(cID, symbol)
   378  			if t == nil {
   379  				return nil
   380  			}
   381  			tokenType := Native
   382  			if !detectedNative {
   383  				tokenType = Erc20
   384  			}
   385  			return &Token{
   386  				TokenType: tokenType,
   387  				ChainID:   w_common.ChainID(t.ChainID),
   388  				Address:   t.Address,
   389  			}
   390  		},
   391  		currentTimestamp: func() int64 {
   392  			return time.Now().Unix()
   393  		},
   394  	}
   395  }
   396  
   397  func sendResponseEvent(eventFeed *event.Feed, requestID *int32, eventType walletevent.EventType, payloadObj interface{}, resErr error) {
   398  	payload, err := json.Marshal(payloadObj)
   399  	if err != nil {
   400  		log.Error("Error marshaling response: %v; result error: %w", err, resErr)
   401  	} else {
   402  		err = resErr
   403  	}
   404  
   405  	requestIDStr := nilStr
   406  	if requestID != nil {
   407  		requestIDStr = strconv.Itoa(int(*requestID))
   408  	}
   409  	log.Debug("wallet.api.activity.Service RESPONSE", "requestID", requestIDStr, "eventType", eventType, "error", err, "payload.len", len(payload))
   410  
   411  	event := walletevent.Event{
   412  		Type:    eventType,
   413  		Message: string(payload),
   414  	}
   415  
   416  	if requestID != nil {
   417  		event.RequestID = new(int)
   418  		*event.RequestID = int(*requestID)
   419  	}
   420  
   421  	eventFeed.Send(event)
   422  }
   423  
   424  func (s *Service) getWalletAddreses() ([]common.Address, error) {
   425  	ethAddresses, err := s.accountsDB.GetWalletAddresses()
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	addresses := make([]common.Address, 0, len(ethAddresses))
   431  	for _, ethAddress := range ethAddresses {
   432  		addresses = append(addresses, common.Address(ethAddress))
   433  	}
   434  
   435  	return addresses, nil
   436  }
   437  
   438  func (s *Service) areAllAddresses(addresses []common.Address) bool {
   439  	// Compare with addresses in accountsDB
   440  	walletAddresses, err := s.getWalletAddreses()
   441  	if err != nil {
   442  		log.Error("Error getting wallet addresses", "error", err)
   443  		return false
   444  	}
   445  
   446  	// Check if passed addresses are the same as in the accountsDB ignoring the order
   447  	return areSlicesEqual(walletAddresses, addresses)
   448  }
   449  
   450  // Comparison function to check if slices are the same ignoring the order
   451  func areSlicesEqual(a, b []common.Address) bool {
   452  	if len(a) != len(b) {
   453  		return false
   454  	}
   455  
   456  	// Create a map of addresses
   457  	aMap := make(map[common.Address]struct{}, len(a))
   458  	for _, address := range a {
   459  		aMap[address] = struct{}{}
   460  	}
   461  
   462  	// Check if all passed addresses are in the map
   463  	for _, address := range b {
   464  		if _, ok := aMap[address]; !ok {
   465  			return false
   466  		}
   467  	}
   468  
   469  	return true
   470  }