github.com/cosmos/cosmos-sdk@v0.50.10/x/group/internal/orm/iterator.go (about) 1 package orm 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/cosmos/gogoproto/proto" 8 9 errorsmod "cosmossdk.io/errors" 10 11 "github.com/cosmos/cosmos-sdk/types/query" 12 "github.com/cosmos/cosmos-sdk/x/group/errors" 13 ) 14 15 // defaultPageLimit is the default limit value for pagination requests. 16 const defaultPageLimit = 100 17 18 // IteratorFunc is a function type that satisfies the Iterator interface 19 // The passed function is called on LoadNext operations. 20 type IteratorFunc func(dest proto.Message) (RowID, error) 21 22 // LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there 23 // are no more items the errors.ErrORMIteratorDone error is returned 24 // The key is the rowID and not any MultiKeyIndex key. 25 func (i IteratorFunc) LoadNext(dest proto.Message) (RowID, error) { 26 return i(dest) 27 } 28 29 // Close always returns nil 30 func (i IteratorFunc) Close() error { 31 return nil 32 } 33 34 func NewSingleValueIterator(rowID RowID, val []byte) Iterator { 35 var closed bool 36 return IteratorFunc(func(dest proto.Message) (RowID, error) { 37 if dest == nil { 38 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination object must not be nil") 39 } 40 if closed || val == nil { 41 return nil, errors.ErrORMIteratorDone 42 } 43 closed = true 44 return rowID, proto.Unmarshal(val, dest) 45 }) 46 } 47 48 // Iterator that return ErrORMInvalidIterator only. 49 func NewInvalidIterator() Iterator { 50 return IteratorFunc(func(dest proto.Message) (RowID, error) { 51 return nil, errors.ErrORMInvalidIterator 52 }) 53 } 54 55 // LimitedIterator returns up to defined maximum number of elements. 56 type LimitedIterator struct { 57 remainingCount int 58 parentIterator Iterator 59 } 60 61 // LimitIterator returns a new iterator that returns max number of elements. 62 // The parent iterator must not be nil 63 // max can be 0 or any positive number 64 func LimitIterator(parent Iterator, max int) (*LimitedIterator, error) { 65 if max < 0 { 66 return nil, errors.ErrORMInvalidArgument.Wrap("quantity must not be negative") 67 } 68 if parent == nil { 69 return nil, errors.ErrORMInvalidArgument.Wrap("parent iterator must not be nil") 70 } 71 return &LimitedIterator{remainingCount: max, parentIterator: parent}, nil 72 } 73 74 // LoadNext loads the next value in the sequence into the pointer passed as dest and returns the key. If there 75 // are no more items or the defined max number of elements was returned the `errors.ErrORMIteratorDone` error is returned 76 // The key is the rowID and not any MultiKeyIndex key. 77 func (i *LimitedIterator) LoadNext(dest proto.Message) (RowID, error) { 78 if i.remainingCount == 0 { 79 return nil, errors.ErrORMIteratorDone 80 } 81 i.remainingCount-- 82 return i.parentIterator.LoadNext(dest) 83 } 84 85 // Close releases the iterator and should be called at the end of iteration 86 func (i LimitedIterator) Close() error { 87 return i.parentIterator.Close() 88 } 89 90 // First loads the first element into the given destination type and closes the iterator. 91 // When the iterator is closed or has no elements the according error is passed as return value. 92 func First(it Iterator, dest proto.Message) (RowID, error) { 93 if it == nil { 94 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil") 95 } 96 defer it.Close() 97 binKey, err := it.LoadNext(dest) 98 if err != nil { 99 return nil, err 100 } 101 return binKey, nil 102 } 103 104 // Paginate does pagination with a given Iterator based on the provided 105 // PageRequest and unmarshals the results into the dest interface that must be 106 // an non-nil pointer to a slice. 107 // 108 // If pageRequest is nil, then we will use these default values: 109 // - Offset: 0 110 // - Key: nil 111 // - Limit: 100 112 // - CountTotal: true 113 // 114 // If pageRequest.Key was provided, it got used beforehand to instantiate the Iterator, 115 // using for instance UInt64Index.GetPaginated method. Only one of pageRequest.Offset or 116 // pageRequest.Key should be set. Using pageRequest.Key is more efficient for querying 117 // the next page. 118 // 119 // If pageRequest.CountTotal is set, we'll visit all iterators elements. 120 // pageRequest.CountTotal is only respected when offset is used. 121 // 122 // This function will call it.Close(). 123 func Paginate( 124 it Iterator, 125 pageRequest *query.PageRequest, 126 dest ModelSlicePtr, 127 ) (*query.PageResponse, error) { 128 // if the PageRequest is nil, use default PageRequest 129 if pageRequest == nil { 130 pageRequest = &query.PageRequest{} 131 } 132 133 offset := pageRequest.Offset 134 key := pageRequest.Key 135 limit := pageRequest.Limit 136 countTotal := pageRequest.CountTotal 137 138 if offset > 0 && key != nil { 139 return nil, fmt.Errorf("invalid request, either offset or key is expected, got both") 140 } 141 142 if limit == 0 { 143 limit = defaultPageLimit 144 145 // count total results when the limit is zero/not supplied 146 countTotal = true 147 } 148 149 if it == nil { 150 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil") 151 } 152 defer it.Close() 153 154 var destRef, tmpSlice reflect.Value 155 elemType, err := assertDest(dest, &destRef, &tmpSlice) 156 if err != nil { 157 return nil, err 158 } 159 160 end := offset + limit 161 var count uint64 162 var nextKey []byte 163 for { 164 obj := reflect.New(elemType) 165 val := obj.Elem() 166 model := obj 167 if elemType.Kind() == reflect.Ptr { 168 val.Set(reflect.New(elemType.Elem())) 169 // if elemType is already a pointer (e.g. dest being some pointer to a slice of pointers, 170 // like []*GroupMember), then obj is a pointer to a pointer which might cause issues 171 // if we try to do obj.Interface().(codec.ProtoMarshaler). 172 // For that reason, we copy obj into model if we have a simple pointer 173 // but in case elemType.Kind() == reflect.Ptr, we overwrite it with model = val 174 // so we can safely call model.Interface().(codec.ProtoMarshaler) afterwards. 175 model = val 176 } 177 178 modelProto, ok := model.Interface().(proto.Message) 179 if !ok { 180 return nil, errorsmod.Wrapf(errors.ErrORMInvalidArgument, "%s should implement codec.ProtoMarshaler", elemType) 181 } 182 binKey, err := it.LoadNext(modelProto) 183 if err != nil { 184 if errors.ErrORMIteratorDone.Is(err) { 185 break 186 } 187 return nil, err 188 } 189 190 count++ 191 192 // During the first loop, count value at this point will be 1, 193 // so if offset is >= 1, it will continue to load the next value until count > offset 194 // else (offset = 0, key might be set or not), 195 // it will start to append values to tmpSlice. 196 if count <= offset { 197 continue 198 } 199 200 if count <= end { 201 tmpSlice = reflect.Append(tmpSlice, val) 202 } else if count == end+1 { 203 nextKey = binKey 204 205 // countTotal is set to true to indicate that the result set should include 206 // a count of the total number of items available for pagination in UIs. 207 // countTotal is only respected when offset is used. It is ignored when key 208 // is set. 209 if !countTotal || len(key) != 0 { 210 break 211 } 212 } 213 } 214 destRef.Set(tmpSlice) 215 216 res := &query.PageResponse{NextKey: nextKey} 217 if countTotal && len(key) == 0 { 218 res.Total = count 219 } 220 221 return res, nil 222 } 223 224 // ModelSlicePtr represents a pointer to a slice of models. Think of it as 225 // *[]Model Because of Go's type system, using []Model type would not work for us. 226 // Instead we use a placeholder type and the validation is done during the 227 // runtime. 228 type ModelSlicePtr interface{} 229 230 // ReadAll consumes all values for the iterator and stores them in a new slice at the passed ModelSlicePtr. 231 // The slice can be empty when the iterator does not return any values but not nil. The iterator 232 // is closed afterwards. 233 // Example: 234 // 235 // var loaded []testdata.GroupInfo 236 // rowIDs, err := ReadAll(it, &loaded) 237 // require.NoError(t, err) 238 func ReadAll(it Iterator, dest ModelSlicePtr) ([]RowID, error) { 239 if it == nil { 240 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "iterator must not be nil") 241 } 242 defer it.Close() 243 244 var destRef, tmpSlice reflect.Value 245 elemType, err := assertDest(dest, &destRef, &tmpSlice) 246 if err != nil { 247 return nil, err 248 } 249 250 var rowIDs []RowID 251 for { 252 obj := reflect.New(elemType) 253 val := obj.Elem() 254 model := obj 255 if elemType.Kind() == reflect.Ptr { 256 val.Set(reflect.New(elemType.Elem())) 257 model = val 258 } 259 260 binKey, err := it.LoadNext(model.Interface().(proto.Message)) 261 switch { 262 case err == nil: 263 tmpSlice = reflect.Append(tmpSlice, val) 264 case errors.ErrORMIteratorDone.Is(err): 265 destRef.Set(tmpSlice) 266 return rowIDs, nil 267 default: 268 return nil, err 269 } 270 rowIDs = append(rowIDs, binKey) 271 } 272 } 273 274 // assertDest checks that the provided dest is not nil and a pointer to a slice. 275 // It also verifies that the slice elements implement *codec.ProtoMarshaler. 276 // It overwrites destRef and tmpSlice using reflection. 277 func assertDest(dest ModelSlicePtr, destRef, tmpSlice *reflect.Value) (reflect.Type, error) { 278 if dest == nil { 279 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must not be nil") 280 } 281 tp := reflect.ValueOf(dest) 282 if tp.Kind() != reflect.Ptr { 283 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must be a pointer to a slice") 284 } 285 if tp.Elem().Kind() != reflect.Slice { 286 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination must point to a slice") 287 } 288 289 // Since dest is just an interface{}, we overwrite destRef using reflection 290 // to have an assignable copy of it. 291 *destRef = tp.Elem() 292 // We need to verify that we can call Set() on destRef. 293 if !destRef.CanSet() { 294 return nil, errorsmod.Wrap(errors.ErrORMInvalidArgument, "destination not assignable") 295 } 296 297 elemType := reflect.TypeOf(dest).Elem().Elem() 298 299 protoMarshaler := reflect.TypeOf((*proto.Message)(nil)).Elem() 300 if !elemType.Implements(protoMarshaler) && 301 !reflect.PtrTo(elemType).Implements(protoMarshaler) { 302 return nil, errorsmod.Wrapf(errors.ErrORMInvalidArgument, "unsupported type :%s", elemType) 303 } 304 305 // tmpSlice is a slice value for the specified type 306 // that we'll use for appending new elements. 307 *tmpSlice = reflect.MakeSlice(reflect.SliceOf(elemType), 0, 0) 308 309 return elemType, nil 310 }