github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/neatptc/filters/api.go (about)

     1  package filters
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"math/big"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/neatlab/neatio"
    13  	"github.com/neatlab/neatio/chain/core/types"
    14  	"github.com/neatlab/neatio/neatdb"
    15  	"github.com/neatlab/neatio/network/rpc"
    16  	"github.com/neatlab/neatio/utilities/common"
    17  	"github.com/neatlab/neatio/utilities/common/hexutil"
    18  	"github.com/neatlab/neatio/utilities/event"
    19  )
    20  
    21  var (
    22  	deadline = 5 * time.Minute
    23  )
    24  
    25  type filter struct {
    26  	typ      Type
    27  	deadline *time.Timer
    28  	hashes   []common.Hash
    29  	crit     FilterCriteria
    30  	logs     []*types.Log
    31  	s        *Subscription
    32  }
    33  
    34  type PublicFilterAPI struct {
    35  	backend   Backend
    36  	mux       *event.TypeMux
    37  	quit      chan struct{}
    38  	chainDb   neatdb.Database
    39  	events    *EventSystem
    40  	filtersMu sync.Mutex
    41  	filters   map[rpc.ID]*filter
    42  }
    43  
    44  func NewPublicFilterAPI(backend Backend, lightMode bool) *PublicFilterAPI {
    45  	api := &PublicFilterAPI{
    46  		backend: backend,
    47  		mux:     backend.EventMux(),
    48  		chainDb: backend.ChainDb(),
    49  		events:  NewEventSystem(backend.EventMux(), backend, lightMode),
    50  		filters: make(map[rpc.ID]*filter),
    51  	}
    52  	go api.timeoutLoop()
    53  
    54  	return api
    55  }
    56  
    57  func (api *PublicFilterAPI) timeoutLoop() {
    58  	ticker := time.NewTicker(5 * time.Minute)
    59  	for {
    60  		<-ticker.C
    61  		api.filtersMu.Lock()
    62  		for id, f := range api.filters {
    63  			select {
    64  			case <-f.deadline.C:
    65  				f.s.Unsubscribe()
    66  				delete(api.filters, id)
    67  			default:
    68  				continue
    69  			}
    70  		}
    71  		api.filtersMu.Unlock()
    72  	}
    73  }
    74  
    75  func (api *PublicFilterAPI) NewPendingTransactionFilter() rpc.ID {
    76  	var (
    77  		pendingTxs   = make(chan common.Hash)
    78  		pendingTxSub = api.events.SubscribePendingTxEvents(pendingTxs)
    79  	)
    80  
    81  	api.filtersMu.Lock()
    82  	api.filters[pendingTxSub.ID] = &filter{typ: PendingTransactionsSubscription, deadline: time.NewTimer(deadline), hashes: make([]common.Hash, 0), s: pendingTxSub}
    83  	api.filtersMu.Unlock()
    84  
    85  	go func() {
    86  		for {
    87  			select {
    88  			case ph := <-pendingTxs:
    89  				api.filtersMu.Lock()
    90  				if f, found := api.filters[pendingTxSub.ID]; found {
    91  					f.hashes = append(f.hashes, ph)
    92  				}
    93  				api.filtersMu.Unlock()
    94  			case <-pendingTxSub.Err():
    95  				api.filtersMu.Lock()
    96  				delete(api.filters, pendingTxSub.ID)
    97  				api.filtersMu.Unlock()
    98  				return
    99  			}
   100  		}
   101  	}()
   102  
   103  	return pendingTxSub.ID
   104  }
   105  
   106  func (api *PublicFilterAPI) NewPendingTransactions(ctx context.Context) (*rpc.Subscription, error) {
   107  	notifier, supported := rpc.NotifierFromContext(ctx)
   108  	if !supported {
   109  		return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported
   110  	}
   111  
   112  	rpcSub := notifier.CreateSubscription()
   113  
   114  	go func() {
   115  		txHashes := make(chan common.Hash)
   116  		pendingTxSub := api.events.SubscribePendingTxEvents(txHashes)
   117  
   118  		for {
   119  			select {
   120  			case h := <-txHashes:
   121  				notifier.Notify(rpcSub.ID, h)
   122  			case <-rpcSub.Err():
   123  				pendingTxSub.Unsubscribe()
   124  				return
   125  			case <-notifier.Closed():
   126  				pendingTxSub.Unsubscribe()
   127  				return
   128  			}
   129  		}
   130  	}()
   131  
   132  	return rpcSub, nil
   133  }
   134  
   135  func (api *PublicFilterAPI) NewBlockFilter() rpc.ID {
   136  	var (
   137  		headers   = make(chan *types.Header)
   138  		headerSub = api.events.SubscribeNewHeads(headers)
   139  	)
   140  
   141  	api.filtersMu.Lock()
   142  	api.filters[headerSub.ID] = &filter{typ: BlocksSubscription, deadline: time.NewTimer(deadline), hashes: make([]common.Hash, 0), s: headerSub}
   143  	api.filtersMu.Unlock()
   144  
   145  	go func() {
   146  		for {
   147  			select {
   148  			case h := <-headers:
   149  				api.filtersMu.Lock()
   150  				if f, found := api.filters[headerSub.ID]; found {
   151  					f.hashes = append(f.hashes, h.Hash())
   152  				}
   153  				api.filtersMu.Unlock()
   154  			case <-headerSub.Err():
   155  				api.filtersMu.Lock()
   156  				delete(api.filters, headerSub.ID)
   157  				api.filtersMu.Unlock()
   158  				return
   159  			}
   160  		}
   161  	}()
   162  
   163  	return headerSub.ID
   164  }
   165  
   166  func (api *PublicFilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, error) {
   167  	notifier, supported := rpc.NotifierFromContext(ctx)
   168  	if !supported {
   169  		return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported
   170  	}
   171  
   172  	rpcSub := notifier.CreateSubscription()
   173  
   174  	go func() {
   175  		headers := make(chan *types.Header)
   176  		headersSub := api.events.SubscribeNewHeads(headers)
   177  
   178  		for {
   179  			select {
   180  			case h := <-headers:
   181  				notifier.Notify(rpcSub.ID, h)
   182  			case <-rpcSub.Err():
   183  				headersSub.Unsubscribe()
   184  				return
   185  			case <-notifier.Closed():
   186  				headersSub.Unsubscribe()
   187  				return
   188  			}
   189  		}
   190  	}()
   191  
   192  	return rpcSub, nil
   193  }
   194  
   195  func (api *PublicFilterAPI) Logs(ctx context.Context, crit FilterCriteria) (*rpc.Subscription, error) {
   196  	notifier, supported := rpc.NotifierFromContext(ctx)
   197  	if !supported {
   198  		return &rpc.Subscription{}, rpc.ErrNotificationsUnsupported
   199  	}
   200  
   201  	var (
   202  		rpcSub      = notifier.CreateSubscription()
   203  		matchedLogs = make(chan []*types.Log)
   204  	)
   205  
   206  	logsSub, err := api.events.SubscribeLogs(neatio.FilterQuery(crit), matchedLogs)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	go func() {
   212  
   213  		for {
   214  			select {
   215  			case logs := <-matchedLogs:
   216  				for _, log := range logs {
   217  					notifier.Notify(rpcSub.ID, &log)
   218  				}
   219  			case <-rpcSub.Err():
   220  				logsSub.Unsubscribe()
   221  				return
   222  			case <-notifier.Closed():
   223  				logsSub.Unsubscribe()
   224  				return
   225  			}
   226  		}
   227  	}()
   228  
   229  	return rpcSub, nil
   230  }
   231  
   232  type FilterCriteria struct {
   233  	FromBlock *big.Int
   234  	ToBlock   *big.Int
   235  	Addresses []common.Address
   236  	Topics    [][]common.Hash
   237  }
   238  
   239  func (api *PublicFilterAPI) NewFilter(crit FilterCriteria) (rpc.ID, error) {
   240  	logs := make(chan []*types.Log)
   241  	logsSub, err := api.events.SubscribeLogs(neatio.FilterQuery(crit), logs)
   242  	if err != nil {
   243  		return rpc.ID(""), err
   244  	}
   245  
   246  	api.filtersMu.Lock()
   247  	api.filters[logsSub.ID] = &filter{typ: LogsSubscription, crit: crit, deadline: time.NewTimer(deadline), logs: make([]*types.Log, 0), s: logsSub}
   248  	api.filtersMu.Unlock()
   249  
   250  	go func() {
   251  		for {
   252  			select {
   253  			case l := <-logs:
   254  				api.filtersMu.Lock()
   255  				if f, found := api.filters[logsSub.ID]; found {
   256  					f.logs = append(f.logs, l...)
   257  				}
   258  				api.filtersMu.Unlock()
   259  			case <-logsSub.Err():
   260  				api.filtersMu.Lock()
   261  				delete(api.filters, logsSub.ID)
   262  				api.filtersMu.Unlock()
   263  				return
   264  			}
   265  		}
   266  	}()
   267  
   268  	return logsSub.ID, nil
   269  }
   270  
   271  func (api *PublicFilterAPI) GetLogs(ctx context.Context, crit FilterCriteria) ([]*types.Log, error) {
   272  
   273  	if crit.FromBlock == nil {
   274  		crit.FromBlock = big.NewInt(rpc.LatestBlockNumber.Int64())
   275  	}
   276  	if crit.ToBlock == nil {
   277  		crit.ToBlock = big.NewInt(rpc.LatestBlockNumber.Int64())
   278  	}
   279  
   280  	filter := New(api.backend, crit.FromBlock.Int64(), crit.ToBlock.Int64(), crit.Addresses, crit.Topics)
   281  
   282  	logs, err := filter.Logs(ctx)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	return returnLogs(logs), err
   287  }
   288  
   289  func (api *PublicFilterAPI) UninstallFilter(id rpc.ID) bool {
   290  	api.filtersMu.Lock()
   291  	f, found := api.filters[id]
   292  	if found {
   293  		delete(api.filters, id)
   294  	}
   295  	api.filtersMu.Unlock()
   296  	if found {
   297  		f.s.Unsubscribe()
   298  	}
   299  
   300  	return found
   301  }
   302  
   303  func (api *PublicFilterAPI) GetFilterLogs(ctx context.Context, id rpc.ID) ([]*types.Log, error) {
   304  	api.filtersMu.Lock()
   305  	f, found := api.filters[id]
   306  	api.filtersMu.Unlock()
   307  
   308  	if !found || f.typ != LogsSubscription {
   309  		return nil, fmt.Errorf("filter not found")
   310  	}
   311  
   312  	begin := rpc.LatestBlockNumber.Int64()
   313  	if f.crit.FromBlock != nil {
   314  		begin = f.crit.FromBlock.Int64()
   315  	}
   316  	end := rpc.LatestBlockNumber.Int64()
   317  	if f.crit.ToBlock != nil {
   318  		end = f.crit.ToBlock.Int64()
   319  	}
   320  
   321  	filter := New(api.backend, begin, end, f.crit.Addresses, f.crit.Topics)
   322  
   323  	logs, err := filter.Logs(ctx)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  	return returnLogs(logs), nil
   328  }
   329  
   330  func (api *PublicFilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) {
   331  	api.filtersMu.Lock()
   332  	defer api.filtersMu.Unlock()
   333  
   334  	if f, found := api.filters[id]; found {
   335  		if !f.deadline.Stop() {
   336  
   337  			<-f.deadline.C
   338  		}
   339  		f.deadline.Reset(deadline)
   340  
   341  		switch f.typ {
   342  		case PendingTransactionsSubscription, BlocksSubscription:
   343  			hashes := f.hashes
   344  			f.hashes = nil
   345  			return returnHashes(hashes), nil
   346  		case LogsSubscription:
   347  			logs := f.logs
   348  			f.logs = nil
   349  			return returnLogs(logs), nil
   350  		}
   351  	}
   352  
   353  	return []interface{}{}, fmt.Errorf("filter not found")
   354  }
   355  
   356  func returnHashes(hashes []common.Hash) []common.Hash {
   357  	if hashes == nil {
   358  		return []common.Hash{}
   359  	}
   360  	return hashes
   361  }
   362  
   363  func returnLogs(logs []*types.Log) []*types.Log {
   364  	if logs == nil {
   365  		return []*types.Log{}
   366  	}
   367  	return logs
   368  }
   369  
   370  func (args *FilterCriteria) UnmarshalJSON(data []byte) error {
   371  	type input struct {
   372  		From      *rpc.BlockNumber `json:"fromBlock"`
   373  		ToBlock   *rpc.BlockNumber `json:"toBlock"`
   374  		Addresses interface{}      `json:"address"`
   375  		Topics    []interface{}    `json:"topics"`
   376  	}
   377  
   378  	var raw input
   379  	if err := json.Unmarshal(data, &raw); err != nil {
   380  		return err
   381  	}
   382  
   383  	if raw.From != nil {
   384  		args.FromBlock = big.NewInt(raw.From.Int64())
   385  	}
   386  
   387  	if raw.ToBlock != nil {
   388  		args.ToBlock = big.NewInt(raw.ToBlock.Int64())
   389  	}
   390  
   391  	args.Addresses = []common.Address{}
   392  
   393  	if raw.Addresses != nil {
   394  
   395  		switch rawAddr := raw.Addresses.(type) {
   396  		case []interface{}:
   397  			for i, addr := range rawAddr {
   398  				if strAddr, ok := addr.(string); ok {
   399  					addr, err := decodeAddress(strAddr)
   400  					if err != nil {
   401  						return fmt.Errorf("invalid address at index %d: %v", i, err)
   402  					}
   403  					args.Addresses = append(args.Addresses, addr)
   404  				} else {
   405  					return fmt.Errorf("non-string address at index %d", i)
   406  				}
   407  			}
   408  		case string:
   409  			addr, err := decodeAddress(rawAddr)
   410  			if err != nil {
   411  				return fmt.Errorf("invalid address: %v", err)
   412  			}
   413  			args.Addresses = []common.Address{addr}
   414  		default:
   415  			return errors.New("invalid addresses in query")
   416  		}
   417  	}
   418  
   419  	if len(raw.Topics) > 0 {
   420  		args.Topics = make([][]common.Hash, len(raw.Topics))
   421  		for i, t := range raw.Topics {
   422  			switch topic := t.(type) {
   423  			case nil:
   424  
   425  			case string:
   426  
   427  				top, err := decodeTopic(topic)
   428  				if err != nil {
   429  					return err
   430  				}
   431  				args.Topics[i] = []common.Hash{top}
   432  
   433  			case []interface{}:
   434  
   435  				for _, rawTopic := range topic {
   436  					if rawTopic == nil {
   437  
   438  						args.Topics[i] = nil
   439  						break
   440  					}
   441  					if topic, ok := rawTopic.(string); ok {
   442  						parsed, err := decodeTopic(topic)
   443  						if err != nil {
   444  							return err
   445  						}
   446  						args.Topics[i] = append(args.Topics[i], parsed)
   447  					} else {
   448  						return fmt.Errorf("invalid topic(s)")
   449  					}
   450  				}
   451  			default:
   452  				return fmt.Errorf("invalid topic(s)")
   453  			}
   454  		}
   455  	}
   456  
   457  	return nil
   458  }
   459  
   460  func decodeAddress(s string) (common.Address, error) {
   461  	b, err := hexutil.Decode(s)
   462  	if err == nil && len(b) != common.AddressLength {
   463  		err = fmt.Errorf("hex has invalid length %d after decoding", len(b))
   464  	}
   465  	return common.BytesToAddress(b), err
   466  }
   467  
   468  func decodeTopic(s string) (common.Hash, error) {
   469  	b, err := hexutil.Decode(s)
   470  	if err == nil && len(b) != common.HashLength {
   471  		err = fmt.Errorf("hex has invalid length %d after decoding", len(b))
   472  	}
   473  	return common.BytesToHash(b), err
   474  }