github.com/wormhole-foundation/wormhole-explorer/common@v0.0.0-20240604151348-09585b5b97c5/client/cache/notional/cache.go (about)

     1  package notional
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/go-redis/redis/v8"
    12  	"github.com/shopspring/decimal"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  const (
    17  	wormscanNotionalUpdated            = "NOTIONAL_UPDATED"
    18  	wormscanTokenNotionalCacheKeyRegex = "WORMSCAN:NOTIONAL:TOKEN:*"
    19  	KeyTokenFormatString               = "WORMSCAN:NOTIONAL:TOKEN:%s"
    20  )
    21  
    22  var (
    23  	ErrNotFound          = errors.New("NOT FOUND")
    24  	ErrInvalidCacheField = errors.New("INVALID CACHE FIELD")
    25  )
    26  
    27  // NotionalLocalCacheReadable is the interface for notional local cache.
    28  type NotionalLocalCacheReadable interface {
    29  	Get(tokenID string) (PriceData, error)
    30  	Close() error
    31  }
    32  
    33  // PriceData is the notional value of assets in cache.
    34  type PriceData struct {
    35  	NotionalUsd decimal.Decimal `json:"notional_usd"`
    36  	UpdatedAt   time.Time       `json:"updated_at"`
    37  }
    38  
    39  // MarshalBinary implements the encoding.BinaryMarshaler interface.
    40  //
    41  // This function is used when the notional job writes data to redis.
    42  func (p PriceData) MarshalBinary() ([]byte, error) {
    43  	return json.Marshal(p)
    44  }
    45  
    46  // NotionalCacheClient redis cache client.
    47  type NotionalCache struct {
    48  	client      *redis.Client
    49  	pubSub      *redis.PubSub
    50  	notionalMap sync.Map
    51  	prefix      string
    52  	logger      *zap.Logger
    53  }
    54  
    55  // NewNotionalCache create a new cache client.
    56  // After create a NotionalCache use the Init method to initialize pubsub and load the cache.
    57  func NewNotionalCache(ctx context.Context, redisClient *redis.Client, prefix string, channel string, log *zap.Logger) (*NotionalCache, error) {
    58  	if redisClient == nil {
    59  		return nil, errors.New("redis client is nil")
    60  	}
    61  	pubsub := redisClient.Subscribe(ctx, formatChannel(prefix, channel))
    62  	return &NotionalCache{
    63  		client:      redisClient,
    64  		pubSub:      pubsub,
    65  		notionalMap: sync.Map{},
    66  		prefix:      prefix,
    67  		logger:      log}, nil
    68  }
    69  
    70  // Init subscribe to notional pubsub and load the cache.
    71  func (c *NotionalCache) Init(ctx context.Context) error {
    72  
    73  	// load notional cache
    74  	err := c.loadCache(ctx)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// notional cache updated channel subscribe
    80  	c.subscribe(ctx)
    81  
    82  	return nil
    83  }
    84  
    85  // loadCache load notional cache from redis.
    86  func (c *NotionalCache) loadCache(ctx context.Context) error {
    87  
    88  	var cursor uint64
    89  	var err error
    90  	for {
    91  		// Get a page of results from the cursor
    92  		var keys []string
    93  		scanCmd := c.client.Scan(ctx, cursor, c.renderRegExp(), 100)
    94  		if scanCmd.Err() != nil {
    95  			c.logger.Error("redis.ScanCmd has errors", zap.Error(err))
    96  			return fmt.Errorf("redis.ScanCmd has errors: %w", err)
    97  		}
    98  		keys, cursor, err = scanCmd.Result()
    99  		if err != nil {
   100  			c.logger.Error("call to redis.ScanCmd.Result() failed", zap.Error(err))
   101  			return fmt.Errorf("call to redis.ScanCmd.Result() failed: %w", err)
   102  		}
   103  
   104  		// Get notional value from keys
   105  		for _, key := range keys {
   106  			var field PriceData
   107  			value, err := c.client.Get(ctx, key).Result()
   108  			json.Unmarshal([]byte(value), &field)
   109  			if err != nil {
   110  				c.logger.Error("loadCache", zap.Error(err))
   111  				return err
   112  			}
   113  			// Save notional value to local cache
   114  			c.notionalMap.Store(key, field)
   115  		}
   116  
   117  		// If we've reached the end of the cursor, return
   118  		if cursor == 0 {
   119  			return nil
   120  		}
   121  	}
   122  }
   123  
   124  // Subscribe to a notional update channel and load new values for the notional cache.
   125  func (c *NotionalCache) subscribe(ctx context.Context) {
   126  	ch := c.pubSub.Channel()
   127  
   128  	go func() {
   129  		for msg := range ch {
   130  			c.logger.Info("receive message from channel", zap.String("channel", msg.Channel), zap.String("payload", msg.Payload))
   131  			if wormscanNotionalUpdated == msg.Payload {
   132  				// update notional cache
   133  				c.loadCache(ctx)
   134  			}
   135  		}
   136  	}()
   137  }
   138  
   139  // Close the pubsub channel.
   140  func (c *NotionalCache) Close() error {
   141  	return c.pubSub.Close()
   142  }
   143  
   144  // Get notional cache value.
   145  func (c *NotionalCache) Get(tokenID string) (PriceData, error) {
   146  	var notional PriceData
   147  
   148  	// get notional cache key
   149  	key := fmt.Sprintf(KeyTokenFormatString, tokenID)
   150  	key = c.renderKey(key)
   151  
   152  	// get notional cache value
   153  	field, ok := c.notionalMap.Load(key)
   154  	if !ok {
   155  		return notional, ErrNotFound
   156  	}
   157  
   158  	// convert any field to NotionalCacheField
   159  	notional, ok = field.(PriceData)
   160  	if !ok {
   161  		c.logger.Error("invalid notional cache field",
   162  			zap.Any("field", field),
   163  			zap.String("tokenId", tokenID))
   164  		return notional, ErrInvalidCacheField
   165  	}
   166  	return notional, nil
   167  }
   168  
   169  func (c *NotionalCache) renderKey(key string) string {
   170  	if c.prefix != "" {
   171  		return fmt.Sprintf("%s:%s", c.prefix, key)
   172  	} else {
   173  		return key
   174  	}
   175  }
   176  
   177  func (c *NotionalCache) renderRegExp() string {
   178  	return "*" + c.renderKey(wormscanTokenNotionalCacheKeyRegex)
   179  }
   180  
   181  func formatChannel(prefix string, channel string) string {
   182  	if prefix != "" {
   183  		return fmt.Sprintf("%s:%s", prefix, channel)
   184  	}
   185  	return channel
   186  }