github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/ann_request.go (about) 1 package client 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "strconv" 7 8 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 9 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 10 "github.com/milvus-io/milvus-sdk-go/v2/entity" 11 ) 12 13 type ANNSearchRequest struct { 14 fieldName string 15 vectors []entity.Vector 16 metricType entity.MetricType 17 expr string 18 searchParam entity.SearchParam 19 options []SearchQueryOptionFunc 20 limit int 21 } 22 23 func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, expr string, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest { 24 return &ANNSearchRequest{ 25 fieldName: fieldName, 26 vectors: vectors, 27 metricType: metricsType, 28 expr: expr, 29 searchParam: searchParam, 30 limit: limit, 31 options: options, 32 } 33 } 34 func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest { 35 r.expr = expr 36 return r 37 } 38 39 func (r *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo, opts ...SearchQueryOptionFunc) (*milvuspb.SearchRequest, error) { 40 opt := &SearchQueryOption{ 41 ConsistencyLevel: collectionInfo.ConsistencyLevel, // default 42 } 43 for _, o := range r.options { 44 o(opt) 45 } 46 for _, o := range opts { 47 o(opt) 48 } 49 params := r.searchParam.Params() 50 params[forTuningKey] = opt.ForTuning 51 bs, err := json.Marshal(params) 52 if err != nil { 53 return nil, err 54 } 55 56 sp := map[string]string{ 57 "anns_field": r.fieldName, 58 "topk": fmt.Sprintf("%d", r.limit), 59 "params": string(bs), 60 "metric_type": string(r.metricType), 61 "round_decimal": "-1", 62 ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing), 63 offsetKey: fmt.Sprintf("%d", opt.Offset), 64 } 65 if opt.GroupByField != "" { 66 sp[groupByKey] = opt.GroupByField 67 } 68 69 searchParams := entity.MapKvPairs(sp) 70 71 result := &milvuspb.SearchRequest{ 72 DbName: "", 73 Dsl: r.expr, 74 PlaceholderGroup: vector2PlaceholderGroupBytes(r.vectors), 75 DslType: commonpb.DslType_BoolExprV1, 76 SearchParams: searchParams, 77 GuaranteeTimestamp: opt.GuaranteeTimestamp, 78 Nq: int64(len(r.vectors)), 79 } 80 return result, nil 81 }