github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/ann_request.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"strconv"
     7  
     8  	"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
     9  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    10  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    11  )
    12  
    13  type ANNSearchRequest struct {
    14  	fieldName   string
    15  	vectors     []entity.Vector
    16  	metricType  entity.MetricType
    17  	expr        string
    18  	searchParam entity.SearchParam
    19  	options     []SearchQueryOptionFunc
    20  	limit       int
    21  }
    22  
    23  func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, expr string, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest {
    24  	return &ANNSearchRequest{
    25  		fieldName:   fieldName,
    26  		vectors:     vectors,
    27  		metricType:  metricsType,
    28  		expr:        expr,
    29  		searchParam: searchParam,
    30  		limit:       limit,
    31  		options:     options,
    32  	}
    33  }
    34  func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest {
    35  	r.expr = expr
    36  	return r
    37  }
    38  
    39  func (r *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo, opts ...SearchQueryOptionFunc) (*milvuspb.SearchRequest, error) {
    40  	opt := &SearchQueryOption{
    41  		ConsistencyLevel: collectionInfo.ConsistencyLevel, // default
    42  	}
    43  	for _, o := range r.options {
    44  		o(opt)
    45  	}
    46  	for _, o := range opts {
    47  		o(opt)
    48  	}
    49  	params := r.searchParam.Params()
    50  	params[forTuningKey] = opt.ForTuning
    51  	bs, err := json.Marshal(params)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	sp := map[string]string{
    57  		"anns_field":     r.fieldName,
    58  		"topk":           fmt.Sprintf("%d", r.limit),
    59  		"params":         string(bs),
    60  		"metric_type":    string(r.metricType),
    61  		"round_decimal":  "-1",
    62  		ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
    63  		offsetKey:        fmt.Sprintf("%d", opt.Offset),
    64  	}
    65  	if opt.GroupByField != "" {
    66  		sp[groupByKey] = opt.GroupByField
    67  	}
    68  
    69  	searchParams := entity.MapKvPairs(sp)
    70  
    71  	result := &milvuspb.SearchRequest{
    72  		DbName:             "",
    73  		Dsl:                r.expr,
    74  		PlaceholderGroup:   vector2PlaceholderGroupBytes(r.vectors),
    75  		DslType:            commonpb.DslType_BoolExprV1,
    76  		SearchParams:       searchParams,
    77  		GuaranteeTimestamp: opt.GuaranteeTimestamp,
    78  		Nq:                 int64(len(r.vectors)),
    79  	}
    80  	return result, nil
    81  }