github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/columns_sparse_test.go (about) 1 // Copyright (C) 2019-2021 Zilliz. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance 4 // with the License. You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software distributed under the License 9 // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 // or implied. See the License for the specific language governing permissions and limitations under the License. 11 12 package entity 13 14 import ( 15 "fmt" 16 "math/rand" 17 "testing" 18 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 ) 22 23 func TestSliceSparseEmbedding(t *testing.T) { 24 t.Run("normal_case", func(t *testing.T) { 25 26 length := 1 + rand.Intn(5) 27 positions := make([]uint32, length) 28 values := make([]float32, length) 29 for i := 0; i < length; i++ { 30 positions[i] = uint32(i) 31 values[i] = rand.Float32() 32 } 33 se, err := NewSliceSparseEmbedding(positions, values) 34 require.NoError(t, err) 35 36 assert.EqualValues(t, length, se.Dim()) 37 assert.EqualValues(t, length, se.Len()) 38 39 bs := se.Serialize() 40 nv, err := deserializeSliceSparceEmbedding(bs) 41 require.NoError(t, err) 42 43 for i := 0; i < length; i++ { 44 pos, val, ok := se.Get(i) 45 require.True(t, ok) 46 assert.Equal(t, positions[i], pos) 47 assert.Equal(t, values[i], val) 48 49 npos, nval, ok := nv.Get(i) 50 require.True(t, ok) 51 assert.Equal(t, positions[i], npos) 52 assert.Equal(t, values[i], nval) 53 } 54 55 _, _, ok := se.Get(-1) 56 assert.False(t, ok) 57 _, _, ok = se.Get(length) 58 assert.False(t, ok) 59 }) 60 61 t.Run("position values not match", func(t *testing.T) { 62 _, err := NewSliceSparseEmbedding([]uint32{1}, []float32{}) 63 assert.Error(t, err) 64 }) 65 66 } 67 68 func TestColumnSparseEmbedding(t *testing.T) { 69 columnName := fmt.Sprintf("column_sparse_embedding_%d", rand.Int()) 70 columnLen := 8 + rand.Intn(10) 71 72 v := make([]SparseEmbedding, 0, columnLen) 73 for i := 0; i < columnLen; i++ { 74 length := 1 + rand.Intn(5) 75 positions := make([]uint32, length) 76 values := make([]float32, length) 77 for j := 0; j < length; j++ { 78 positions[j] = uint32(j) 79 values[j] = rand.Float32() 80 } 81 se, err := NewSliceSparseEmbedding(positions, values) 82 require.NoError(t, err) 83 v = append(v, se) 84 } 85 column := NewColumnSparseVectors(columnName, v) 86 87 t.Run("test column attribute", func(t *testing.T) { 88 assert.Equal(t, columnName, column.Name()) 89 assert.Equal(t, FieldTypeSparseVector, column.Type()) 90 assert.Equal(t, columnLen, column.Len()) 91 assert.EqualValues(t, v, column.Data()) 92 }) 93 94 t.Run("test column field data", func(t *testing.T) { 95 fd := column.FieldData() 96 assert.NotNil(t, fd) 97 assert.Equal(t, fd.GetFieldName(), columnName) 98 }) 99 100 t.Run("test column value by idx", func(t *testing.T) { 101 _, err := column.ValueByIdx(-1) 102 assert.Error(t, err) 103 _, err = column.ValueByIdx(columnLen) 104 assert.Error(t, err) 105 106 _, err = column.Get(-1) 107 assert.Error(t, err) 108 _, err = column.Get(columnLen) 109 assert.Error(t, err) 110 111 for i := 0; i < columnLen; i++ { 112 v, err := column.ValueByIdx(i) 113 assert.NoError(t, err) 114 assert.Equal(t, column.vectors[i], v) 115 getV, err := column.Get(i) 116 assert.NoError(t, err) 117 assert.Equal(t, v, getV) 118 } 119 }) 120 }