github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/index_test.go (about)

     1  // Copyright (C) 2019-2021 Zilliz. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
     4  // with the License. You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software distributed under the License
     9  // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
    10  // or implied. See the License for the specific language governing permissions and limitations under the License.
    11  
    12  package entity
    13  
    14  import (
    15  	"fmt"
    16  	"math/rand"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestGenericIndex(t *testing.T) {
    24  	name := fmt.Sprintf("generic_index_%d", rand.Int())
    25  	gi := NewGenericIndex(name, IvfFlat, map[string]string{
    26  		tMetricType: string(IP),
    27  	})
    28  	assert.Equal(t, name, gi.Name())
    29  	assert.EqualValues(t, IvfFlat, gi.Params()[tIndexType])
    30  }
    31  
    32  func TestAddRadius(t *testing.T) {
    33  	params := newBaseSearchParams()
    34  	params.AddRadius(10)
    35  	assert.Equal(t, params.Params()["radius"], float64(10))
    36  }
    37  
    38  func TestAddRangeFilter(t *testing.T) {
    39  	params := newBaseSearchParams()
    40  	params.AddRangeFilter(20)
    41  	assert.Equal(t, params.Params()["range_filter"], float64(20))
    42  }
    43  
    44  func TestIndexGPUCagra(t *testing.T) {
    45  	t.Run("index", func(t *testing.T) {
    46  		index, err := NewIndexGPUCagra(L2, 64, 64)
    47  		require.NoError(t, err)
    48  		require.NotNil(t, index)
    49  
    50  		assert.Equal(t, "GPUCagra", index.Name())
    51  		assert.Equal(t, GPUCagra, index.IndexType())
    52  		assert.False(t, index.SupportBinary())
    53  
    54  		params := index.Params()
    55  
    56  		metricType, ok := params["metric_type"]
    57  		require.True(t, ok)
    58  		assert.Equal(t, string(L2), metricType)
    59  
    60  		indexType, ok := params["index_type"]
    61  		require.True(t, ok)
    62  		assert.Equal(t, string(GPUCagra), indexType)
    63  
    64  		_, err = NewIndexGPUCagra(L2, 32, 64)
    65  		assert.Error(t, err)
    66  	})
    67  
    68  	t.Run("search_param", func(t *testing.T) {
    69  		sp, err := NewIndexGPUCagraSearchParam(
    70  			64,
    71  			1,
    72  			0,
    73  			0,
    74  			4,
    75  		)
    76  		require.NoError(t, err)
    77  		require.NotNil(t, sp)
    78  
    79  		params := sp.Params()
    80  		itopkSize, ok := params["itopk_size"]
    81  		require.True(t, ok)
    82  		assert.EqualValues(t, 64, itopkSize)
    83  		searchWidth, ok := params["search_width"]
    84  		require.True(t, ok)
    85  		assert.EqualValues(t, 1, searchWidth)
    86  		maxIterations, ok := params["max_iterations"]
    87  		require.True(t, ok)
    88  		assert.EqualValues(t, 0, maxIterations)
    89  		minIterations, ok := params["min_iterations"]
    90  		require.True(t, ok)
    91  		assert.EqualValues(t, 0, minIterations)
    92  		teamSize, ok := params["team_size"]
    93  		require.True(t, ok)
    94  		assert.EqualValues(t, 4, teamSize)
    95  
    96  		_, err = NewIndexGPUCagraSearchParam(
    97  			64,
    98  			1,
    99  			0,
   100  			0,
   101  			3,
   102  		)
   103  		assert.Error(t, err)
   104  	})
   105  }
   106  
   107  func TestIndexGPUBruteForce(t *testing.T) {
   108  	t.Run("index", func(t *testing.T) {
   109  		index, err := NewIndexGPUBruteForce(L2)
   110  		require.NoError(t, err)
   111  		require.NotNil(t, index)
   112  
   113  		assert.Equal(t, "GPUBruteForce", index.Name())
   114  		assert.Equal(t, GPUBruteForce, index.IndexType())
   115  		assert.False(t, index.SupportBinary())
   116  
   117  		params := index.Params()
   118  
   119  		metricType, ok := params["metric_type"]
   120  		require.True(t, ok)
   121  		assert.Equal(t, string(L2), metricType)
   122  
   123  		indexType, ok := params["index_type"]
   124  		require.True(t, ok)
   125  		assert.Equal(t, string(GPUBruteForce), indexType)
   126  	})
   127  
   128  	t.Run("search_param", func(t *testing.T) {
   129  		sp, err := NewIndexGPUBruteForceSearchParam()
   130  		assert.NoError(t, err)
   131  		assert.NotNil(t, sp)
   132  	})
   133  }