github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/test/testcases/hybrid_search_test.go (about)

     1  //go:build L0
     2  
     3  package testcases
     4  
     5  import (
     6  	"fmt"
     7  	"log"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/milvus-io/milvus-sdk-go/v2/client"
    14  
    15  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    16  	"github.com/milvus-io/milvus-sdk-go/v2/test/common"
    17  )
    18  
    19  func TestHybridSearchDefault(t *testing.T) {
    20  	ctx := createContext(t, time.Second*common.DefaultTimeout)
    21  	// connect
    22  	mc := createMilvusClient(ctx, t)
    23  
    24  	// create -> insert [0, 3000) -> flush -> index -> load
    25  	cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false,
    26  		ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
    27  
    28  	dp := DataParams{DoInsert: true, CollectionFieldsType: Int64FloatVec, start: 0, nb: common.DefaultNb,
    29  		dim: common.DefaultDim, EnableDynamicField: false}
    30  
    31  	collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
    32  
    33  	// hybrid search
    34  	ranker := client.NewRRFReranker()
    35  	expr := fmt.Sprintf("%s > 10", common.DefaultIntFieldName)
    36  	sp, _ := entity.NewIndexFlatSearchParam()
    37  	queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
    38  	queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
    39  	sReqs := []*client.ANNSearchRequest{
    40  		client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK),
    41  		client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK),
    42  	}
    43  	searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, ranker, sReqs)
    44  	common.CheckErr(t, errSearch, true)
    45  	common.CheckSearchResult(t, searchRes, 1, common.DefaultTopK)
    46  	common.CheckOutputFields(t, searchRes[0].Fields, []string{common.DefaultIntFieldName, common.DefaultFloatFieldName, common.DefaultFloatVecFieldName})
    47  }
    48  
    49  // hybrid search default -> verify success
    50  func TestHybridSearchMultiVectorsDefault(t *testing.T) {
    51  	t.Parallel()
    52  	ctx := createContext(t, time.Second*common.DefaultTimeout*3)
    53  	// connect
    54  	mc := createMilvusClient(ctx, t)
    55  	for _, enableDynamic := range []bool{false, true} {
    56  		// create -> insert [0, 3000) -> flush -> index -> load
    57  		cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: enableDynamic,
    58  			ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
    59  
    60  		dp := DataParams{DoInsert: true, CollectionFieldsType: AllFields, start: 0, nb: common.DefaultNb * 3,
    61  			dim: common.DefaultDim, EnableDynamicField: enableDynamic}
    62  
    63  		ips := GenDefaultIndexParamsForAllVectors()
    64  
    65  		collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
    66  
    67  		// hybrid search with different limit
    68  		type limitGroup struct {
    69  			limit1   int
    70  			limit2   int
    71  			limit3   int
    72  			expLimit int
    73  		}
    74  		limits := []limitGroup{
    75  			{limit1: 10, limit2: 5, limit3: 8, expLimit: 8},
    76  			{limit1: 10, limit2: 5, limit3: 15, expLimit: 15},
    77  			{limit1: 10, limit2: 5, limit3: 20, expLimit: 15},
    78  		}
    79  		sp, _ := entity.NewIndexFlatSearchParam()
    80  		expr := fmt.Sprintf("%s > 5", common.DefaultIntFieldName)
    81  		queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
    82  		queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloat16Vector)
    83  
    84  		// search with different reranker and limit
    85  		for _, reranker := range []client.Reranker{client.NewRRFReranker(),
    86  			client.NewWeightedReranker([]float64{0.8, 0.2}),
    87  			client.NewWeightedReranker([]float64{0.0, 0.2}),
    88  			client.NewWeightedReranker([]float64{0.4, 1.0}),
    89  		} {
    90  			for _, limit := range limits {
    91  				// hybrid search
    92  				sReqs := []*client.ANNSearchRequest{
    93  					client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, limit.limit1),
    94  					client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, limit.limit2),
    95  				}
    96  				searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, limit.limit3, []string{"*"}, reranker, sReqs)
    97  				common.CheckErr(t, errSearch, true)
    98  				common.CheckSearchResult(t, searchRes, common.DefaultNq, limit.expLimit)
    99  				common.CheckOutputFields(t, searchRes[0].Fields, common.GetAllFieldsName(enableDynamic, false))
   100  			}
   101  		}
   102  	}
   103  }
   104  
   105  // invalid limit: 0, -1, max+1
   106  // invalid WeightedReranker params
   107  // invalid fieldName: not exist
   108  // invalid metric type: mismatch
   109  func TestHybridSearchInvalidParams(t *testing.T) {
   110  	ctx := createContext(t, time.Second*common.DefaultTimeout*2)
   111  	// connect
   112  	mc := createMilvusClient(ctx, t)
   113  
   114  	// create -> insert [0, 3000) -> flush -> index -> load
   115  	cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false,
   116  		ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
   117  
   118  	dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb,
   119  		dim: common.DefaultDim, EnableDynamicField: false}
   120  
   121  	// index params
   122  	ips := GenDefaultIndexParamsForAllVectors()
   123  	collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips),
   124  		WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
   125  
   126  	// hybrid search with invalid limit
   127  	ranker := client.NewRRFReranker()
   128  	sp, _ := entity.NewIndexFlatSearchParam()
   129  	queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
   130  	queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeBinaryVector)
   131  	sReqs := []*client.ANNSearchRequest{
   132  		client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, common.DefaultTopK),
   133  		client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVec2, sp, common.DefaultTopK),
   134  	}
   135  	for _, invalidLimit := range []int{-1, 0, common.MaxTopK + 1} {
   136  		sReqsInvalid := []*client.ANNSearchRequest{
   137  			client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, invalidLimit)}
   138  
   139  		for _, sReq := range [][]*client.ANNSearchRequest{sReqs, sReqsInvalid} {
   140  			_, errSearch := mc.HybridSearch(ctx, collName, []string{}, invalidLimit, []string{}, ranker, sReq)
   141  			common.CheckErr(t, errSearch, false, "should be greater than 0", "should be in range [1, 16384]")
   142  		}
   143  	}
   144  
   145  	// hybrid search with invalid WeightedReranker params
   146  	for _, invalidRanker := range []client.Reranker{
   147  		client.NewWeightedReranker([]float64{-1, 0.2}),
   148  		client.NewWeightedReranker([]float64{1.2, 0.2}),
   149  		client.NewWeightedReranker([]float64{0.2}),
   150  		client.NewWeightedReranker([]float64{0.2, 0.7, 0.5}),
   151  	} {
   152  		_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, invalidRanker, sReqs)
   153  		common.CheckErr(t, errSearch, false, "rank param weight should be in range [0, 1]",
   154  			"the length of weights param mismatch with ann search requests")
   155  	}
   156  
   157  	// invalid fieldName: not exist
   158  	sReqs = append(sReqs, client.NewANNSearchRequest("a", entity.L2, "", queryVec1, sp, common.DefaultTopK))
   159  	_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, sReqs)
   160  	common.CheckErr(t, errSearch, false, "failed to get field schema by name: fieldName(a) not found")
   161  
   162  	// invalid metric type: mismatch
   163  	sReqsInvalidMetric := []*client.ANNSearchRequest{
   164  		client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.COSINE, "", queryVec1, sp, common.DefaultTopK),
   165  		client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVec2, sp, common.DefaultTopK),
   166  	}
   167  	_, errSearch = mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, sReqsInvalidMetric)
   168  	common.CheckErr(t, errSearch, false, "metric type not match: invalid parameter")
   169  }
   170  
   171  // len(nq) != 1
   172  // vector type mismatch: vectors: float32, queryVec: binary
   173  // vector dim mismatch
   174  func TestHybridSearchInvalidVectors(t *testing.T) {
   175  	ctx := createContext(t, time.Second*common.DefaultTimeout*2)
   176  	// connect
   177  	mc := createMilvusClient(ctx, t)
   178  
   179  	// create -> insert [0, 3000) -> flush -> index -> load
   180  	cp := CollectionParams{CollectionFieldsType: Int64FloatVec, AutoID: false, EnableDynamicField: false,
   181  		ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
   182  
   183  	dp := DataParams{DoInsert: true, CollectionFieldsType: Int64FloatVec, start: 0, nb: common.DefaultNb,
   184  		dim: common.DefaultDim, EnableDynamicField: false}
   185  
   186  	collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
   187  
   188  	// hybrid search with invalid limit
   189  	ranker := client.NewRRFReranker()
   190  	sp, _ := entity.NewIndexFlatSearchParam()
   191  	// queryVecNq := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
   192  	queryVecBinary := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeBinaryVector)
   193  	queryVecType := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector)
   194  	queryVecDim := common.GenSearchVectors(1, common.DefaultDim*2, entity.FieldTypeFloatVector)
   195  	sReqs := [][]*client.ANNSearchRequest{
   196  		// {client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecNq, sp, common.DefaultTopK)},           // nq != 1
   197  		{client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecType, sp, common.DefaultTopK)},         // TODO vector type not match
   198  		{client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVecDim, sp, common.DefaultTopK)},          // vector dim not match
   199  		{client.NewANNSearchRequest(common.DefaultBinaryVecFieldName, entity.JACCARD, "", queryVecBinary, sp, common.DefaultTopK)}, // not exist vector types
   200  	}
   201  	for idx, invalidSReq := range sReqs {
   202  		log.Println(idx)
   203  		_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, ranker, invalidSReq)
   204  		common.CheckErr(t, errSearch, false, "nq should be equal to 1", "vector dimension mismatch",
   205  			"failed to get field schema by name", "vector type must be the same")
   206  	}
   207  }
   208  
   209  // hybrid search Pagination -> verify success
   210  func TestHybridSearchMultiVectorsPagination(t *testing.T) {
   211  	t.Parallel()
   212  	ctx := createContext(t, time.Second*common.DefaultTimeout*2)
   213  	// connect
   214  	mc := createMilvusClient(ctx, t)
   215  
   216  	// create -> insert [0, 3000) -> flush -> index -> load
   217  	cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false,
   218  		ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
   219  
   220  	dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb * 5,
   221  		dim: common.DefaultDim, EnableDynamicField: false}
   222  
   223  	// index params
   224  	ips := GenDefaultIndexParamsForAllVectors()
   225  	collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
   226  
   227  	// hybrid search with different limit
   228  	sp, _ := entity.NewIndexFlatSearchParam()
   229  	expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName)
   230  	queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
   231  	queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector)
   232  	// milvus ignore invalid offset with ANNSearchRequest
   233  	for _, invalidOffset := range []int64{-1, common.MaxTopK + 1} {
   234  		sReqs := []*client.ANNSearchRequest{
   235  			client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec1, sp, common.DefaultTopK, client.WithOffset(invalidOffset)),
   236  			client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, "", queryVec2, sp, common.DefaultTopK),
   237  		}
   238  		_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs)
   239  		common.CheckErr(t, errSearch, true)
   240  
   241  		//hybrid search with invalid offset
   242  		_, errSearch = mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithOffset(invalidOffset))
   243  		common.CheckErr(t, errSearch, false, "should be gte than 0", "(offset+limit) should be in range [1, 16384]")
   244  	}
   245  
   246  	// search with different reranker and offset
   247  	for _, reranker := range []client.Reranker{
   248  		client.NewRRFReranker(),
   249  		client.NewWeightedReranker([]float64{0.8, 0.2}),
   250  		client.NewWeightedReranker([]float64{0.0, 0.2}),
   251  		client.NewWeightedReranker([]float64{0.4, 1.0}),
   252  	} {
   253  		sReqs := []*client.ANNSearchRequest{
   254  			client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK),
   255  			client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK),
   256  		}
   257  		// hybrid search
   258  		searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, reranker, sReqs)
   259  		common.CheckErr(t, errSearch, true)
   260  		offsetRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, 5, []string{}, reranker, sReqs, client.WithOffset(5))
   261  		common.CheckErr(t, errSearch, true)
   262  		common.CheckSearchResult(t, searchRes, 1, common.DefaultTopK)
   263  		common.CheckSearchResult(t, offsetRes, 1, 5)
   264  		for i := 0; i < len(searchRes); i++ {
   265  			require.Equal(t, searchRes[i].IDs.(*entity.ColumnInt64).Data()[5:], offsetRes[i].IDs.(*entity.ColumnInt64).Data())
   266  		}
   267  	}
   268  }
   269  
   270  // hybrid search Pagination -> verify success
   271  func TestHybridSearchMultiVectorsRangeSearch(t *testing.T) {
   272  	ctx := createContext(t, time.Second*common.DefaultTimeout*5)
   273  	// connect
   274  	mc := createMilvusClient(ctx, t)
   275  
   276  	// create -> insert [0, 3000) -> flush -> index -> load
   277  	cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false,
   278  		ShardsNum: common.DefaultShards, Dim: common.DefaultDim}
   279  
   280  	dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb * 3,
   281  		dim: common.DefaultDim, EnableDynamicField: false}
   282  
   283  	// index params
   284  	ips := GenDefaultIndexParamsForAllVectors()
   285  	collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
   286  
   287  	// hybrid search
   288  	sp, _ := entity.NewIndexFlatSearchParam()
   289  	expr := fmt.Sprintf("%s > 4", common.DefaultIntFieldName)
   290  	queryVec1 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloatVector)
   291  	queryVec2 := common.GenSearchVectors(1, common.DefaultDim, entity.FieldTypeFloat16Vector)
   292  
   293  	// search with different reranker and offset
   294  	sp.AddRadius(20)
   295  	sp.AddRangeFilter(0.01)
   296  	for _, reranker := range []client.Reranker{
   297  		client.NewRRFReranker(),
   298  		client.NewWeightedReranker([]float64{0.8, 0.2}),
   299  		client.NewWeightedReranker([]float64{0.5, 0.5}),
   300  	} {
   301  		sReqs := []*client.ANNSearchRequest{
   302  			client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK*2, client.WithOffset(1)),
   303  			client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK),
   304  		}
   305  		// hybrid search
   306  		resRange, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, reranker, sReqs)
   307  		common.CheckErr(t, errSearch, true)
   308  		common.CheckSearchResult(t, resRange, 1, common.DefaultTopK)
   309  		for _, res := range resRange {
   310  			for _, score := range res.Scores {
   311  				require.GreaterOrEqual(t, score, float32(0.01))
   312  				require.LessOrEqual(t, score, float32(20))
   313  			}
   314  		}
   315  	}
   316  }
   317  
   318  func TestHybridSearchSparseVector(t *testing.T) {
   319  	t.Parallel()
   320  	idxInverted := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_INVERTED_INDEX", map[string]string{"drop_ratio_build": "0.2", "metric_type": "IP"})
   321  	idxWand := entity.NewGenericIndex(common.DefaultSparseVecFieldName, "SPARSE_WAND", map[string]string{"drop_ratio_build": "0.3", "metric_type": "IP"})
   322  	for _, idx := range []entity.Index{idxInverted, idxWand} {
   323  		ctx := createContext(t, time.Second*common.DefaultTimeout*2)
   324  		// connect
   325  		mc := createMilvusClient(ctx, t)
   326  
   327  		// create -> insert [0, 3000) -> flush -> index -> load
   328  		cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true,
   329  			ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen}
   330  
   331  		dp := DataParams{DoInsert: true, CollectionFieldsType: Int64VarcharSparseVec, start: 0, nb: common.DefaultNb * 3,
   332  			dim: common.DefaultDim, EnableDynamicField: true}
   333  
   334  		// index params
   335  		idxHnsw, _ := entity.NewIndexHNSW(entity.L2, 8, 96)
   336  		ips := []IndexParams{
   337  			{BuildIndex: true, Index: idx, FieldName: common.DefaultSparseVecFieldName, async: false},
   338  			{BuildIndex: true, Index: idxHnsw, FieldName: common.DefaultFloatVecFieldName, async: false},
   339  		}
   340  		collName := prepareCollection(ctx, t, mc, cp, WithDataParams(dp), WithIndexParams(ips), WithCreateOption(client.WithConsistencyLevel(entity.ClStrong)))
   341  
   342  		// search
   343  		queryVec1 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim*2, entity.FieldTypeSparseVector)
   344  		queryVec2 := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
   345  		sp1, _ := entity.NewIndexSparseInvertedSearchParam(0.2)
   346  		sp2, _ := entity.NewIndexHNSWSearchParam(20)
   347  		expr := fmt.Sprintf("%s > 1", common.DefaultIntFieldName)
   348  		sReqs := []*client.ANNSearchRequest{
   349  			client.NewANNSearchRequest(common.DefaultSparseVecFieldName, entity.IP, expr, queryVec1, sp1, common.DefaultTopK),
   350  			client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, "", queryVec2, sp2, common.DefaultTopK),
   351  		}
   352  		for _, reranker := range []client.Reranker{
   353  			client.NewRRFReranker(),
   354  			client.NewWeightedReranker([]float64{0.5, 0.6}),
   355  		} {
   356  			// hybrid search
   357  			searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{"*"}, reranker, sReqs)
   358  			common.CheckErr(t, errSearch, true)
   359  			common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultTopK)
   360  			common.CheckErr(t, errSearch, true)
   361  			outputFields := []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName, common.DefaultFloatVecFieldName,
   362  				common.DefaultSparseVecFieldName, common.DefaultDynamicFieldName}
   363  			common.CheckOutputFields(t, searchRes[0].Fields, outputFields)
   364  		}
   365  	}
   366  }