github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/iterator.go (about) 1 package client 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "strings" 8 9 "github.com/cockroachdb/errors" 10 "github.com/milvus-io/milvus-sdk-go/v2/entity" 11 ) 12 13 func NewQueryIteratorOption(collectionName string) *QueryIteratorOption { 14 return &QueryIteratorOption{ 15 collectionName: collectionName, 16 batchSize: 1000, 17 } 18 } 19 20 type QueryIteratorOption struct { 21 collectionName string 22 partitionNames []string 23 expr string 24 outputFields []string 25 batchSize int 26 } 27 28 func (opt *QueryIteratorOption) WithPartitions(partitionNames ...string) *QueryIteratorOption { 29 opt.partitionNames = partitionNames 30 return opt 31 } 32 33 func (opt *QueryIteratorOption) WithExpr(expr string) *QueryIteratorOption { 34 opt.expr = expr 35 return opt 36 } 37 38 func (opt *QueryIteratorOption) WithOutputFields(outputFields ...string) *QueryIteratorOption { 39 opt.outputFields = outputFields 40 return opt 41 } 42 43 func (opt *QueryIteratorOption) WithBatchSize(batchSize int) *QueryIteratorOption { 44 opt.batchSize = batchSize 45 return opt 46 } 47 48 func (c *GrpcClient) QueryIterator(ctx context.Context, opt *QueryIteratorOption) (*QueryIterator, error) { 49 collectionName := opt.collectionName 50 var sch *entity.Schema 51 collInfo, ok := MetaCache.getCollectionInfo(collectionName) 52 if !ok { 53 coll, err := c.DescribeCollection(ctx, collectionName) 54 if err != nil { 55 return nil, err 56 } 57 sch = coll.Schema 58 } else { 59 sch = collInfo.Schema 60 } 61 62 itr := &QueryIterator{ 63 client: c, 64 65 collectionName: opt.collectionName, 66 partitionNames: opt.partitionNames, 67 outputFields: opt.outputFields, 68 sch: sch, 69 pkField: sch.PKField(), 70 71 batchSize: opt.batchSize, 72 expr: opt.expr, 73 } 74 75 err := itr.init(ctx) 76 if err != nil { 77 return nil, err 78 } 79 return itr, nil 80 } 81 82 type QueryIterator struct { 83 // user provided expression 84 expr string 85 86 batchSize int 87 88 cached ResultSet 89 90 collectionName string 91 partitionNames []string 92 outputFields []string 93 sch *entity.Schema 94 pkField *entity.Field 95 96 lastPK interface{} 97 98 // internal grpc client 99 client *GrpcClient 100 } 101 102 // init fetches the first batch of data and put it into cache. 103 // this operation could be used to check all the parameters before returning the iterator. 104 func (itr *QueryIterator) init(ctx context.Context) error { 105 if itr.batchSize <= 0 { 106 return errors.New("batch size cannot less than 1") 107 } 108 109 rs, err := itr.fetchNextBatch(ctx) 110 if err != nil { 111 return err 112 } 113 itr.cached = rs 114 return nil 115 } 116 117 func (itr *QueryIterator) composeIteratorExpr() string { 118 if itr.lastPK == nil { 119 return itr.expr 120 } 121 122 expr := strings.TrimSpace(itr.expr) 123 124 switch itr.pkField.DataType { 125 case entity.FieldTypeInt64: 126 if len(expr) == 0 { 127 expr = fmt.Sprintf("%s > %d", itr.pkField.Name, itr.lastPK) 128 } else { 129 expr = fmt.Sprintf("(%s) and %s > %d", expr, itr.pkField.Name, itr.lastPK) 130 } 131 case entity.FieldTypeVarChar: 132 if len(expr) == 0 { 133 expr = fmt.Sprintf(`%s > "%s"`, itr.pkField.Name, itr.lastPK) 134 } else { 135 expr = fmt.Sprintf(`(%s) and %s > "%s"`, expr, itr.pkField.Name, itr.lastPK) 136 } 137 default: 138 return itr.expr 139 } 140 return expr 141 } 142 143 func (itr *QueryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) { 144 return itr.client.Query(ctx, itr.collectionName, itr.partitionNames, itr.composeIteratorExpr(), itr.outputFields, 145 WithLimit(int64(float64(itr.batchSize))), withIterator(), reduceForBest(true)) 146 } 147 148 func (itr *QueryIterator) cachedSufficient() bool { 149 return itr.cached != nil && itr.cached.Len() >= itr.batchSize 150 } 151 152 func (itr *QueryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) { 153 result := rs.Slice(0, itr.batchSize) 154 itr.cached = rs.Slice(itr.batchSize, -1) 155 156 pkColumn := result.GetColumn(itr.pkField.Name) 157 switch itr.pkField.DataType { 158 case entity.FieldTypeInt64: 159 itr.lastPK, _ = pkColumn.GetAsInt64(pkColumn.Len() - 1) 160 case entity.FieldTypeVarChar: 161 itr.lastPK, _ = pkColumn.GetAsString(pkColumn.Len() - 1) 162 default: 163 return nil, errors.Newf("unsupported pk type: %v", itr.pkField.DataType) 164 } 165 return result, nil 166 } 167 168 func (itr *QueryIterator) Next(ctx context.Context) (ResultSet, error) { 169 var rs ResultSet 170 var err error 171 172 // check cache sufficient for next batch 173 if !itr.cachedSufficient() { 174 rs, err = itr.fetchNextBatch(ctx) 175 if err != nil { 176 return nil, err 177 } 178 } else { 179 rs = itr.cached 180 } 181 182 // if resultset is empty, return EOF 183 if rs.Len() == 0 { 184 return nil, io.EOF 185 } 186 187 return itr.cacheNextBatch(rs) 188 }