github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/test/testcases/hybrid_search_test.go (about) 1 //go:build L0 2 3 package testcases 4 5 import ( 6 "fmt" 7 "log" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/require" 12 13 "github.com/milvus-io/milvus-sdk-go/v2/client" 14 15 "github.com/milvus-io/milvus-sdk-go/v2/entity" 16 "github.com/milvus-io/milvus-sdk-go/v2/test/common" 17 ) 18 19 func TestHybridSearchDefault(t *testing.T) { 20 ctx := createContext(t, time.Second*common.DefaultTimeout) 21 // connect 22 mc := createMilvusClient(ctx, t) 23 24 // create -> insert [0, 3000) -> flush -> index -> load 25 cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false, 26 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 27 28 dp := DataParams{DoInsert: true, CollectionFieldsType: Int64FloatVec, start: 0, nb: common.DefaultNb, 29 dim: common.DefaultDim, EnableDynamicField: false} 30 31 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 32 33 // hybrid search 34 ranker := client.NewRRFReranker() 35 expr := fmt.Sprintf("%s > 10", common.DefaultIntFieldName) 36 sp, _ := entity.NewIndexFlatSearchParam() 37 queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) 38 queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) 39 sReqs := []*client.ANNSearchRequest{ 40 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK), 41 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK), 42 } 43 searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, ranker, sReqs) 44 common.CheckErr(t, errSearch, true) 45 common.CheckSearchResult(t, searchRes, 1, common.DefaultTopK) 46 common.CheckOutputFields(t, searchRes[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultFloatVecFieldName}) 47 } 48 49 // hybrid search default -> verify success 50 func TestHybridSearchMultiVectorsDefault(t *testing.T) { 51 t.Parallel() 52 ctx := createContext(t, time.Second*common.DefaultTimeout*3) 53 // connect 54 mc := createMilvusClient(ctx, t) 55 for _, enableDynamic := range []bool{false, true} { 56 // create -> insert [0, 3000) -> flush -> index -> load 57 cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: enableDynamic, 58 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 59 60 dp := DataParams{DoInsert: true, CollectionFieldsType: AllFields, start: 0, nb: common.DefaultNb * 3, 61 dim: common.DefaultDim, EnableDynamicField: enableDynamic} 62 63 ips := GenDefaultIndexParamsForAllVectors() 64 65 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 66 67 // hybrid search with different limit 68 type limitGroup struct { 69 limit1 int 70 limit2 int 71 limit3 int 72 expLimit int 73 } 74 limits := []limitGroup{ 75 {limit1: 10, limit2: 5, limit3: 8, expLimit: 8}, 76 {limit1: 10, limit2: 5, limit3: 15, expLimit: 15}, 77 {limit1: 10, limit2: 5, limit3: 20, expLimit: 15}, 78 } 79 sp, _ := entity.NewIndexFlatSearchParam() 80 expr := fmt.Sprintf("%s > 5", common.DefaultIntFieldName) 81 queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) 82 queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector) 83 84 // search with different reranker and limit 85 for _, reranker := range []client.Reranker{client.NewRRFReranker(), 86 client.NewWeightedReranker([]float64{0.8, 0.2}), 87 client.NewWeightedReranker([]float64{0.0, 0.2}), 88 client.NewWeightedReranker([]float64{0.4, 1.0}), 89 } { 90 for _, limit := range limits { 91 // hybrid search 92 sReqs := []*client.ANNSearchRequest{ 93 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, limit.limit1), 94 client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, limit.limit2), 95 } 96 searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, limit.limit3, []string{"*"}, reranker, sReqs) 97 common.CheckErr(t, errSearch, true) 98 common.CheckSearchResult(t, searchRes, common.DefaultNq, limit.expLimit) 99 common.CheckOutputFields(t, searchRes[0].Fields, common.GetAllFieldsName(enableDynamic, false)) 100 } 101 } 102 } 103 } 104 105 // invalid limit: 0, -1, max+1 106 // invalid WeightedReranker params 107 // invalid fieldName: not exist 108 // invalid metric type: mismatch 109 func TestHybridSearchInvalidParams(t *testing.T) { 110 ctx := createContext(t, time.Second*common.DefaultTimeout*2) 111 // connect 112 mc := createMilvusClient(ctx, t) 113 114 // create -> insert [0, 3000) -> flush -> index -> load 115 cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false, 116 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 117 118 dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb, 119 dim: common.DefaultDim, EnableDynamicField: false} 120 121 // index params 122 ips := GenDefaultIndexParamsForAllVectors() 123 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), 124 WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 125 126 // hybrid search with invalid limit 127 ranker := client.NewRRFReranker() 128 sp, _ := entity.NewIndexFlatSearchParam() 129 queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) 130 queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeBinaryVector) 131 sReqs := []*client.ANNSearchRequest{ 132 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, common.DefaultTopK), 133 client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVec2, sp, common.DefaultTopK), 134 } 135 for _, invalidLimit := range []int{-1, 0, common.MaxTopK + 1} { 136 sReqsInvalid := []*client.ANNSearchRequest{ 137 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, invalidLimit)} 138 139 for _, sReq := range [][]*client.ANNSearchRequest{sReqs, sReqsInvalid} { 140 _, errSearch := mc.HybridSearch(ctx, collName, []string{}, invalidLimit, []string{}, ranker, sReq) 141 common.CheckErr(t, errSearch, false, "should be greater than 0", "should be in range [1, 16384]") 142 } 143 } 144 145 // hybrid search with invalid WeightedReranker params 146 for _, invalidRanker := range []client.Reranker{ 147 client.NewWeightedReranker([]float64{-1, 0.2}), 148 client.NewWeightedReranker([]float64{1.2, 0.2}), 149 client.NewWeightedReranker([]float64{0.2}), 150 client.NewWeightedReranker([]float64{0.2, 0.7, 0.5}), 151 } { 152 _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, invalidRanker, sReqs) 153 common.CheckErr(t, errSearch, false, "rank param weight should be in range [0, 1]", 154 "the length of weights param mismatch with ann search requests") 155 } 156 157 // invalid fieldName: not exist 158 sReqs = append(sReqs, client.NewANNSearchRequest("a", entity.L2, "", queryVec1, sp, common.DefaultTopK)) 159 _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, sReqs) 160 common.CheckErr(t, errSearch, false, "failed to get field schema by name: fieldName(a) not found") 161 162 // invalid metric type: mismatch 163 sReqsInvalidMetric := []*client.ANNSearchRequest{ 164 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, "", queryVec1, sp, common.DefaultTopK), 165 client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVec2, sp, common.DefaultTopK), 166 } 167 _, errSearch = mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, sReqsInvalidMetric) 168 common.CheckErr(t, errSearch, false, "metric type not match: invalid parameter") 169 } 170 171 // len(nq) != 1 172 // vector type mismatch: vectors: float32, queryVec: binary 173 // vector dim mismatch 174 func TestHybridSearchInvalidVectors(t *testing.T) { 175 ctx := createContext(t, time.Second*common.DefaultTimeout*2) 176 // connect 177 mc := createMilvusClient(ctx, t) 178 179 // create -> insert [0, 3000) -> flush -> index -> load 180 cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false, 181 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 182 183 dp := DataParams{DoInsert: true, CollectionFieldsType: Int64FloatVec, start: 0, nb: common.DefaultNb, 184 dim: common.DefaultDim, EnableDynamicField: false} 185 186 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 187 188 // hybrid search with invalid limit 189 ranker := client.NewRRFReranker() 190 sp, _ := entity.NewIndexFlatSearchParam() 191 // queryVecNq := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) 192 queryVecBinary := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeBinaryVector) 193 queryVecType := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) 194 queryVecDim := common.GenSearchVectors(1, common.DefaultDim*2, entity.FieldTypeFloatVector) 195 sReqs := [][]*client.ANNSearchRequest{ 196 // {client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecNq, sp, common.DefaultTopK)}, // nq != 1 197 {client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecType, sp, common.DefaultTopK)}, // TODO vector type not match 198 {client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecDim, sp, common.DefaultTopK)}, // vector dim not match 199 {client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVecBinary, sp, common.DefaultTopK)}, // not exist vector types 200 } 201 for idx, invalidSReq := range sReqs { 202 log.Println(idx) 203 _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, invalidSReq) 204 common.CheckErr(t, errSearch, false, "nq should be equal to 1", "vector dimension mismatch", 205 "failed to get field schema by name", "vector type must be the same") 206 } 207 } 208 209 // hybrid search Pagination -> verify success 210 func TestHybridSearchMultiVectorsPagination(t *testing.T) { 211 t.Parallel() 212 ctx := createContext(t, time.Second*common.DefaultTimeout*2) 213 // connect 214 mc := createMilvusClient(ctx, t) 215 216 // create -> insert [0, 3000) -> flush -> index -> load 217 cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false, 218 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 219 220 dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb * 5, 221 dim: common.DefaultDim, EnableDynamicField: false} 222 223 // index params 224 ips := GenDefaultIndexParamsForAllVectors() 225 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 226 227 // hybrid search with different limit 228 sp, _ := entity.NewIndexFlatSearchParam() 229 expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName) 230 queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) 231 queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) 232 // milvus ignore invalid offset with ANNSearchRequest 233 for _, invalidOffset := range []int64{-1, common.MaxTopK + 1} { 234 sReqs := []*client.ANNSearchRequest{ 235 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, common.DefaultTopK, client.WithOffset(invalidOffset)), 236 client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, "", queryVec2, sp, common.DefaultTopK), 237 } 238 _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs) 239 common.CheckErr(t, errSearch, true) 240 241 //hybrid search with invalid offset 242 _, errSearch = mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithOffset(invalidOffset)) 243 common.CheckErr(t, errSearch, false, "should be gte than 0", "(offset+limit) should be in range [1, 16384]") 244 } 245 246 // search with different reranker and offset 247 for _, reranker := range []client.Reranker{ 248 client.NewRRFReranker(), 249 client.NewWeightedReranker([]float64{0.8, 0.2}), 250 client.NewWeightedReranker([]float64{0.0, 0.2}), 251 client.NewWeightedReranker([]float64{0.4, 1.0}), 252 } { 253 sReqs := []*client.ANNSearchRequest{ 254 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK), 255 client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK), 256 } 257 // hybrid search 258 searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, reranker, sReqs) 259 common.CheckErr(t, errSearch, true) 260 offsetRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, 5, []string{}, reranker, sReqs, client.WithOffset(5)) 261 common.CheckErr(t, errSearch, true) 262 common.CheckSearchResult(t, searchRes, 1, common.DefaultTopK) 263 common.CheckSearchResult(t, offsetRes, 1, 5) 264 for i := 0; i < len(searchRes); i++ { 265 require.Equal(t, searchRes[i].IDs.(*entity.ColumnInt64).Data()[5:], offsetRes[i].IDs.(*entity.ColumnInt64).Data()) 266 } 267 } 268 } 269 270 // hybrid search Pagination -> verify success 271 func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) { 272 ctx := createContext(t, time.Second*common.DefaultTimeout*5) 273 // connect 274 mc := createMilvusClient(ctx, t) 275 276 // create -> insert [0, 3000) -> flush -> index -> load 277 cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false, 278 ShardsNum: common.DefaultShards, Dim: common.DefaultDim} 279 280 dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb * 3, 281 dim: common.DefaultDim, EnableDynamicField: false} 282 283 // index params 284 ips := GenDefaultIndexParamsForAllVectors() 285 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 286 287 // hybrid search 288 sp, _ := entity.NewIndexFlatSearchParam() 289 expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName) 290 queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector) 291 queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector) 292 293 // search with different reranker and offset 294 sp.AddRadius(20) 295 sp.AddRangeFilter(0.01) 296 for _, reranker := range []client.Reranker{ 297 client.NewRRFReranker(), 298 client.NewWeightedReranker([]float64{0.8, 0.2}), 299 client.NewWeightedReranker([]float64{0.5, 0.5}), 300 } { 301 sReqs := []*client.ANNSearchRequest{ 302 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK*2, client.WithOffset(1)), 303 client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK), 304 } 305 // hybrid search 306 resRange, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, reranker, sReqs) 307 common.CheckErr(t, errSearch, true) 308 common.CheckSearchResult(t, resRange, 1, common.DefaultTopK) 309 for _, res := range resRange { 310 for _, score := range res.Scores { 311 require.GreaterOrEqual(t, score, float32(0.01)) 312 require.LessOrEqual(t, score, float32(20)) 313 } 314 } 315 } 316 } 317 318 func TestHybridSearchSparseVector(t *testing.T) { 319 t.Parallel() 320 idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"}) 321 idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"}) 322 for _, idx := range []entity.Index{idxInverted, idxWand} { 323 ctx := createContext(t, time.Second*common.DefaultTimeout*2) 324 // connect 325 mc := createMilvusClient(ctx, t) 326 327 // create -> insert [0, 3000) -> flush -> index -> load 328 cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true, 329 ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen} 330 331 dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 3, 332 dim: common.DefaultDim, EnableDynamicField: true} 333 334 // index params 335 idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96) 336 ips := []IndexParams{ 337 {BuildIndex: true, Index: idx, FieldName: common.DefaultSparseVecFieldName, async: false}, 338 {BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false}, 339 } 340 collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong))) 341 342 // search 343 queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector) 344 queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) 345 sp1, _ := entity.NewIndexSparseInvertedSearchParam(0.2) 346 sp2, _ := entity.NewIndexHNSWSearchParam(20) 347 expr := fmt.Sprintf("%s > 1", common.DefaultIntFieldName) 348 sReqs := []*client.ANNSearchRequest{ 349 client.NewANNSearchRequest(common.DefaultSparseVecFieldName, entity.IP, expr, queryVec1, sp1, common.DefaultTopK), 350 client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec2, sp2, common.DefaultTopK), 351 } 352 for _, reranker := range []client.Reranker{ 353 client.NewRRFReranker(), 354 client.NewWeightedReranker([]float64{0.5, 0.6}), 355 } { 356 // hybrid search 357 searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, reranker, sReqs) 358 common.CheckErr(t, errSearch, true) 359 common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK) 360 common.CheckErr(t, errSearch, true) 361 outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName, 362 common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName} 363 common.CheckOutputFields(t, searchRes[0].Fields, outputFields) 364 } 365 } 366 }