github.com/onflow/flow-go@v0.33.17/engine/access/state_stream/filter.go (about)

     1  package state_stream
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/onflow/flow-go/model/events"
     8  	"github.com/onflow/flow-go/model/flow"
     9  )
    10  
    11  const (
    12  	// DefaultMaxEventTypes is the default maximum number of event types that can be specified in a filter
    13  	DefaultMaxEventTypes = 1000
    14  
    15  	// DefaultMaxAddresses is the default maximum number of addresses that can be specified in a filter
    16  	DefaultMaxAddresses = 1000
    17  
    18  	// DefaultMaxContracts is the default maximum number of contracts that can be specified in a filter
    19  	DefaultMaxContracts = 1000
    20  )
    21  
    22  // EventFilterConfig is used to configure the limits for EventFilters
    23  type EventFilterConfig struct {
    24  	MaxEventTypes int
    25  	MaxAddresses  int
    26  	MaxContracts  int
    27  }
    28  
    29  // DefaultEventFilterConfig is the default configuration for EventFilters
    30  var DefaultEventFilterConfig = EventFilterConfig{
    31  	MaxEventTypes: DefaultMaxEventTypes,
    32  	MaxAddresses:  DefaultMaxAddresses,
    33  	MaxContracts:  DefaultMaxContracts,
    34  }
    35  
    36  // EventFilter represents a filter applied to events for a given subscription
    37  type EventFilter struct {
    38  	hasFilters bool
    39  	EventTypes map[flow.EventType]struct{}
    40  	Addresses  map[string]struct{}
    41  	Contracts  map[string]struct{}
    42  }
    43  
    44  func NewEventFilter(
    45  	config EventFilterConfig,
    46  	chain flow.Chain,
    47  	eventTypes []string,
    48  	addresses []string,
    49  	contracts []string,
    50  ) (EventFilter, error) {
    51  	// put some reasonable limits on the number of filters. Lookups use a map so they are fast,
    52  	// this just puts a cap on the memory consumed per filter.
    53  	if len(eventTypes) > config.MaxEventTypes {
    54  		return EventFilter{}, fmt.Errorf("too many event types in filter (%d). use %d or fewer", len(eventTypes), config.MaxEventTypes)
    55  	}
    56  
    57  	if len(addresses) > config.MaxAddresses {
    58  		return EventFilter{}, fmt.Errorf("too many addresses in filter (%d). use %d or fewer", len(addresses), config.MaxAddresses)
    59  	}
    60  
    61  	if len(contracts) > config.MaxContracts {
    62  		return EventFilter{}, fmt.Errorf("too many contracts in filter (%d). use %d or fewer", len(contracts), config.MaxContracts)
    63  	}
    64  
    65  	f := EventFilter{
    66  		EventTypes: make(map[flow.EventType]struct{}, len(eventTypes)),
    67  		Addresses:  make(map[string]struct{}, len(addresses)),
    68  		Contracts:  make(map[string]struct{}, len(contracts)),
    69  	}
    70  
    71  	// Check all of the filters to ensure they are correctly formatted. This helps avoid searching
    72  	// with criteria that will never match.
    73  	for _, event := range eventTypes {
    74  		eventType := flow.EventType(event)
    75  		if err := validateEventType(eventType, chain); err != nil {
    76  			return EventFilter{}, err
    77  		}
    78  		f.EventTypes[eventType] = struct{}{}
    79  	}
    80  
    81  	for _, address := range addresses {
    82  		addr := flow.HexToAddress(address)
    83  		if err := validateAddress(addr, chain); err != nil {
    84  			return EventFilter{}, err
    85  		}
    86  		// use the parsed address to make sure it will match the event address string exactly
    87  		f.Addresses[addr.String()] = struct{}{}
    88  	}
    89  
    90  	for _, contract := range contracts {
    91  		if err := validateContract(contract); err != nil {
    92  			return EventFilter{}, err
    93  		}
    94  		f.Contracts[contract] = struct{}{}
    95  	}
    96  
    97  	f.hasFilters = len(f.EventTypes) > 0 || len(f.Addresses) > 0 || len(f.Contracts) > 0
    98  	return f, nil
    99  }
   100  
   101  // Filter applies the all filters on the provided list of events, and returns a list of events that
   102  // match
   103  func (f *EventFilter) Filter(events flow.EventsList) flow.EventsList {
   104  	var filteredEvents flow.EventsList
   105  	for _, event := range events {
   106  		if f.Match(event) {
   107  			filteredEvents = append(filteredEvents, event)
   108  		}
   109  	}
   110  	return filteredEvents
   111  }
   112  
   113  // Match applies all filters to a specific event, and returns true if the event matches
   114  func (f *EventFilter) Match(event flow.Event) bool {
   115  	// No filters means all events match
   116  	if !f.hasFilters {
   117  		return true
   118  	}
   119  
   120  	if _, ok := f.EventTypes[event.Type]; ok {
   121  		return true
   122  	}
   123  
   124  	parsed, err := events.ParseEvent(event.Type)
   125  	if err != nil {
   126  		// TODO: log this error
   127  		return false
   128  	}
   129  
   130  	if _, ok := f.Contracts[parsed.Contract]; ok {
   131  		return true
   132  	}
   133  
   134  	if parsed.Type == events.AccountEventType {
   135  		_, ok := f.Addresses[parsed.Address]
   136  		return ok
   137  	}
   138  
   139  	return false
   140  }
   141  
   142  // validateEventType ensures that the event type matches the expected format
   143  func validateEventType(eventType flow.EventType, chain flow.Chain) error {
   144  	_, err := events.ValidateEvent(flow.EventType(eventType), chain)
   145  	if err != nil {
   146  		return fmt.Errorf("invalid event type %s: %w", eventType, err)
   147  	}
   148  	return nil
   149  }
   150  
   151  // validateAddress ensures that the address is valid for the given chain
   152  func validateAddress(address flow.Address, chain flow.Chain) error {
   153  	if !chain.IsValid(address) {
   154  		return fmt.Errorf("invalid address for chain: %s", address)
   155  	}
   156  	return nil
   157  }
   158  
   159  // validateContract ensures that the contract is in the correct format
   160  func validateContract(contract string) error {
   161  	if contract == "flow" {
   162  		return nil
   163  	}
   164  
   165  	parts := strings.Split(contract, ".")
   166  	if len(parts) != 3 || parts[0] != "A" {
   167  		return fmt.Errorf("invalid contract: %s", contract)
   168  	}
   169  	return nil
   170  }