github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cache/redis.go (about)

     1  package cache
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/gob"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    10  	"github.com/gomodule/redigo/redis"
    11  )
    12  
    13  // RedisStore redis存储驱动
    14  type RedisStore struct {
    15  	pool *redis.Pool
    16  }
    17  
    18  type item struct {
    19  	Value interface{}
    20  }
    21  
    22  func serializer(value interface{}) ([]byte, error) {
    23  	var buffer bytes.Buffer
    24  	enc := gob.NewEncoder(&buffer)
    25  	storeValue := item{
    26  		Value: value,
    27  	}
    28  	err := enc.Encode(storeValue)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  	return buffer.Bytes(), nil
    33  }
    34  
    35  func deserializer(value []byte) (interface{}, error) {
    36  	var res item
    37  	buffer := bytes.NewReader(value)
    38  	dec := gob.NewDecoder(buffer)
    39  	err := dec.Decode(&res)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	return res.Value, nil
    44  }
    45  
    46  // NewRedisStore 创建新的redis存储
    47  func NewRedisStore(size int, network, address, user, password, database string) *RedisStore {
    48  	return &RedisStore{
    49  		pool: &redis.Pool{
    50  			MaxIdle:     size,
    51  			IdleTimeout: 240 * time.Second,
    52  			TestOnBorrow: func(c redis.Conn, t time.Time) error {
    53  				_, err := c.Do("PING")
    54  				return err
    55  			},
    56  			Dial: func() (redis.Conn, error) {
    57  				db, err := strconv.Atoi(database)
    58  				if err != nil {
    59  					return nil, err
    60  				}
    61  
    62  				c, err := redis.Dial(
    63  					network,
    64  					address,
    65  					redis.DialDatabase(db),
    66  					redis.DialUsername(user),
    67  					redis.DialPassword(password),
    68  				)
    69  				if err != nil {
    70  					util.Log().Panic("Failed to create Redis connection: %s", err)
    71  				}
    72  				return c, nil
    73  			},
    74  		},
    75  	}
    76  }
    77  
    78  // Set 存储值
    79  func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
    80  	rc := store.pool.Get()
    81  	defer rc.Close()
    82  
    83  	serialized, err := serializer(value)
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	if rc.Err() != nil {
    89  		return rc.Err()
    90  	}
    91  
    92  	if ttl > 0 {
    93  		_, err = rc.Do("SETEX", key, ttl, serialized)
    94  	} else {
    95  		_, err = rc.Do("SET", key, serialized)
    96  	}
    97  
    98  	if err != nil {
    99  		return err
   100  	}
   101  	return nil
   102  
   103  }
   104  
   105  // Get 取值
   106  func (store *RedisStore) Get(key string) (interface{}, bool) {
   107  	rc := store.pool.Get()
   108  	defer rc.Close()
   109  	if rc.Err() != nil {
   110  		return nil, false
   111  	}
   112  
   113  	v, err := redis.Bytes(rc.Do("GET", key))
   114  	if err != nil || v == nil {
   115  		return nil, false
   116  	}
   117  
   118  	finalValue, err := deserializer(v)
   119  	if err != nil {
   120  		return nil, false
   121  	}
   122  
   123  	return finalValue, true
   124  
   125  }
   126  
   127  // Gets 批量取值
   128  func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
   129  	rc := store.pool.Get()
   130  	defer rc.Close()
   131  	if rc.Err() != nil {
   132  		return nil, keys
   133  	}
   134  
   135  	var queryKeys = make([]string, len(keys))
   136  	for key, value := range keys {
   137  		queryKeys[key] = prefix + value
   138  	}
   139  
   140  	v, err := redis.ByteSlices(rc.Do("MGET", redis.Args{}.AddFlat(queryKeys)...))
   141  	if err != nil {
   142  		return nil, keys
   143  	}
   144  
   145  	var res = make(map[string]interface{})
   146  	var missed = make([]string, 0, len(keys))
   147  
   148  	for key, value := range v {
   149  		decoded, err := deserializer(value)
   150  		if err != nil || decoded == nil {
   151  			missed = append(missed, keys[key])
   152  		} else {
   153  			res[keys[key]] = decoded
   154  		}
   155  	}
   156  	// 解码所得值
   157  	return res, missed
   158  }
   159  
   160  // Sets 批量设置值
   161  func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error {
   162  	rc := store.pool.Get()
   163  	defer rc.Close()
   164  	if rc.Err() != nil {
   165  		return rc.Err()
   166  	}
   167  	var setValues = make(map[string]interface{})
   168  
   169  	// 编码待设置值
   170  	for key, value := range values {
   171  		serialized, err := serializer(value)
   172  		if err != nil {
   173  			return err
   174  		}
   175  		setValues[prefix+key] = serialized
   176  	}
   177  
   178  	_, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	return nil
   183  
   184  }
   185  
   186  // Delete 批量删除给定的键
   187  func (store *RedisStore) Delete(keys []string, prefix string) error {
   188  	rc := store.pool.Get()
   189  	defer rc.Close()
   190  	if rc.Err() != nil {
   191  		return rc.Err()
   192  	}
   193  
   194  	// 处理前缀
   195  	for i := 0; i < len(keys); i++ {
   196  		keys[i] = prefix + keys[i]
   197  	}
   198  
   199  	_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
   200  	if err != nil {
   201  		return err
   202  	}
   203  	return nil
   204  }
   205  
   206  // DeleteAll 批量所有键
   207  func (store *RedisStore) DeleteAll() error {
   208  	rc := store.pool.Get()
   209  	defer rc.Close()
   210  	if rc.Err() != nil {
   211  		return rc.Err()
   212  	}
   213  
   214  	_, err := rc.Do("FLUSHDB")
   215  
   216  	return err
   217  }
   218  
   219  // Persist Dummy implementation
   220  func (store *RedisStore) Persist(path string) error {
   221  	return nil
   222  }
   223  
   224  // Restore dummy implementation
   225  func (store *RedisStore) Restore(path string) error {
   226  	return nil
   227  }