github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/iterator.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"strings"
     8  
     9  	"github.com/cockroachdb/errors"
    10  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    11  )
    12  
    13  func NewQueryIteratorOption(collectionName string) *QueryIteratorOption {
    14  	return &QueryIteratorOption{
    15  		collectionName: collectionName,
    16  		batchSize:      1000,
    17  	}
    18  }
    19  
    20  type QueryIteratorOption struct {
    21  	collectionName string
    22  	partitionNames []string
    23  	expr           string
    24  	outputFields   []string
    25  	batchSize      int
    26  }
    27  
    28  func (opt *QueryIteratorOption) WithPartitions(partitionNames ...string) *QueryIteratorOption {
    29  	opt.partitionNames = partitionNames
    30  	return opt
    31  }
    32  
    33  func (opt *QueryIteratorOption) WithExpr(expr string) *QueryIteratorOption {
    34  	opt.expr = expr
    35  	return opt
    36  }
    37  
    38  func (opt *QueryIteratorOption) WithOutputFields(outputFields ...string) *QueryIteratorOption {
    39  	opt.outputFields = outputFields
    40  	return opt
    41  }
    42  
    43  func (opt *QueryIteratorOption) WithBatchSize(batchSize int) *QueryIteratorOption {
    44  	opt.batchSize = batchSize
    45  	return opt
    46  }
    47  
    48  func (c *GrpcClient) QueryIterator(ctx context.Context, opt *QueryIteratorOption) (*QueryIterator, error) {
    49  	collectionName := opt.collectionName
    50  	var sch *entity.Schema
    51  	collInfo, ok := MetaCache.getCollectionInfo(collectionName)
    52  	if !ok {
    53  		coll, err := c.DescribeCollection(ctx, collectionName)
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  		sch = coll.Schema
    58  	} else {
    59  		sch = collInfo.Schema
    60  	}
    61  
    62  	itr := &QueryIterator{
    63  		client: c,
    64  
    65  		collectionName: opt.collectionName,
    66  		partitionNames: opt.partitionNames,
    67  		outputFields:   opt.outputFields,
    68  		sch:            sch,
    69  		pkField:        sch.PKField(),
    70  
    71  		batchSize: opt.batchSize,
    72  		expr:      opt.expr,
    73  	}
    74  
    75  	err := itr.init(ctx)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return itr, nil
    80  }
    81  
    82  type QueryIterator struct {
    83  	// user provided expression
    84  	expr string
    85  
    86  	batchSize int
    87  
    88  	cached ResultSet
    89  
    90  	collectionName string
    91  	partitionNames []string
    92  	outputFields   []string
    93  	sch            *entity.Schema
    94  	pkField        *entity.Field
    95  
    96  	lastPK interface{}
    97  
    98  	// internal grpc client
    99  	client *GrpcClient
   100  }
   101  
   102  // init fetches the first batch of data and put it into cache.
   103  // this operation could be used to check all the parameters before returning the iterator.
   104  func (itr *QueryIterator) init(ctx context.Context) error {
   105  	if itr.batchSize <= 0 {
   106  		return errors.New("batch size cannot less than 1")
   107  	}
   108  
   109  	rs, err := itr.fetchNextBatch(ctx)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	itr.cached = rs
   114  	return nil
   115  }
   116  
   117  func (itr *QueryIterator) composeIteratorExpr() string {
   118  	if itr.lastPK == nil {
   119  		return itr.expr
   120  	}
   121  
   122  	expr := strings.TrimSpace(itr.expr)
   123  
   124  	switch itr.pkField.DataType {
   125  	case entity.FieldTypeInt64:
   126  		if len(expr) == 0 {
   127  			expr = fmt.Sprintf("%s > %d", itr.pkField.Name, itr.lastPK)
   128  		} else {
   129  			expr = fmt.Sprintf("(%s) and %s > %d", expr, itr.pkField.Name, itr.lastPK)
   130  		}
   131  	case entity.FieldTypeVarChar:
   132  		if len(expr) == 0 {
   133  			expr = fmt.Sprintf(`%s > "%s"`, itr.pkField.Name, itr.lastPK)
   134  		} else {
   135  			expr = fmt.Sprintf(`(%s) and %s > "%s"`, expr, itr.pkField.Name, itr.lastPK)
   136  		}
   137  	default:
   138  		return itr.expr
   139  	}
   140  	return expr
   141  }
   142  
   143  func (itr *QueryIterator) fetchNextBatch(ctx context.Context) (ResultSet, error) {
   144  	return itr.client.Query(ctx, itr.collectionName, itr.partitionNames, itr.composeIteratorExpr(), itr.outputFields,
   145  		WithLimit(int64(float64(itr.batchSize))), withIterator(), reduceForBest(true))
   146  }
   147  
   148  func (itr *QueryIterator) cachedSufficient() bool {
   149  	return itr.cached != nil && itr.cached.Len() >= itr.batchSize
   150  }
   151  
   152  func (itr *QueryIterator) cacheNextBatch(rs ResultSet) (ResultSet, error) {
   153  	result := rs.Slice(0, itr.batchSize)
   154  	itr.cached = rs.Slice(itr.batchSize, -1)
   155  
   156  	pkColumn := result.GetColumn(itr.pkField.Name)
   157  	switch itr.pkField.DataType {
   158  	case entity.FieldTypeInt64:
   159  		itr.lastPK, _ = pkColumn.GetAsInt64(pkColumn.Len() - 1)
   160  	case entity.FieldTypeVarChar:
   161  		itr.lastPK, _ = pkColumn.GetAsString(pkColumn.Len() - 1)
   162  	default:
   163  		return nil, errors.Newf("unsupported pk type: %v", itr.pkField.DataType)
   164  	}
   165  	return result, nil
   166  }
   167  
   168  func (itr *QueryIterator) Next(ctx context.Context) (ResultSet, error) {
   169  	var rs ResultSet
   170  	var err error
   171  
   172  	// check cache sufficient for next batch
   173  	if !itr.cachedSufficient() {
   174  		rs, err = itr.fetchNextBatch(ctx)
   175  		if err != nil {
   176  			return nil, err
   177  		}
   178  	} else {
   179  		rs = itr.cached
   180  	}
   181  
   182  	// if resultset is empty, return EOF
   183  	if rs.Len() == 0 {
   184  		return nil, io.EOF
   185  	}
   186  
   187  	return itr.cacheNextBatch(rs)
   188  }