github.com/status-im/status-go@v1.1.0/rpc/chain/rpc_limiter.go (about)

     1  package chain
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/google/uuid"
    11  
    12  	"github.com/ethereum/go-ethereum/log"
    13  )
    14  
    15  const (
    16  	defaultMaxRequestsPerSecond = 50
    17  	minRequestsPerSecond        = 20
    18  	requestsPerSecondStep       = 10
    19  
    20  	tickerInterval  = 1 * time.Second
    21  	LimitInfinitely = 0
    22  )
    23  
    24  var (
    25  	ErrRequestsOverLimit = errors.New("number of requests over limit")
    26  )
    27  
    28  type callerOnWait struct {
    29  	requests int
    30  	ch       chan bool
    31  }
    32  
    33  type LimitsStorage interface {
    34  	Get(tag string) (*LimitData, error)
    35  	Set(data *LimitData) error
    36  	Delete(tag string) error
    37  }
    38  
    39  type InMemRequestsMapStorage struct {
    40  	data sync.Map
    41  }
    42  
    43  func NewInMemRequestsMapStorage() *InMemRequestsMapStorage {
    44  	return &InMemRequestsMapStorage{}
    45  }
    46  
    47  func (s *InMemRequestsMapStorage) Get(tag string) (*LimitData, error) {
    48  	data, ok := s.data.Load(tag)
    49  	if !ok {
    50  		return nil, nil
    51  	}
    52  
    53  	return data.(*LimitData), nil
    54  }
    55  
    56  func (s *InMemRequestsMapStorage) Set(data *LimitData) error {
    57  	if data == nil {
    58  		return fmt.Errorf("data is nil")
    59  	}
    60  
    61  	s.data.Store(data.Tag, data)
    62  	return nil
    63  }
    64  
    65  func (s *InMemRequestsMapStorage) Delete(tag string) error {
    66  	s.data.Delete(tag)
    67  	return nil
    68  }
    69  
    70  type LimitsDBStorage struct {
    71  	db *RPCLimiterDB
    72  }
    73  
    74  func NewLimitsDBStorage(db *sql.DB) *LimitsDBStorage {
    75  	return &LimitsDBStorage{
    76  		db: NewRPCLimiterDB(db),
    77  	}
    78  }
    79  
    80  func (s *LimitsDBStorage) Get(tag string) (*LimitData, error) {
    81  	return s.db.GetRPCLimit(tag)
    82  }
    83  
    84  func (s *LimitsDBStorage) Set(data *LimitData) error {
    85  	if data == nil {
    86  		return fmt.Errorf("data is nil")
    87  	}
    88  
    89  	limit, err := s.db.GetRPCLimit(data.Tag)
    90  	if err != nil && err != sql.ErrNoRows {
    91  		return err
    92  	}
    93  
    94  	if limit == nil {
    95  		return s.db.CreateRPCLimit(*data)
    96  	}
    97  
    98  	return s.db.UpdateRPCLimit(*data)
    99  }
   100  
   101  func (s *LimitsDBStorage) Delete(tag string) error {
   102  	return s.db.DeleteRPCLimit(tag)
   103  }
   104  
   105  type LimitData struct {
   106  	Tag       string
   107  	CreatedAt time.Time
   108  	Period    time.Duration
   109  	MaxReqs   int
   110  	NumReqs   int
   111  }
   112  
   113  type RequestLimiter interface {
   114  	SetLimit(tag string, maxRequests int, interval time.Duration) error
   115  	GetLimit(tag string) (*LimitData, error)
   116  	DeleteLimit(tag string) error
   117  	Allow(tag string) (bool, error)
   118  }
   119  
   120  type RPCRequestLimiter struct {
   121  	storage LimitsStorage
   122  	mu      sync.Mutex
   123  }
   124  
   125  func NewRequestLimiter(storage LimitsStorage) *RPCRequestLimiter {
   126  	return &RPCRequestLimiter{
   127  		storage: storage,
   128  	}
   129  }
   130  
   131  func (rl *RPCRequestLimiter) SetLimit(tag string, maxRequests int, interval time.Duration) error {
   132  	err := rl.saveToStorage(tag, maxRequests, interval, 0, time.Now())
   133  	if err != nil {
   134  		log.Error("Failed to save request data to storage", "error", err)
   135  		return err
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func (rl *RPCRequestLimiter) GetLimit(tag string) (*LimitData, error) {
   142  	data, err := rl.storage.Get(tag)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	return data, nil
   148  }
   149  
   150  func (rl *RPCRequestLimiter) DeleteLimit(tag string) error {
   151  	err := rl.storage.Delete(tag)
   152  	if err != nil {
   153  		log.Error("Failed to delete request data from storage", "error", err)
   154  		return err
   155  	}
   156  
   157  	return nil
   158  }
   159  
   160  func (rl *RPCRequestLimiter) saveToStorage(tag string, maxRequests int, interval time.Duration, numReqs int, timestamp time.Time) error {
   161  	data := &LimitData{
   162  		Tag:       tag,
   163  		CreatedAt: timestamp,
   164  		Period:    interval,
   165  		MaxReqs:   maxRequests,
   166  		NumReqs:   numReqs,
   167  	}
   168  
   169  	err := rl.storage.Set(data)
   170  	if err != nil {
   171  		log.Error("Failed to save request data to storage", "error", err)
   172  		return err
   173  	}
   174  
   175  	return nil
   176  }
   177  
   178  func (rl *RPCRequestLimiter) Allow(tag string) (bool, error) {
   179  	rl.mu.Lock()
   180  	defer rl.mu.Unlock()
   181  
   182  	data, err := rl.storage.Get(tag)
   183  	if err != nil {
   184  		return true, err
   185  	}
   186  
   187  	if data == nil {
   188  		return true, nil
   189  	}
   190  
   191  	// Check if the interval has passed and reset the number of requests
   192  	if time.Since(data.CreatedAt) >= data.Period && data.Period.Milliseconds() != LimitInfinitely {
   193  		err = rl.saveToStorage(tag, data.MaxReqs, data.Period, 0, time.Now())
   194  		if err != nil {
   195  			return true, err
   196  		}
   197  
   198  		return true, nil
   199  	}
   200  
   201  	// Check if a number of requests is over the limit within the interval
   202  	if time.Since(data.CreatedAt) < data.Period || data.Period.Milliseconds() == LimitInfinitely {
   203  		if data.NumReqs >= data.MaxReqs {
   204  			log.Info("Number of requests over limit",
   205  				"tag", tag,
   206  				"numReqs", data.NumReqs,
   207  				"maxReqs", data.MaxReqs,
   208  				"period", data.Period,
   209  				"createdAt", data.CreatedAt.UTC())
   210  			return false, ErrRequestsOverLimit
   211  		}
   212  
   213  		return true, rl.saveToStorage(tag, data.MaxReqs, data.Period, data.NumReqs+1, data.CreatedAt)
   214  	}
   215  
   216  	// Reset the number of requests if the interval has passed
   217  	return true, rl.saveToStorage(tag, data.MaxReqs, data.Period, 0, time.Now()) // still allow the request if failed to save as not critical
   218  }
   219  
   220  type RPCRpsLimiter struct {
   221  	uuid uuid.UUID
   222  
   223  	maxRequestsPerSecond      int
   224  	maxRequestsPerSecondMutex sync.RWMutex
   225  
   226  	requestsMadeWithinSecond      int
   227  	requestsMadeWithinSecondMutex sync.RWMutex
   228  
   229  	callersOnWaitForRequests      []callerOnWait
   230  	callersOnWaitForRequestsMutex sync.RWMutex
   231  
   232  	quit chan bool
   233  }
   234  
   235  func NewRPCRpsLimiter() *RPCRpsLimiter {
   236  
   237  	limiter := RPCRpsLimiter{
   238  		uuid:                 uuid.New(),
   239  		maxRequestsPerSecond: defaultMaxRequestsPerSecond,
   240  		quit:                 make(chan bool),
   241  	}
   242  
   243  	limiter.start()
   244  
   245  	return &limiter
   246  }
   247  
   248  func (rl *RPCRpsLimiter) ReduceLimit() {
   249  	rl.maxRequestsPerSecondMutex.Lock()
   250  	defer rl.maxRequestsPerSecondMutex.Unlock()
   251  	if rl.maxRequestsPerSecond <= minRequestsPerSecond {
   252  		return
   253  	}
   254  	rl.maxRequestsPerSecond = rl.maxRequestsPerSecond - requestsPerSecondStep
   255  }
   256  
   257  func (rl *RPCRpsLimiter) start() {
   258  	ticker := time.NewTicker(tickerInterval)
   259  	go func() {
   260  		for {
   261  			select {
   262  			case <-ticker.C:
   263  				{
   264  					rl.requestsMadeWithinSecondMutex.Lock()
   265  					oldrequestsMadeWithinSecond := rl.requestsMadeWithinSecond
   266  					if rl.requestsMadeWithinSecond != 0 {
   267  						rl.requestsMadeWithinSecond = 0
   268  					}
   269  					rl.requestsMadeWithinSecondMutex.Unlock()
   270  					if oldrequestsMadeWithinSecond == 0 {
   271  						continue
   272  					}
   273  				}
   274  
   275  				rl.callersOnWaitForRequestsMutex.Lock()
   276  				numOfRequestsToMakeAvailable := rl.maxRequestsPerSecond
   277  				for {
   278  					if numOfRequestsToMakeAvailable == 0 || len(rl.callersOnWaitForRequests) == 0 {
   279  						break
   280  					}
   281  
   282  					var index = -1
   283  					for i := 0; i < len(rl.callersOnWaitForRequests); i++ {
   284  						if rl.callersOnWaitForRequests[i].requests <= numOfRequestsToMakeAvailable {
   285  							index = i
   286  							break
   287  						}
   288  					}
   289  
   290  					if index == -1 {
   291  						break
   292  					}
   293  
   294  					callerOnWait := rl.callersOnWaitForRequests[index]
   295  					numOfRequestsToMakeAvailable -= callerOnWait.requests
   296  					rl.callersOnWaitForRequests = append(rl.callersOnWaitForRequests[:index], rl.callersOnWaitForRequests[index+1:]...)
   297  
   298  					callerOnWait.ch <- true
   299  				}
   300  				rl.callersOnWaitForRequestsMutex.Unlock()
   301  
   302  			case <-rl.quit:
   303  				ticker.Stop()
   304  				return
   305  			}
   306  		}
   307  	}()
   308  }
   309  
   310  func (rl *RPCRpsLimiter) Stop() {
   311  	rl.quit <- true
   312  	close(rl.quit)
   313  	for _, callerOnWait := range rl.callersOnWaitForRequests {
   314  		close(callerOnWait.ch)
   315  	}
   316  	rl.callersOnWaitForRequests = nil
   317  }
   318  
   319  func (rl *RPCRpsLimiter) WaitForRequestsAvailability(requests int) error {
   320  	if requests > rl.maxRequestsPerSecond {
   321  		return ErrRequestsOverLimit
   322  	}
   323  
   324  	{
   325  		rl.requestsMadeWithinSecondMutex.Lock()
   326  		if rl.requestsMadeWithinSecond+requests <= rl.maxRequestsPerSecond {
   327  			rl.requestsMadeWithinSecond += requests
   328  			rl.requestsMadeWithinSecondMutex.Unlock()
   329  			return nil
   330  		}
   331  		rl.requestsMadeWithinSecondMutex.Unlock()
   332  	}
   333  
   334  	callerOnWait := callerOnWait{
   335  		requests: requests,
   336  		ch:       make(chan bool),
   337  	}
   338  
   339  	{
   340  		rl.callersOnWaitForRequestsMutex.Lock()
   341  		rl.callersOnWaitForRequests = append(rl.callersOnWaitForRequests, callerOnWait)
   342  		rl.callersOnWaitForRequestsMutex.Unlock()
   343  	}
   344  
   345  	<-callerOnWait.ch
   346  
   347  	close(callerOnWait.ch)
   348  
   349  	rl.requestsMadeWithinSecondMutex.Lock()
   350  	rl.requestsMadeWithinSecond += requests
   351  	rl.requestsMadeWithinSecondMutex.Unlock()
   352  
   353  	return nil
   354  }