github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/index_test.go (about) 1 // Copyright (C) 2019-2021 Zilliz. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance 4 // with the License. You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software distributed under the License 9 // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 // or implied. See the License for the specific language governing permissions and limitations under the License. 11 12 package entity 13 14 import ( 15 "fmt" 16 "math/rand" 17 "testing" 18 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 ) 22 23 func TestGenericIndex(t *testing.T) { 24 name := fmt.Sprintf("generic_index_%d", rand.Int()) 25 gi := NewGenericIndex(name, IvfFlat, map[string]string{ 26 tMetricType: string(IP), 27 }) 28 assert.Equal(t, name, gi.Name()) 29 assert.EqualValues(t, IvfFlat, gi.Params()[tIndexType]) 30 } 31 32 func TestAddRadius(t *testing.T) { 33 params := newBaseSearchParams() 34 params.AddRadius(10) 35 assert.Equal(t, params.Params()["radius"], float64(10)) 36 } 37 38 func TestAddRangeFilter(t *testing.T) { 39 params := newBaseSearchParams() 40 params.AddRangeFilter(20) 41 assert.Equal(t, params.Params()["range_filter"], float64(20)) 42 } 43 44 func TestIndexGPUCagra(t *testing.T) { 45 t.Run("index", func(t *testing.T) { 46 index, err := NewIndexGPUCagra(L2, 64, 64) 47 require.NoError(t, err) 48 require.NotNil(t, index) 49 50 assert.Equal(t, "GPUCagra", index.Name()) 51 assert.Equal(t, GPUCagra, index.IndexType()) 52 assert.False(t, index.SupportBinary()) 53 54 params := index.Params() 55 56 metricType, ok := params["metric_type"] 57 require.True(t, ok) 58 assert.Equal(t, string(L2), metricType) 59 60 indexType, ok := params["index_type"] 61 require.True(t, ok) 62 assert.Equal(t, string(GPUCagra), indexType) 63 64 _, err = NewIndexGPUCagra(L2, 32, 64) 65 assert.Error(t, err) 66 }) 67 68 t.Run("search_param", func(t *testing.T) { 69 sp, err := NewIndexGPUCagraSearchParam( 70 64, 71 1, 72 0, 73 0, 74 4, 75 ) 76 require.NoError(t, err) 77 require.NotNil(t, sp) 78 79 params := sp.Params() 80 itopkSize, ok := params["itopk_size"] 81 require.True(t, ok) 82 assert.EqualValues(t, 64, itopkSize) 83 searchWidth, ok := params["search_width"] 84 require.True(t, ok) 85 assert.EqualValues(t, 1, searchWidth) 86 maxIterations, ok := params["max_iterations"] 87 require.True(t, ok) 88 assert.EqualValues(t, 0, maxIterations) 89 minIterations, ok := params["min_iterations"] 90 require.True(t, ok) 91 assert.EqualValues(t, 0, minIterations) 92 teamSize, ok := params["team_size"] 93 require.True(t, ok) 94 assert.EqualValues(t, 4, teamSize) 95 96 _, err = NewIndexGPUCagraSearchParam( 97 64, 98 1, 99 0, 100 0, 101 3, 102 ) 103 assert.Error(t, err) 104 }) 105 } 106 107 func TestIndexGPUBruteForce(t *testing.T) { 108 t.Run("index", func(t *testing.T) { 109 index, err := NewIndexGPUBruteForce(L2) 110 require.NoError(t, err) 111 require.NotNil(t, index) 112 113 assert.Equal(t, "GPUBruteForce", index.Name()) 114 assert.Equal(t, GPUBruteForce, index.IndexType()) 115 assert.False(t, index.SupportBinary()) 116 117 params := index.Params() 118 119 metricType, ok := params["metric_type"] 120 require.True(t, ok) 121 assert.Equal(t, string(L2), metricType) 122 123 indexType, ok := params["index_type"] 124 require.True(t, ok) 125 assert.Equal(t, string(GPUBruteForce), indexType) 126 }) 127 128 t.Run("search_param", func(t *testing.T) { 129 sp, err := NewIndexGPUBruteForceSearchParam() 130 assert.NoError(t, err) 131 assert.NotNil(t, sp) 132 }) 133 }