github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/columns_sparse_test.go (about)

     1  // Copyright (C) 2019-2021 Zilliz. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
     4  // with the License. You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software distributed under the License
     9  // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
    10  // or implied. See the License for the specific language governing permissions and limitations under the License.
    11  
    12  package entity
    13  
    14  import (
    15  	"fmt"
    16  	"math/rand"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestSliceSparseEmbedding(t *testing.T) {
    24  	t.Run("normal_case", func(t *testing.T) {
    25  
    26  		length := 1 + rand.Intn(5)
    27  		positions := make([]uint32, length)
    28  		values := make([]float32, length)
    29  		for i := 0; i < length; i++ {
    30  			positions[i] = uint32(i)
    31  			values[i] = rand.Float32()
    32  		}
    33  		se, err := NewSliceSparseEmbedding(positions, values)
    34  		require.NoError(t, err)
    35  
    36  		assert.EqualValues(t, length, se.Dim())
    37  		assert.EqualValues(t, length, se.Len())
    38  
    39  		bs := se.Serialize()
    40  		nv, err := deserializeSliceSparceEmbedding(bs)
    41  		require.NoError(t, err)
    42  
    43  		for i := 0; i < length; i++ {
    44  			pos, val, ok := se.Get(i)
    45  			require.True(t, ok)
    46  			assert.Equal(t, positions[i], pos)
    47  			assert.Equal(t, values[i], val)
    48  
    49  			npos, nval, ok := nv.Get(i)
    50  			require.True(t, ok)
    51  			assert.Equal(t, positions[i], npos)
    52  			assert.Equal(t, values[i], nval)
    53  		}
    54  
    55  		_, _, ok := se.Get(-1)
    56  		assert.False(t, ok)
    57  		_, _, ok = se.Get(length)
    58  		assert.False(t, ok)
    59  	})
    60  
    61  	t.Run("position values not match", func(t *testing.T) {
    62  		_, err := NewSliceSparseEmbedding([]uint32{1}, []float32{})
    63  		assert.Error(t, err)
    64  	})
    65  
    66  }
    67  
    68  func TestColumnSparseEmbedding(t *testing.T) {
    69  	columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int())
    70  	columnLen := 8 + rand.Intn(10)
    71  
    72  	v := make([]SparseEmbedding, 0, columnLen)
    73  	for i := 0; i < columnLen; i++ {
    74  		length := 1 + rand.Intn(5)
    75  		positions := make([]uint32, length)
    76  		values := make([]float32, length)
    77  		for j := 0; j < length; j++ {
    78  			positions[j] = uint32(j)
    79  			values[j] = rand.Float32()
    80  		}
    81  		se, err := NewSliceSparseEmbedding(positions, values)
    82  		require.NoError(t, err)
    83  		v = append(v, se)
    84  	}
    85  	column := NewColumnSparseVectors(columnName, v)
    86  
    87  	t.Run("test column attribute", func(t *testing.T) {
    88  		assert.Equal(t, columnName, column.Name())
    89  		assert.Equal(t, FieldTypeSparseVector, column.Type())
    90  		assert.Equal(t, columnLen, column.Len())
    91  		assert.EqualValues(t, v, column.Data())
    92  	})
    93  
    94  	t.Run("test column field data", func(t *testing.T) {
    95  		fd := column.FieldData()
    96  		assert.NotNil(t, fd)
    97  		assert.Equal(t, fd.GetFieldName(), columnName)
    98  	})
    99  
   100  	t.Run("test column value by idx", func(t *testing.T) {
   101  		_, err := column.ValueByIdx(-1)
   102  		assert.Error(t, err)
   103  		_, err = column.ValueByIdx(columnLen)
   104  		assert.Error(t, err)
   105  
   106  		_, err = column.Get(-1)
   107  		assert.Error(t, err)
   108  		_, err = column.Get(columnLen)
   109  		assert.Error(t, err)
   110  
   111  		for i := 0; i < columnLen; i++ {
   112  			v, err := column.ValueByIdx(i)
   113  			assert.NoError(t, err)
   114  			assert.Equal(t, column.vectors[i], v)
   115  			getV, err := column.Get(i)
   116  			assert.NoError(t, err)
   117  			assert.Equal(t, v, getV)
   118  		}
   119  	})
   120  }