github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/columns_sparse.go (about)

     1  // Copyright (C) 2019-2021 Zilliz. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
     4  // with the License. You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software distributed under the License
     9  // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
    10  // or implied. See the License for the specific language governing permissions and limitations under the License.
    11  
    12  package entity
    13  
    14  import (
    15  	"encoding/binary"
    16  	"fmt"
    17  	"math"
    18  	"sort"
    19  
    20  	"github.com/cockroachdb/errors"
    21  	schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
    22  )
    23  
    24  type SparseEmbedding interface {
    25  	Dim() int // the dimension
    26  	Len() int // the actual items in this vector
    27  	Get(idx int) (pos uint32, value float32, ok bool)
    28  	Serialize() []byte
    29  	FieldType() FieldType
    30  }
    31  
    32  var _ SparseEmbedding = sliceSparseEmbedding{}
    33  var _ Vector = sliceSparseEmbedding{}
    34  
    35  type sliceSparseEmbedding struct {
    36  	positions []uint32
    37  	values    []float32
    38  	dim       int
    39  	len       int
    40  }
    41  
    42  func (e sliceSparseEmbedding) Dim() int {
    43  	return e.dim
    44  }
    45  
    46  func (e sliceSparseEmbedding) Len() int {
    47  	return e.len
    48  }
    49  
    50  func (e sliceSparseEmbedding) FieldType() FieldType {
    51  	return FieldTypeSparseVector
    52  }
    53  
    54  func (e sliceSparseEmbedding) Get(idx int) (uint32, float32, bool) {
    55  	if idx < 0 || idx >= int(e.len) {
    56  		return 0, 0, false
    57  	}
    58  	return e.positions[idx], e.values[idx], true
    59  }
    60  
    61  func (e sliceSparseEmbedding) Serialize() []byte {
    62  	row := make([]byte, 8*e.Len())
    63  	for idx := 0; idx < e.Len(); idx++ {
    64  		pos, value, _ := e.Get(idx)
    65  		binary.LittleEndian.PutUint32(row[idx*8:], pos)
    66  		binary.LittleEndian.PutUint32(row[idx*8+4:], math.Float32bits(value))
    67  	}
    68  	return row
    69  }
    70  
    71  // Less implements sort.Interce
    72  func (e sliceSparseEmbedding) Less(i, j int) bool {
    73  	return e.positions[i] < e.positions[j]
    74  }
    75  
    76  func (e sliceSparseEmbedding) Swap(i, j int) {
    77  	e.positions[i], e.positions[j] = e.positions[j], e.positions[i]
    78  	e.values[i], e.values[j] = e.values[j], e.values[i]
    79  }
    80  
    81  func deserializeSliceSparceEmbedding(bs []byte) (sliceSparseEmbedding, error) {
    82  	length := len(bs)
    83  	if length%8 != 0 {
    84  		return sliceSparseEmbedding{}, errors.New("not valid sparse embedding bytes")
    85  	}
    86  
    87  	length = length / 8
    88  
    89  	result := sliceSparseEmbedding{
    90  		positions: make([]uint32, length),
    91  		values:    make([]float32, length),
    92  		len:       length,
    93  	}
    94  
    95  	for i := 0; i < length; i++ {
    96  		result.positions[i] = binary.LittleEndian.Uint32(bs[i*8 : i*8+4])
    97  		result.values[i] = math.Float32frombits(binary.LittleEndian.Uint32(bs[i*8+4 : i*8+8]))
    98  	}
    99  	return result, nil
   100  }
   101  
   102  func NewSliceSparseEmbedding(positions []uint32, values []float32) (SparseEmbedding, error) {
   103  	if len(positions) != len(values) {
   104  		return nil, errors.New("invalid sparse embedding input, positions shall have same number of values")
   105  	}
   106  
   107  	se := sliceSparseEmbedding{
   108  		positions: positions,
   109  		values:    values,
   110  		len:       len(positions),
   111  	}
   112  
   113  	sort.Sort(se)
   114  
   115  	if se.len > 0 {
   116  		se.dim = int(se.positions[se.len-1]) + 1
   117  	}
   118  
   119  	return se, nil
   120  }
   121  
   122  var _ (Column) = (*ColumnSparseFloatVector)(nil)
   123  
   124  type ColumnSparseFloatVector struct {
   125  	ColumnBase
   126  
   127  	vectors []SparseEmbedding
   128  	name    string
   129  }
   130  
   131  // Name returns column name.
   132  func (c *ColumnSparseFloatVector) Name() string {
   133  	return c.name
   134  }
   135  
   136  // Type returns column FieldType.
   137  func (c *ColumnSparseFloatVector) Type() FieldType {
   138  	return FieldTypeSparseVector
   139  }
   140  
   141  // Len returns column values length.
   142  func (c *ColumnSparseFloatVector) Len() int {
   143  	return len(c.vectors)
   144  }
   145  
   146  func (c *ColumnSparseFloatVector) Slice(start, end int) Column {
   147  	l := c.Len()
   148  	if start > l {
   149  		start = l
   150  	}
   151  	if end == -1 || end > l {
   152  		end = l
   153  	}
   154  	return &ColumnSparseFloatVector{
   155  		ColumnBase: c.ColumnBase,
   156  		name:       c.name,
   157  		vectors:    c.vectors[start:end],
   158  	}
   159  }
   160  
   161  // Get returns value at index as interface{}.
   162  func (c *ColumnSparseFloatVector) Get(idx int) (interface{}, error) {
   163  	if idx < 0 || idx >= c.Len() {
   164  		return nil, errors.New("index out of range")
   165  	}
   166  	return c.vectors[idx], nil
   167  }
   168  
   169  // ValueByIdx returns value of the provided index
   170  // error occurs when index out of range
   171  func (c *ColumnSparseFloatVector) ValueByIdx(idx int) (SparseEmbedding, error) {
   172  	var r SparseEmbedding // use default value
   173  	if idx < 0 || idx >= c.Len() {
   174  		return r, errors.New("index out of range")
   175  	}
   176  	return c.vectors[idx], nil
   177  }
   178  
   179  func (c *ColumnSparseFloatVector) FieldData() *schema.FieldData {
   180  	fd := &schema.FieldData{
   181  		Type:      schema.DataType_SparseFloatVector,
   182  		FieldName: c.name,
   183  	}
   184  
   185  	dim := int(0)
   186  	data := make([][]byte, 0, len(c.vectors))
   187  	for _, vector := range c.vectors {
   188  		row := make([]byte, 8*vector.Len())
   189  		for idx := 0; idx < vector.Len(); idx++ {
   190  			pos, value, _ := vector.Get(idx)
   191  			binary.LittleEndian.PutUint32(row[idx*8:], pos)
   192  			binary.LittleEndian.PutUint32(row[idx*8+4:], math.Float32bits(value))
   193  		}
   194  		data = append(data, row)
   195  		if vector.Dim() > dim {
   196  			dim = vector.Dim()
   197  		}
   198  	}
   199  
   200  	fd.Field = &schema.FieldData_Vectors{
   201  		Vectors: &schema.VectorField{
   202  			Dim: int64(dim),
   203  			Data: &schema.VectorField_SparseFloatVector{
   204  				SparseFloatVector: &schema.SparseFloatArray{
   205  					Dim:      int64(dim),
   206  					Contents: data,
   207  				},
   208  			},
   209  		},
   210  	}
   211  	return fd
   212  }
   213  
   214  func (c *ColumnSparseFloatVector) AppendValue(i interface{}) error {
   215  	v, ok := i.(SparseEmbedding)
   216  	if !ok {
   217  		return fmt.Errorf("invalid type, expect SparseEmbedding interface, got %T", i)
   218  	}
   219  	c.vectors = append(c.vectors, v)
   220  
   221  	return nil
   222  }
   223  
   224  func (c *ColumnSparseFloatVector) Data() []SparseEmbedding {
   225  	return c.vectors
   226  }
   227  
   228  func NewColumnSparseVectors(name string, values []SparseEmbedding) *ColumnSparseFloatVector {
   229  	return &ColumnSparseFloatVector{
   230  		name:    name,
   231  		vectors: values,
   232  	}
   233  }