github.com/jxskiss/gopkg/v2@v2.14.9-0.20240514120614-899f3e7952b4/exp/kvutil/sharding.go (about)

     1  package kvutil
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha1"
     7  	"fmt"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/jxskiss/gopkg/v2/easy"
    13  )
    14  
    15  // ShardingModel is the interface implemented by types that can be cached
    16  // by ShardingCache.
    17  //
    18  // A ShardingModel implementation type should save ShardingData with it,
    19  // and cna tell whether it's a shard or a complete model.
    20  type ShardingModel interface {
    21  	Model
    22  
    23  	// GetShardingData returns whether a model is a shard or a complete model.
    24  	// When the returned bool value is false, the returned ShardingData
    25  	// can be a zero value, and shall not be used.
    26  	GetShardingData() (ShardingData, bool)
    27  
    28  	// SetShardingData sets a shard data to a new model created when
    29  	// doing serialization to split data into shards and save to storage.
    30  	SetShardingData(data ShardingData)
    31  }
    32  
    33  // ShardingData holds sharding information and partial data of a sharding.
    34  //
    35  // Example:
    36  //
    37  //	// protobuf
    38  //	message ShardingData {
    39  //		int32 total_num = 1;
    40  //		int32 shard_num = 2;
    41  //		bytes digest = 3;
    42  //		bytes data = 4;
    43  //	}
    44  type ShardingData struct {
    45  	TotalNum int32  `protobuf:"varint,1,opt,name=total_num,json=totalNum,proto3" json:"total_num,omitempty"`
    46  	ShardNum int32  `protobuf:"varint,2,opt,name=shard_num,json=shardNum,proto3" json:"shard_num,omitempty"`
    47  	Digest   []byte `protobuf:"bytes,3,opt,name=digest,proto3" json:"digest,omitempty"`
    48  	Data     []byte `protobuf:"bytes,4,opt,name=data,json=data,proto3" json:"data,omitempty"`
    49  }
    50  
    51  // ShardingCacheConfig configures a ShardingCache instance.
    52  type ShardingCacheConfig[K comparable, V ShardingModel] struct {
    53  
    54  	// Storage must return a Storage implementation which will be used
    55  	// as the underlying key-value storage.
    56  	Storage func(ctx context.Context) Storage
    57  
    58  	// IDFunc returns the primary key of a ShardingModel object.
    59  	IDFunc func(V) K
    60  
    61  	// KeyFunc specifies the key function to use with the storage.
    62  	KeyFunc Key
    63  
    64  	// ShardingSize configures the maximum length of data in a shard.
    65  	// When the serialization result of V is longer than ShardingSize,
    66  	// the data will be split into shards to save to storage.
    67  	//
    68  	// ShardingSize must be greater than zero, else it panics.
    69  	ShardingSize int
    70  
    71  	// MGetBatchSize optionally specifies the batch size for one MGet
    72  	// calling to storage. The default is 200.
    73  	MGetBatchSize int
    74  
    75  	// MSetBatchSize optionally specifies the batch size for one MSet
    76  	// calling to storage. The default is 200.
    77  	MSetBatchSize int
    78  
    79  	// DeleteBatchSize optionally specifies the batch size for one Delete
    80  	// calling to storage. The default is 200.
    81  	DeleteBatchSize int
    82  }
    83  
    84  func (p *ShardingCacheConfig[K, V]) checkAndSetDefaults() {
    85  	if p.ShardingSize <= 0 {
    86  		panic("kvutil: ShardingCacheConfig.ShardingSize must be greater than zero")
    87  	}
    88  	if p.MGetBatchSize <= 0 {
    89  		p.MGetBatchSize = DefaultBatchSize
    90  	}
    91  	if p.MSetBatchSize <= 0 {
    92  		p.MSetBatchSize = DefaultBatchSize
    93  	}
    94  	if p.DeleteBatchSize <= 0 {
    95  		p.DeleteBatchSize = DefaultBatchSize
    96  	}
    97  }
    98  
    99  // NewShardingCache returns a new ShardingCache instance.
   100  func NewShardingCache[K comparable, V ShardingModel](config *ShardingCacheConfig[K, V]) *ShardingCache[K, V] {
   101  	config.checkAndSetDefaults()
   102  	newElemFn := buildNewElemFunc[V]()
   103  	return &ShardingCache[K, V]{
   104  		config:      config,
   105  		newElemFunc: newElemFn,
   106  	}
   107  }
   108  
   109  // ShardingCache implements common cache operations for big cache value,
   110  // it helps to split big value into shards according to
   111  // ShardingCacheConfig.ShardingSize.
   112  //
   113  // When saving data to cache storage, it checks length of the serialization
   114  // result, if it does not exceed ShardingSize, it saves one key-value
   115  // to storage, else it splits the serialization result to multiple shards
   116  // and saves multiple key-values to storage.
   117  //
   118  // When doing query, it first loads the first shard from storage and checks
   119  // whether there are more shards, if yes, it builds the keys of other shards
   120  // using the information in the first shard, then reads the other shards,
   121  // and concat all data to deserialize the complete model.
   122  //
   123  // A ShardingCache must not be copied after initialization.
   124  type ShardingCache[K comparable, V ShardingModel] struct {
   125  	config *ShardingCacheConfig[K, V]
   126  
   127  	newElemFunc func() V
   128  }
   129  
   130  // Set writes a key value pair to ShardingCache.
   131  func (p *ShardingCache[K, V]) Set(ctx context.Context, pk K, elem V, expiration time.Duration) error {
   132  	_ = pk
   133  	return p.MSet(ctx, []V{elem}, expiration)
   134  }
   135  
   136  // MSet serializes and writes multiple models to ShardingCache.
   137  func (p *ShardingCache[K, V]) MSet(ctx context.Context, models []V, expiration time.Duration) error {
   138  	if len(models) == 0 {
   139  		return nil
   140  	}
   141  	kvPairs, err := p.marshalModels(models)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	stor := p.config.Storage(ctx)
   146  	return msetToStorage(ctx, stor, kvPairs, expiration, p.config.MSetBatchSize)
   147  }
   148  
   149  // Delete deletes key values from ShardingCache.
   150  //
   151  // By default, it only deletes the first shards from storage,
   152  // if the underlying storage is Redis, the other shards shall be evicted
   153  // when they are expired.
   154  // If the underlying storage does not support auto eviction, or the data
   155  // does not expire, or user want to release storage space actively,
   156  // deleteAllShards should be set to true, which indicates it to read
   157  // the first shard from storage and checks whether there are more shards,
   158  // it yes, it builds the keys of other shards using the information
   159  // in the first shard, then deletes all shards from storage.
   160  func (p *ShardingCache[K, V]) Delete(ctx context.Context, deleteAllShards bool, pks ...K) error {
   161  	if len(pks) == 0 {
   162  		return nil
   163  	}
   164  	if deleteAllShards {
   165  		return p.deleteAllShards(ctx, pks)
   166  	}
   167  
   168  	keys := make([]string, 0, len(pks))
   169  	for _, pk := range pks {
   170  		keys = append(keys, p.config.KeyFunc(pk))
   171  	}
   172  
   173  	stor := p.config.Storage(ctx)
   174  	batches := easy.Split(keys, p.config.DeleteBatchSize)
   175  	for _, bat := range batches {
   176  		err := stor.Delete(ctx, bat...)
   177  		if err != nil {
   178  			return err
   179  		}
   180  	}
   181  	return nil
   182  }
   183  
   184  func (p *ShardingCache[K, V]) deleteAllShards(ctx context.Context, pks []K) error {
   185  	keys := make([]string, 0, len(pks))
   186  	for _, pk := range pks {
   187  		keys = append(keys, p.config.KeyFunc(pk))
   188  	}
   189  
   190  	stor := p.config.Storage(ctx)
   191  	mgetRet, err := mgetFromStorage(ctx, stor, keys, p.config.MGetBatchSize)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	for i, data := range mgetRet {
   196  		if len(data) == 0 {
   197  			continue
   198  		}
   199  		key := keys[i]
   200  		elem := p.newElemFunc()
   201  		err = elem.UnmarshalBinary(data)
   202  		if err != nil {
   203  			return fmt.Errorf("cannot unmarshal data: %w", err)
   204  		}
   205  		shard0, isShard := elem.GetShardingData()
   206  		if isShard {
   207  			for j := 1; j < int(shard0.TotalNum); j++ {
   208  				ithKey := GetShardKey(key, j)
   209  				keys = append(keys, ithKey)
   210  			}
   211  		}
   212  	}
   213  
   214  	batches := easy.Split(keys, p.config.DeleteBatchSize)
   215  	for _, bat := range batches {
   216  		err = stor.Delete(ctx, bat...)
   217  		if err != nil {
   218  			return err
   219  		}
   220  	}
   221  	return nil
   222  }
   223  
   224  func (p *ShardingCache[K, V]) marshalModels(entityList []V) (result []KVPair, err error) {
   225  	shardSize := p.config.ShardingSize
   226  	for _, elem := range entityList {
   227  		buf, err := elem.MarshalBinary()
   228  		if err != nil {
   229  			return nil, err
   230  		}
   231  		pk := p.config.IDFunc(elem)
   232  		key := p.config.KeyFunc(pk)
   233  		if len(buf) <= shardSize {
   234  			result = append(result, KVPair{K: key, V: buf})
   235  			continue
   236  		}
   237  
   238  		// Split big value into shards.
   239  		totalNum := (len(buf) + shardSize - 1) / shardSize
   240  		digest := calcDigest(buf)
   241  		shard0 := p.newElemFunc()
   242  		shard0.SetShardingData(ShardingData{
   243  			TotalNum: int32(totalNum),
   244  			ShardNum: 0,
   245  			Digest:   digest,
   246  			Data:     buf[:shardSize],
   247  		})
   248  		shard0Buf, err := shard0.MarshalBinary()
   249  		if err != nil {
   250  			return nil, err
   251  		}
   252  
   253  		result = append(result, KVPair{K: key, V: shard0Buf})
   254  		for num := 1; num < totalNum; num++ {
   255  			i := num * shardSize
   256  			j := min((num+1)*shardSize, len(buf))
   257  			ithKey := GetShardKey(key, num)
   258  			ithShard := p.newElemFunc()
   259  			ithShard.SetShardingData(ShardingData{
   260  				TotalNum: int32(totalNum),
   261  				ShardNum: int32(num),
   262  				Digest:   digest,
   263  				Data:     buf[i:j],
   264  			})
   265  			ithBuf, err := ithShard.MarshalBinary()
   266  			if err != nil {
   267  				return nil, err
   268  			}
   269  			result = append(result, KVPair{K: ithKey, V: ithBuf})
   270  		}
   271  	}
   272  	return result, nil
   273  }
   274  
   275  // Get queries ShardingCache for a given pk.
   276  //
   277  // If pk cannot be found in the cache, it returns an error ErrDataNotFound.
   278  func (p *ShardingCache[K, V]) Get(ctx context.Context, pk K) (V, error) {
   279  	var zeroVal V
   280  	stor := p.config.Storage(ctx)
   281  	key := p.config.KeyFunc(pk)
   282  	cacheResult, err := stor.MGet(ctx, key)
   283  	if err != nil {
   284  		return zeroVal, err
   285  	}
   286  
   287  	if len(cacheResult) == 0 || len(cacheResult[0]) == 0 {
   288  		return zeroVal, ErrDataNotFound
   289  	}
   290  
   291  	elem := p.newElemFunc()
   292  	err = elem.UnmarshalBinary(cacheResult[0])
   293  	if err != nil {
   294  		return zeroVal, fmt.Errorf("cannot unmarshal data: %w", err)
   295  	}
   296  	shard0, isShard := elem.GetShardingData()
   297  	if !isShard {
   298  		return elem, nil
   299  	}
   300  
   301  	ithKeys := make([]string, 0, shard0.TotalNum-1)
   302  	for i := 1; i < int(shard0.TotalNum); i++ {
   303  		ithKey := GetShardKey(key, i)
   304  		ithKeys = append(ithKeys, ithKey)
   305  	}
   306  	ithCacheResult, err := mgetFromStorage(ctx, stor, ithKeys, p.config.MGetBatchSize)
   307  	if err != nil {
   308  		return zeroVal, err
   309  	}
   310  
   311  	buf := shard0.Data
   312  	for i := 0; i < len(ithKeys); i++ {
   313  		ithKey := ithKeys[i]
   314  		ithRet := ithCacheResult[i]
   315  		if len(ithRet) == 0 {
   316  			return zeroVal, fmt.Errorf("sharding data not found: %s", ithKey)
   317  		}
   318  		ithVal := p.newElemFunc()
   319  		err = ithVal.UnmarshalBinary(ithRet)
   320  		if err != nil {
   321  			return zeroVal, fmt.Errorf("cannot unmarshal data: %w", err)
   322  		}
   323  		shardNum := getShardNumFromKey(ithKey)
   324  		ithShard, isShard := ithVal.GetShardingData()
   325  		if !isShard || ithShard.ShardNum != shardNum {
   326  			return zeroVal, fmt.Errorf("sharding data is invalid: %s", ithKey)
   327  		}
   328  		if !bytes.Equal(ithShard.Digest, shard0.Digest) {
   329  			return zeroVal, fmt.Errorf("sharding data digest not match: %s", ithKey)
   330  		}
   331  		buf = append(buf, ithShard.Data...)
   332  	}
   333  	elem = p.newElemFunc()
   334  	err = elem.UnmarshalBinary(buf)
   335  	if err != nil {
   336  		return zeroVal, fmt.Errorf("cannot unmarshal data: %w", err)
   337  	}
   338  	return elem, nil
   339  }
   340  
   341  // MGet queries ShardingCache for multiple pks.
   342  func (p *ShardingCache[K, V]) MGet(ctx context.Context, pks []K) (
   343  	result map[K]V, errMap map[K]error, storageErr error) {
   344  	query := &shardingQuery[K, V]{
   345  		stor:          p.config.Storage(ctx),
   346  		mgetBatchSize: p.config.MGetBatchSize,
   347  		keyFunc:       p.config.KeyFunc,
   348  		newElemFunc:   p.newElemFunc,
   349  		pks:           pks,
   350  	}
   351  	query.Do(ctx)
   352  	return query.Result()
   353  }
   354  
   355  type shardingQuery[K comparable, V ShardingModel] struct {
   356  	stor          Storage
   357  	mgetBatchSize int
   358  	keyFunc       Key
   359  	newElemFunc   func() V
   360  
   361  	pks  []K
   362  	keys []string
   363  
   364  	shardingPKs  []K
   365  	shard0Data   []ShardingData
   366  	ithShardKeys []string
   367  
   368  	result     map[K]V
   369  	retErrs    map[K]error
   370  	storageErr error
   371  }
   372  
   373  func (sq *shardingQuery[K, V]) setError(pk K, err error) {
   374  	if sq.retErrs[pk] == nil {
   375  		if sq.retErrs == nil {
   376  			sq.retErrs = make(map[K]error)
   377  		}
   378  		sq.retErrs[pk] = err
   379  	}
   380  }
   381  
   382  func (sq *shardingQuery[K, V]) setStorageError(err error) {
   383  	if sq.storageErr == nil {
   384  		sq.storageErr = err
   385  	}
   386  }
   387  
   388  func (sq *shardingQuery[K, V]) Do(ctx context.Context) {
   389  	for _, pk := range sq.pks {
   390  		key := sq.keyFunc(pk)
   391  		sq.keys = append(sq.keys, key)
   392  	}
   393  
   394  	mgetRet, err := mgetFromStorage(ctx, sq.stor, sq.keys, sq.mgetBatchSize)
   395  	if err != nil {
   396  		sq.setStorageError(err)
   397  		return
   398  	}
   399  
   400  	newElemFunc := buildNewElemFunc[V]()
   401  	sq.result = make(map[K]V, len(sq.keys))
   402  	sq.retErrs = make(map[K]error, len(sq.keys))
   403  	for i, buf := range mgetRet {
   404  		if len(buf) == 0 {
   405  			continue
   406  		}
   407  
   408  		pk := sq.pks[i]
   409  		key := sq.keys[i]
   410  		elem := newElemFunc()
   411  		err = elem.UnmarshalBinary(buf)
   412  		if err != nil {
   413  			tmpErr := fmt.Errorf("cannot unmarshal data: %w", err)
   414  			sq.setError(pk, tmpErr)
   415  			continue
   416  		}
   417  
   418  		shard0, isShard := elem.GetShardingData()
   419  		if !isShard {
   420  			sq.result[pk] = elem
   421  			continue
   422  		}
   423  
   424  		// The cache data is a shard, we need to read data from all shards.
   425  		sq.shardingPKs = append(sq.shardingPKs, pk)
   426  		sq.shard0Data = append(sq.shard0Data, shard0)
   427  		for j := 1; j < int(shard0.TotalNum); j++ {
   428  			ithKey := GetShardKey(key, j)
   429  			sq.ithShardKeys = append(sq.ithShardKeys, ithKey)
   430  		}
   431  	}
   432  
   433  	if len(sq.shardingPKs) > 0 {
   434  		sq.queryShardingData(ctx)
   435  	}
   436  }
   437  
   438  type ithResult struct {
   439  	Value any
   440  	Err   error
   441  }
   442  
   443  func (sq *shardingQuery[K, V]) queryShardingData(ctx context.Context) {
   444  	ithShardKeys := sq.ithShardKeys
   445  	shardRet, err := mgetFromStorage(ctx, sq.stor, ithShardKeys, sq.mgetBatchSize)
   446  	if err != nil {
   447  		sq.setStorageError(err)
   448  		return
   449  	}
   450  
   451  	newElemFunc := buildNewElemFunc[V]()
   452  	ithShardMap := make(map[string]ithResult, len(ithShardKeys))
   453  	for i, ithKey := range ithShardKeys {
   454  		buf := shardRet[i]
   455  		if len(buf) == 0 {
   456  			continue
   457  		}
   458  		elem := newElemFunc()
   459  		err = elem.UnmarshalBinary(buf)
   460  		if err != nil {
   461  			tmpErr := fmt.Errorf("cannot unmarshal data: %w", err)
   462  			ithShardMap[ithKey] = ithResult{Err: tmpErr}
   463  			continue
   464  		}
   465  		ithShardMap[ithKey] = ithResult{Value: elem}
   466  	}
   467  
   468  mergeShards:
   469  	for i, pk := range sq.shardingPKs {
   470  		key := sq.keyFunc(pk)
   471  		shard0 := sq.shard0Data[i]
   472  		buf := shard0.Data
   473  		for j := 1; j < int(shard0.TotalNum); j++ {
   474  			ithKey := GetShardKey(key, j)
   475  			ithRet, exists := ithShardMap[ithKey]
   476  			if !exists {
   477  				tmpErr := fmt.Errorf("sharding data not found: %s", ithKey)
   478  				sq.retErrs[pk] = tmpErr
   479  				continue mergeShards
   480  			}
   481  			if ithRet.Err != nil {
   482  				sq.retErrs[pk] = ithRet.Err
   483  				continue mergeShards
   484  			}
   485  			ithVal := ithRet.Value.(V)
   486  			ithShard, isShard := ithVal.GetShardingData()
   487  			if !isShard || ithShard.ShardNum != int32(j) {
   488  				tmpErr := fmt.Errorf("sharding data is invalid: %s", ithKey)
   489  				sq.retErrs[pk] = tmpErr
   490  				continue mergeShards
   491  			}
   492  			if !bytes.Equal(ithShard.Digest, shard0.Digest) {
   493  				tmpErr := fmt.Errorf("sharding data digest not match: %s", ithKey)
   494  				sq.retErrs[pk] = tmpErr
   495  				continue mergeShards
   496  			}
   497  			buf = append(buf, ithShard.Data...)
   498  		}
   499  		elem := newElemFunc()
   500  		err = elem.UnmarshalBinary(buf)
   501  		if err != nil {
   502  			tmpErr := fmt.Errorf("cannot unmarshal data: %w", err)
   503  			sq.retErrs[pk] = tmpErr
   504  			continue mergeShards
   505  		}
   506  		sq.result[pk] = elem
   507  	}
   508  }
   509  
   510  func (sq *shardingQuery[K, V]) Result() (map[K]V, map[K]error, error) {
   511  	return sq.result, sq.retErrs, sq.storageErr
   512  }
   513  
   514  const shardNumSep = "__"
   515  
   516  // GetShardKey returns the shard key for given key and index of shard.
   517  func GetShardKey(key string, i int) string {
   518  	if i == 0 {
   519  		return key
   520  	}
   521  	return key + shardNumSep + strconv.Itoa(i)
   522  }
   523  
   524  func getShardNumFromKey(key string) int32 {
   525  	sepIdx := strings.LastIndex(key, shardNumSep)
   526  	if sepIdx <= 0 {
   527  		return 0
   528  	}
   529  	shardNum, _ := strconv.Atoi(key[sepIdx+len(shardNumSep):])
   530  	return int32(shardNum)
   531  }
   532  
   533  func calcDigest(data []byte) []byte {
   534  	sum := sha1.Sum(data)
   535  	return sum[:]
   536  }
   537  
   538  func mgetFromStorage(ctx context.Context, stor Storage, keys []string, batchSize int) ([][]byte, error) {
   539  	ret := make([][]byte, 0, len(keys))
   540  	for _, batchKeys := range easy.Split(keys, batchSize) {
   541  		batchRet, err := stor.MGet(ctx, batchKeys...)
   542  		if err != nil {
   543  			return nil, err
   544  		}
   545  		ret = append(ret, batchRet...)
   546  	}
   547  	return ret, nil
   548  }
   549  
   550  func msetToStorage(ctx context.Context, stor Storage, kvPairs []KVPair, expiration time.Duration, batchSize int) error {
   551  	for _, batchKVPairs := range easy.Split(kvPairs, batchSize) {
   552  		err := stor.MSet(ctx, batchKVPairs, expiration)
   553  		if err != nil {
   554  			return err
   555  		}
   556  	}
   557  	return nil
   558  }