github.com/cosmos/cosmos-sdk@v0.50.10/types/query/collections_pagination.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"cosmossdk.io/collections"
     9  	collcodec "cosmossdk.io/collections/codec"
    10  	storetypes "cosmossdk.io/store/types"
    11  )
    12  
    13  // WithCollectionPaginationPairPrefix applies a prefix to a collection, whose key is a collection.Pair,
    14  // being paginated that needs prefixing.
    15  func WithCollectionPaginationPairPrefix[K1, K2 any](prefix K1) func(o *CollectionsPaginateOptions[collections.Pair[K1, K2]]) {
    16  	return func(o *CollectionsPaginateOptions[collections.Pair[K1, K2]]) {
    17  		prefix := collections.PairPrefix[K1, K2](prefix)
    18  		o.Prefix = &prefix
    19  	}
    20  }
    21  
    22  // CollectionsPaginateOptions provides extra options for pagination in collections.
    23  type CollectionsPaginateOptions[K any] struct {
    24  	// Prefix allows to optionally set a prefix for the pagination.
    25  	Prefix *K
    26  }
    27  
    28  // Collection defines the minimum required API of a collection
    29  // to work with pagination.
    30  type Collection[K, V any] interface {
    31  	// IterateRaw allows to iterate over a raw set of byte keys.
    32  	IterateRaw(ctx context.Context, start, end []byte, order collections.Order) (collections.Iterator[K, V], error)
    33  	// KeyCodec exposes the KeyCodec of a collection, required to encode a collection key from and to bytes
    34  	// for pagination request and response.
    35  	KeyCodec() collcodec.KeyCodec[K]
    36  }
    37  
    38  // CollectionPaginate follows the same logic as Paginate but for collection types.
    39  // transformFunc is used to transform the result to a different type.
    40  func CollectionPaginate[K, V any, C Collection[K, V], T any](
    41  	ctx context.Context,
    42  	coll C,
    43  	pageReq *PageRequest,
    44  	transformFunc func(key K, value V) (T, error),
    45  	opts ...func(opt *CollectionsPaginateOptions[K]),
    46  ) ([]T, *PageResponse, error) {
    47  	return CollectionFilteredPaginate(
    48  		ctx,
    49  		coll,
    50  		pageReq,
    51  		nil,
    52  		transformFunc,
    53  		opts...,
    54  	)
    55  }
    56  
    57  // CollectionFilteredPaginate works in the same way as CollectionPaginate but allows to filter
    58  // results using a predicateFunc.
    59  // A nil predicateFunc means no filtering is applied and results are collected as is.
    60  // TransformFunc is applied only to results which are in range of the pagination and allow
    61  // to convert the result to a different type.
    62  // NOTE: do not collect results using the values/keys passed to predicateFunc as they are not
    63  // guaranteed to be in the pagination range requested.
    64  func CollectionFilteredPaginate[K, V any, C Collection[K, V], T any](
    65  	ctx context.Context,
    66  	coll C,
    67  	pageReq *PageRequest,
    68  	predicateFunc func(key K, value V) (include bool, err error),
    69  	transformFunc func(key K, value V) (T, error),
    70  	opts ...func(opt *CollectionsPaginateOptions[K]),
    71  ) (results []T, pageRes *PageResponse, err error) {
    72  	pageReq = initPageRequestDefaults(pageReq)
    73  
    74  	offset := pageReq.Offset
    75  	key := pageReq.Key
    76  	limit := pageReq.Limit
    77  	countTotal := pageReq.CountTotal
    78  	reverse := pageReq.Reverse
    79  
    80  	if offset > 0 && key != nil {
    81  		return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
    82  	}
    83  
    84  	opt := new(CollectionsPaginateOptions[K])
    85  	for _, o := range opts {
    86  		o(opt)
    87  	}
    88  
    89  	var prefix []byte
    90  	if opt.Prefix != nil {
    91  		prefix, err = encodeCollKey[K, V](coll, *opt.Prefix)
    92  		if err != nil {
    93  			return nil, nil, err
    94  		}
    95  	}
    96  
    97  	if len(key) != 0 {
    98  		results, pageRes, err = collFilteredPaginateByKey(ctx, coll, prefix, key, reverse, limit, predicateFunc, transformFunc)
    99  	} else {
   100  		results, pageRes, err = collFilteredPaginateNoKey(ctx, coll, prefix, reverse, offset, limit, countTotal, predicateFunc, transformFunc)
   101  	}
   102  	// invalid iter error is ignored to retain Paginate behavior
   103  	if errors.Is(err, collections.ErrInvalidIterator) {
   104  		return results, new(PageResponse), nil
   105  	}
   106  	// strip the prefix from next key
   107  	if len(pageRes.NextKey) != 0 && prefix != nil {
   108  		pageRes.NextKey = pageRes.NextKey[len(prefix):]
   109  	}
   110  	return results, pageRes, err
   111  }
   112  
   113  // collFilteredPaginateNoKey applies the provided pagination on the collection when the starting key is not set.
   114  // If predicateFunc is nil no filtering is applied.
   115  func collFilteredPaginateNoKey[K, V any, C Collection[K, V], T any](
   116  	ctx context.Context,
   117  	coll C,
   118  	prefix []byte,
   119  	reverse bool,
   120  	offset uint64,
   121  	limit uint64,
   122  	countTotal bool,
   123  	predicateFunc func(K, V) (bool, error),
   124  	transformFunc func(K, V) (T, error),
   125  ) ([]T, *PageResponse, error) {
   126  	iterator, err := getCollIter[K, V](ctx, coll, prefix, nil, reverse)
   127  	if err != nil {
   128  		return nil, nil, err
   129  	}
   130  	defer iterator.Close()
   131  	// we advance the iter equal to the provided offset
   132  	if !advanceIter(iterator, offset) {
   133  		return nil, nil, collections.ErrInvalidIterator
   134  	}
   135  
   136  	var (
   137  		count   uint64
   138  		nextKey []byte
   139  		results []T
   140  	)
   141  
   142  	for ; iterator.Valid(); iterator.Next() {
   143  		switch {
   144  		// first case, we still haven't found all the results up to the limit
   145  		case count < limit:
   146  			kv, err := iterator.KeyValue()
   147  			if err != nil {
   148  				return nil, nil, err
   149  			}
   150  			// if no predicate function is specified then we just include the result
   151  			if predicateFunc == nil {
   152  				transformed, err := transformFunc(kv.Key, kv.Value)
   153  				if err != nil {
   154  					return nil, nil, err
   155  				}
   156  				results = append(results, transformed)
   157  				count++
   158  
   159  				// if predicate function is defined we check if the result matches the filtering criteria
   160  			} else {
   161  				include, err := predicateFunc(kv.Key, kv.Value)
   162  				if err != nil {
   163  					return nil, nil, err
   164  				}
   165  				if include {
   166  					transformed, err := transformFunc(kv.Key, kv.Value)
   167  					if err != nil {
   168  						return nil, nil, err
   169  					}
   170  					results = append(results, transformed)
   171  					count++
   172  				}
   173  			}
   174  		// second case, we found all the objects specified within the limit
   175  		case count == limit:
   176  			key, err := iterator.Key()
   177  			if err != nil {
   178  				return nil, nil, err
   179  			}
   180  			nextKey, err = encodeCollKey[K, V](coll, key)
   181  			if err != nil {
   182  				return nil, nil, err
   183  			}
   184  			// if count total was not specified, we return the next key only
   185  			if !countTotal {
   186  				return results, &PageResponse{
   187  					NextKey: nextKey,
   188  				}, nil
   189  			}
   190  			// otherwise we fallthrough the third case
   191  			fallthrough
   192  		// this is the case in which we found all the required results
   193  		// but we need to count how many possible results exist in total.
   194  		// so we keep increasing the count until the iterator is fully consumed.
   195  		case count > limit:
   196  			if predicateFunc == nil {
   197  				count++
   198  
   199  				// if predicate function is defined we check if the result matches the filtering criteria
   200  			} else {
   201  				kv, err := iterator.KeyValue()
   202  				if err != nil {
   203  					return nil, nil, err
   204  				}
   205  
   206  				include, err := predicateFunc(kv.Key, kv.Value)
   207  				if err != nil {
   208  					return nil, nil, err
   209  				}
   210  				if include {
   211  					count++
   212  				}
   213  			}
   214  		}
   215  	}
   216  
   217  	resp := &PageResponse{
   218  		NextKey: nextKey,
   219  	}
   220  
   221  	if countTotal {
   222  		resp.Total = count + offset
   223  	}
   224  	return results, resp, nil
   225  }
   226  
   227  func advanceIter[I interface {
   228  	Next()
   229  	Valid() bool
   230  }](iter I, offset uint64,
   231  ) bool {
   232  	for i := uint64(0); i < offset; i++ {
   233  		if !iter.Valid() {
   234  			return false
   235  		}
   236  		iter.Next()
   237  	}
   238  	return true
   239  }
   240  
   241  // collFilteredPaginateByKey paginates a collection when a starting key
   242  // is provided in the PageRequest. Predicate is applied only if not nil.
   243  func collFilteredPaginateByKey[K, V any, C Collection[K, V], T any](
   244  	ctx context.Context,
   245  	coll C,
   246  	prefix []byte,
   247  	key []byte,
   248  	reverse bool,
   249  	limit uint64,
   250  	predicateFunc func(key K, value V) (bool, error),
   251  	transformFunc func(key K, value V) (transformed T, err error),
   252  ) (results []T, pageRes *PageResponse, err error) {
   253  	iterator, err := getCollIter[K, V](ctx, coll, prefix, key, reverse)
   254  	if err != nil {
   255  		return nil, nil, err
   256  	}
   257  	defer iterator.Close()
   258  
   259  	var (
   260  		count   uint64
   261  		nextKey []byte
   262  	)
   263  
   264  	for ; iterator.Valid(); iterator.Next() {
   265  		// if we reached the specified limit
   266  		// then we get the next key, and we exit the iteration.
   267  		if count == limit {
   268  			concreteKey, err := iterator.Key()
   269  			if err != nil {
   270  				return nil, nil, err
   271  			}
   272  
   273  			nextKey, err = encodeCollKey[K, V](coll, concreteKey)
   274  			if err != nil {
   275  				return nil, nil, err
   276  			}
   277  			break
   278  		}
   279  
   280  		kv, err := iterator.KeyValue()
   281  		if err != nil {
   282  			return nil, nil, err
   283  		}
   284  		// if no predicate is specified then we just append the result
   285  		if predicateFunc == nil {
   286  			transformed, err := transformFunc(kv.Key, kv.Value)
   287  			if err != nil {
   288  				return nil, nil, err
   289  			}
   290  			results = append(results, transformed)
   291  			// if predicate is applied we execute the predicate function
   292  			// and append only if predicateFunc yields true.
   293  		} else {
   294  			include, err := predicateFunc(kv.Key, kv.Value)
   295  			if err != nil {
   296  				return nil, nil, err
   297  			}
   298  			if include {
   299  				transformed, err := transformFunc(kv.Key, kv.Value)
   300  				if err != nil {
   301  					return nil, nil, err
   302  				}
   303  				results = append(results, transformed)
   304  			}
   305  		}
   306  		count++
   307  	}
   308  
   309  	return results, &PageResponse{
   310  		NextKey: nextKey,
   311  	}, nil
   312  }
   313  
   314  // todo maybe move to collections?
   315  func encodeCollKey[K, V any, C Collection[K, V]](coll C, key K) ([]byte, error) {
   316  	buffer := make([]byte, coll.KeyCodec().Size(key))
   317  	_, err := coll.KeyCodec().Encode(buffer, key)
   318  	return buffer, err
   319  }
   320  
   321  func getCollIter[K, V any, C Collection[K, V]](ctx context.Context, coll C, prefix, start []byte, reverse bool) (collections.Iterator[K, V], error) {
   322  	// TODO: maybe can be simplified
   323  	if reverse {
   324  		// if we are in reverse mode, we need to increase the start key
   325  		// to include the start key in the iteration.
   326  		start = storetypes.PrefixEndBytes(append(prefix, start...))
   327  		end := prefix
   328  
   329  		return coll.IterateRaw(ctx, end, start, collections.OrderDescending)
   330  	}
   331  	var end []byte
   332  	if prefix != nil {
   333  		start = append(prefix, start...)
   334  		end = storetypes.PrefixEndBytes(prefix)
   335  	}
   336  	return coll.IterateRaw(ctx, start, end, collections.OrderAscending)
   337  }