github.com/cosmos/cosmos-sdk@v0.50.10/x/group/internal/orm/iterator.go (about)

     1  package orm
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/cosmos/gogoproto/proto"
     8  
     9  	errorsmod "cosmossdk.io/errors"
    10  
    11  	"github.com/cosmos/cosmos-sdk/types/query"
    12  	"github.com/cosmos/cosmos-sdk/x/group/errors"
    13  )
    14  
    15  // defaultPageLimit is the default limit value for pagination requests.
    16  const defaultPageLimit = 100
    17  
    18  // IteratorFunc is a function type that satisfies the Iterator interface
    19  // The passed function is called on LoadNext operations.
    20  type IteratorFunc func(dest proto.Message) (RowID, error)
    21  
    22  // LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there
    23  // are no more items the errors.ErrORMIteratorDone error is returned
    24  // The key is the rowID and not any MultiKeyIndex key.
    25  func (i IteratorFunc) LoadNext(dest proto.Message) (RowID, error) {
    26  	return i(dest)
    27  }
    28  
    29  // Close always returns nil
    30  func (i IteratorFunc) Close() error {
    31  	return nil
    32  }
    33  
    34  func NewSingleValueIterator(rowID RowID, val []byte) Iterator {
    35  	var closed bool
    36  	return IteratorFunc(func(dest proto.Message) (RowID, error) {
    37  		if dest == nil {
    38  			return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination object must not be nil")
    39  		}
    40  		if closed || val == nil {
    41  			return nil, errors.ErrORMIteratorDone
    42  		}
    43  		closed = true
    44  		return rowID, proto.Unmarshal(val, dest)
    45  	})
    46  }
    47  
    48  // Iterator that return ErrORMInvalidIterator only.
    49  func NewInvalidIterator() Iterator {
    50  	return IteratorFunc(func(dest proto.Message) (RowID, error) {
    51  		return nil, errors.ErrORMInvalidIterator
    52  	})
    53  }
    54  
    55  // LimitedIterator returns up to defined maximum number of elements.
    56  type LimitedIterator struct {
    57  	remainingCount int
    58  	parentIterator Iterator
    59  }
    60  
    61  // LimitIterator returns a new iterator that returns max number of elements.
    62  // The parent iterator must not be nil
    63  // max can be 0 or any positive number
    64  func LimitIterator(parent Iterator, max int) (*LimitedIterator, error) {
    65  	if max < 0 {
    66  		return nil, errors.ErrORMInvalidArgument.Wrap("quantity must not be negative")
    67  	}
    68  	if parent == nil {
    69  		return nil, errors.ErrORMInvalidArgument.Wrap("parent iterator must not be nil")
    70  	}
    71  	return &LimitedIterator{remainingCount: max, parentIterator: parent}, nil
    72  }
    73  
    74  // LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there
    75  // are no more items or the defined max number of elements was returned the `errors.ErrORMIteratorDone` error is returned
    76  // The key is the rowID and not any MultiKeyIndex key.
    77  func (i *LimitedIterator) LoadNext(dest proto.Message) (RowID, error) {
    78  	if i.remainingCount == 0 {
    79  		return nil, errors.ErrORMIteratorDone
    80  	}
    81  	i.remainingCount--
    82  	return i.parentIterator.LoadNext(dest)
    83  }
    84  
    85  // Close releases the iterator and should be called at the end of iteration
    86  func (i LimitedIterator) Close() error {
    87  	return i.parentIterator.Close()
    88  }
    89  
    90  // First loads the first element into the given destination type and closes the iterator.
    91  // When the iterator is closed or has no elements the according error is passed as return value.
    92  func First(it Iterator, dest proto.Message) (RowID, error) {
    93  	if it == nil {
    94  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
    95  	}
    96  	defer it.Close()
    97  	binKey, err := it.LoadNext(dest)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return binKey, nil
   102  }
   103  
   104  // Paginate does pagination with a given Iterator based on the provided
   105  // PageRequest and unmarshals the results into the dest interface that must be
   106  // an non-nil pointer to a slice.
   107  //
   108  // If pageRequest is nil, then we will use these default values:
   109  //   - Offset: 0
   110  //   - Key: nil
   111  //   - Limit: 100
   112  //   - CountTotal: true
   113  //
   114  // If pageRequest.Key was provided, it got used beforehand to instantiate the Iterator,
   115  // using for instance UInt64Index.GetPaginated method. Only one of pageRequest.Offset or
   116  // pageRequest.Key should be set. Using pageRequest.Key is more efficient for querying
   117  // the next page.
   118  //
   119  // If pageRequest.CountTotal is set, we'll visit all iterators elements.
   120  // pageRequest.CountTotal is only respected when offset is used.
   121  //
   122  // This function will call it.Close().
   123  func Paginate(
   124  	it Iterator,
   125  	pageRequest *query.PageRequest,
   126  	dest ModelSlicePtr,
   127  ) (*query.PageResponse, error) {
   128  	// if the PageRequest is nil, use default PageRequest
   129  	if pageRequest == nil {
   130  		pageRequest = &query.PageRequest{}
   131  	}
   132  
   133  	offset := pageRequest.Offset
   134  	key := pageRequest.Key
   135  	limit := pageRequest.Limit
   136  	countTotal := pageRequest.CountTotal
   137  
   138  	if offset > 0 && key != nil {
   139  		return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
   140  	}
   141  
   142  	if limit == 0 {
   143  		limit = defaultPageLimit
   144  
   145  		// count total results when the limit is zero/not supplied
   146  		countTotal = true
   147  	}
   148  
   149  	if it == nil {
   150  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
   151  	}
   152  	defer it.Close()
   153  
   154  	var destRef, tmpSlice reflect.Value
   155  	elemType, err := assertDest(dest, &destRef, &tmpSlice)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	end := offset + limit
   161  	var count uint64
   162  	var nextKey []byte
   163  	for {
   164  		obj := reflect.New(elemType)
   165  		val := obj.Elem()
   166  		model := obj
   167  		if elemType.Kind() == reflect.Ptr {
   168  			val.Set(reflect.New(elemType.Elem()))
   169  			// if elemType is already a pointer (e.g. dest being some pointer to a slice of pointers,
   170  			// like []*GroupMember), then obj is a pointer to a pointer which might cause issues
   171  			// if we try to do obj.Interface().(codec.ProtoMarshaler).
   172  			// For that reason, we copy obj into model if we have a simple pointer
   173  			// but in case elemType.Kind() == reflect.Ptr, we overwrite it with model = val
   174  			// so we can safely call model.Interface().(codec.ProtoMarshaler) afterwards.
   175  			model = val
   176  		}
   177  
   178  		modelProto, ok := model.Interface().(proto.Message)
   179  		if !ok {
   180  			return nil, errorsmod.Wrapf(errors.ErrORMInvalidArgument, "%s should implement codec.ProtoMarshaler", elemType)
   181  		}
   182  		binKey, err := it.LoadNext(modelProto)
   183  		if err != nil {
   184  			if errors.ErrORMIteratorDone.Is(err) {
   185  				break
   186  			}
   187  			return nil, err
   188  		}
   189  
   190  		count++
   191  
   192  		// During the first loop, count value at this point will be 1,
   193  		// so if offset is >= 1, it will continue to load the next value until count > offset
   194  		// else (offset = 0, key might be set or not),
   195  		// it will start to append values to tmpSlice.
   196  		if count <= offset {
   197  			continue
   198  		}
   199  
   200  		if count <= end {
   201  			tmpSlice = reflect.Append(tmpSlice, val)
   202  		} else if count == end+1 {
   203  			nextKey = binKey
   204  
   205  			// countTotal is set to true to indicate that the result set should include
   206  			// a count of the total number of items available for pagination in UIs.
   207  			// countTotal is only respected when offset is used. It is ignored when key
   208  			// is set.
   209  			if !countTotal || len(key) != 0 {
   210  				break
   211  			}
   212  		}
   213  	}
   214  	destRef.Set(tmpSlice)
   215  
   216  	res := &query.PageResponse{NextKey: nextKey}
   217  	if countTotal && len(key) == 0 {
   218  		res.Total = count
   219  	}
   220  
   221  	return res, nil
   222  }
   223  
   224  // ModelSlicePtr represents a pointer to a slice of models. Think of it as
   225  // *[]Model Because of Go's type system, using []Model type would not work for us.
   226  // Instead we use a placeholder type and the validation is done during the
   227  // runtime.
   228  type ModelSlicePtr interface{}
   229  
   230  // ReadAll consumes all values for the iterator and stores them in a new slice at the passed ModelSlicePtr.
   231  // The slice can be empty when the iterator does not return any values but not nil. The iterator
   232  // is closed afterwards.
   233  // Example:
   234  //
   235  //	var loaded []testdata.GroupInfo
   236  //	rowIDs, err := ReadAll(it, &loaded)
   237  //	require.NoError(t, err)
   238  func ReadAll(it Iterator, dest ModelSlicePtr) ([]RowID, error) {
   239  	if it == nil {
   240  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil")
   241  	}
   242  	defer it.Close()
   243  
   244  	var destRef, tmpSlice reflect.Value
   245  	elemType, err := assertDest(dest, &destRef, &tmpSlice)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	var rowIDs []RowID
   251  	for {
   252  		obj := reflect.New(elemType)
   253  		val := obj.Elem()
   254  		model := obj
   255  		if elemType.Kind() == reflect.Ptr {
   256  			val.Set(reflect.New(elemType.Elem()))
   257  			model = val
   258  		}
   259  
   260  		binKey, err := it.LoadNext(model.Interface().(proto.Message))
   261  		switch {
   262  		case err == nil:
   263  			tmpSlice = reflect.Append(tmpSlice, val)
   264  		case errors.ErrORMIteratorDone.Is(err):
   265  			destRef.Set(tmpSlice)
   266  			return rowIDs, nil
   267  		default:
   268  			return nil, err
   269  		}
   270  		rowIDs = append(rowIDs, binKey)
   271  	}
   272  }
   273  
   274  // assertDest checks that the provided dest is not nil and a pointer to a slice.
   275  // It also verifies that the slice elements implement *codec.ProtoMarshaler.
   276  // It overwrites destRef and tmpSlice using reflection.
   277  func assertDest(dest ModelSlicePtr, destRef, tmpSlice *reflect.Value) (reflect.Type, error) {
   278  	if dest == nil {
   279  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must not be nil")
   280  	}
   281  	tp := reflect.ValueOf(dest)
   282  	if tp.Kind() != reflect.Ptr {
   283  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must be a pointer to a slice")
   284  	}
   285  	if tp.Elem().Kind() != reflect.Slice {
   286  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must point to a slice")
   287  	}
   288  
   289  	// Since dest is just an interface{}, we overwrite destRef using reflection
   290  	// to have an assignable copy of it.
   291  	*destRef = tp.Elem()
   292  	// We need to verify that we can call Set() on destRef.
   293  	if !destRef.CanSet() {
   294  		return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination not assignable")
   295  	}
   296  
   297  	elemType := reflect.TypeOf(dest).Elem().Elem()
   298  
   299  	protoMarshaler := reflect.TypeOf((*proto.Message)(nil)).Elem()
   300  	if !elemType.Implements(protoMarshaler) &&
   301  		!reflect.PtrTo(elemType).Implements(protoMarshaler) {
   302  		return nil, errorsmod.Wrapf(errors.ErrORMInvalidArgument, "unsupported type :%s", elemType)
   303  	}
   304  
   305  	// tmpSlice is a slice value for the specified type
   306  	// that we'll use for appending new elements.
   307  	*tmpSlice = reflect.MakeSlice(reflect.SliceOf(elemType), 0, 0)
   308  
   309  	return elemType, nil
   310  }