github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/entity/rows_test.go (about)

     1  // Copyright (C) 2019-2021 Zilliz. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
     4  // with the License. You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software distributed under the License
     9  // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
    10  // or implied. See the License for the specific language governing permissions and limitations under the License.
    11  
    12  package entity
    13  
    14  import (
    15  	"reflect"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/suite"
    20  )
    21  
    22  // ArrayRow test case type
    23  type ArrayRow [16]float32
    24  
    25  func (ar *ArrayRow) Collection() string  { return "" }
    26  func (ar *ArrayRow) Partition() string   { return "" }
    27  func (ar *ArrayRow) Description() string { return "" }
    28  
    29  type Uint8Struct struct {
    30  	RowBase
    31  	Attr uint8
    32  }
    33  
    34  type StringArrayStruct struct {
    35  	RowBase
    36  	Vector [8]string
    37  }
    38  
    39  type StringSliceStruct struct {
    40  	RowBase
    41  	Vector []string `milvus:"dim:8"`
    42  }
    43  
    44  type SliceNoDimStruct struct {
    45  	RowBase
    46  	Vector []float32 `milvus:""`
    47  }
    48  
    49  type SliceBadDimStruct struct {
    50  	RowBase
    51  	Vector []float32 `milvus:"dim:str"`
    52  }
    53  
    54  type SliceBadDimStruct2 struct {
    55  	RowBase
    56  	Vector []float32 `milvus:"dim:0"`
    57  }
    58  
    59  func TestParseSchema(t *testing.T) {
    60  
    61  	t.Run("invalid cases", func(t *testing.T) {
    62  		// anonymous struct with default collection name ("") will cause error
    63  		anonymusStruct := struct {
    64  			RowBase
    65  		}{}
    66  		sch, err := ParseSchema(anonymusStruct)
    67  		assert.Nil(t, sch)
    68  		assert.NotNil(t, err)
    69  
    70  		// MapRow
    71  		m := make(MapRow)
    72  		sch, err = ParseSchema(m)
    73  		assert.Nil(t, sch)
    74  		assert.NotNil(t, err)
    75  
    76  		// non struct
    77  		arrayRow := ArrayRow([16]float32{})
    78  		sch, err = ParseSchema(&arrayRow)
    79  		assert.Nil(t, sch)
    80  		assert.NotNil(t, err)
    81  
    82  		// uint8 not supported
    83  		sch, err = ParseSchema(&Uint8Struct{})
    84  		assert.Nil(t, sch)
    85  		assert.NotNil(t, err)
    86  
    87  		// string array not supported
    88  		sch, err = ParseSchema(&StringArrayStruct{})
    89  		assert.Nil(t, sch)
    90  		assert.NotNil(t, err)
    91  
    92  		// string slice not supported
    93  		sch, err = ParseSchema(&StringSliceStruct{})
    94  		assert.Nil(t, sch)
    95  		assert.NotNil(t, err)
    96  
    97  		// slice vector with no dim
    98  		sch, err = ParseSchema(&SliceNoDimStruct{})
    99  		assert.Nil(t, sch)
   100  		assert.NotNil(t, err)
   101  
   102  		// slice vector with bad format dim
   103  		sch, err = ParseSchema(&SliceBadDimStruct{})
   104  		assert.Nil(t, sch)
   105  		assert.NotNil(t, err)
   106  
   107  		// slice vector with bad format dim 2
   108  		sch, err = ParseSchema(&SliceBadDimStruct2{})
   109  		assert.Nil(t, sch)
   110  		assert.NotNil(t, err)
   111  
   112  	})
   113  
   114  	t.Run("valid cases", func(t *testing.T) {
   115  
   116  		sch, err := ParseSchema(RowBase{})
   117  		assert.Nil(t, err)
   118  		assert.Equal(t, "RowBase", sch.CollectionName)
   119  
   120  		getVectorField := func(schema *Schema) *Field {
   121  			for _, field := range schema.Fields {
   122  				if field.DataType == FieldTypeFloatVector ||
   123  					field.DataType == FieldTypeBinaryVector ||
   124  					field.DataType == FieldTypeBFloat16Vector ||
   125  					field.DataType == FieldTypeFloat16Vector {
   126  					return field
   127  				}
   128  			}
   129  			return nil
   130  		}
   131  
   132  		type ValidStruct struct {
   133  			RowBase
   134  			ID     int64 `milvus:"primary_key"`
   135  			Attr1  int8
   136  			Attr2  int16
   137  			Attr3  int32
   138  			Attr4  float32
   139  			Attr5  float64
   140  			Attr6  string
   141  			Vector []float32 `milvus:"dim:128"`
   142  		}
   143  		vs := &ValidStruct{}
   144  		sch, err = ParseSchema(vs)
   145  		assert.Nil(t, err)
   146  		assert.NotNil(t, sch)
   147  		assert.Equal(t, "ValidStruct", sch.CollectionName)
   148  
   149  		type ValidFp16Struct struct {
   150  			RowBase
   151  			ID     int64 `milvus:"primary_key"`
   152  			Attr1  int8
   153  			Attr2  int16
   154  			Attr3  int32
   155  			Attr4  float32
   156  			Attr5  float64
   157  			Attr6  string
   158  			Vector []byte `milvus:"dim:128;vector_type:fp16"`
   159  		}
   160  		fp16Vs := &ValidFp16Struct{}
   161  		sch, err = ParseSchema(fp16Vs)
   162  		assert.Nil(t, err)
   163  		assert.NotNil(t, sch)
   164  		assert.Equal(t, "ValidFp16Struct", sch.CollectionName)
   165  		vectorField := getVectorField(sch)
   166  		assert.Equal(t, FieldTypeFloat16Vector, vectorField.DataType)
   167  
   168  		type ValidBf16Struct struct {
   169  			RowBase
   170  			ID     int64 `milvus:"primary_key"`
   171  			Attr1  int8
   172  			Attr2  int16
   173  			Attr3  int32
   174  			Attr4  float32
   175  			Attr5  float64
   176  			Attr6  string
   177  			Vector []byte `milvus:"dim:128;vector_type:bf16"`
   178  		}
   179  		bf16Vs := &ValidBf16Struct{}
   180  		sch, err = ParseSchema(bf16Vs)
   181  		assert.Nil(t, err)
   182  		assert.NotNil(t, sch)
   183  		assert.Equal(t, "ValidBf16Struct", sch.CollectionName)
   184  		vectorField = getVectorField(sch)
   185  		assert.Equal(t, FieldTypeBFloat16Vector, vectorField.DataType)
   186  
   187  		type ValidByteStruct struct {
   188  			RowBase
   189  			ID     int64  `milvus:"primary_key"`
   190  			Vector []byte `milvus:"dim:128"`
   191  		}
   192  		vs2 := &ValidByteStruct{}
   193  		sch, err = ParseSchema(vs2)
   194  		assert.Nil(t, err)
   195  		assert.NotNil(t, sch)
   196  
   197  		type ValidArrayStruct struct {
   198  			RowBase
   199  			ID     int64 `milvus:"primary_key"`
   200  			Vector [64]float32
   201  		}
   202  		vs3 := &ValidArrayStruct{}
   203  		sch, err = ParseSchema(vs3)
   204  		assert.Nil(t, err)
   205  		assert.NotNil(t, sch)
   206  
   207  		type ValidArrayStructByte struct {
   208  			RowBase
   209  			ID     int64   `milvus:"primary_key;auto_id"`
   210  			Data   *string `milvus:"extra:test\\;false"`
   211  			Vector [64]byte
   212  		}
   213  		vs4 := &ValidArrayStructByte{}
   214  		sch, err = ParseSchema(vs4)
   215  		assert.Nil(t, err)
   216  		assert.NotNil(t, sch)
   217  
   218  		vs5 := &ValidStructWithNamedTag{}
   219  		sch, err = ParseSchema(vs5)
   220  		assert.Nil(t, err)
   221  		assert.NotNil(t, sch)
   222  		i64f, vecf := false, false
   223  		for _, field := range sch.Fields {
   224  			if field.Name == "id" {
   225  				i64f = true
   226  			}
   227  			if field.Name == "vector" {
   228  				vecf = true
   229  			}
   230  		}
   231  
   232  		assert.True(t, i64f)
   233  		assert.True(t, vecf)
   234  	})
   235  }
   236  
   237  type ValidStruct struct {
   238  	RowBase
   239  	ID      int64 `milvus:"primary_key"`
   240  	Attr1   int8
   241  	Attr2   int16
   242  	Attr3   int32
   243  	Attr4   float32
   244  	Attr5   float64
   245  	Attr6   string
   246  	Attr7   bool
   247  	Vector  []float32 `milvus:"dim:16"`
   248  	Vector2 []byte    `milvus:"dim:32"`
   249  }
   250  
   251  type ValidStruct2 struct {
   252  	RowBase
   253  	ID      int64 `milvus:"primary_key"`
   254  	Vector  [16]float32
   255  	Vector2 [4]byte
   256  	Ignored bool `milvus:"-"`
   257  }
   258  
   259  type ValidStructWithNamedTag struct {
   260  	RowBase
   261  	ID     int64       `milvus:"primary_key;name:id"`
   262  	Vector [16]float32 `milvus:"name:vector"`
   263  }
   264  
   265  type RowsSuite struct {
   266  	suite.Suite
   267  }
   268  
   269  func (s *RowsSuite) TestRowsToColumns() {
   270  	s.Run("valid_cases", func() {
   271  
   272  		columns, err := RowsToColumns([]Row{&ValidStruct{}})
   273  		s.Nil(err)
   274  		s.Equal(10, len(columns))
   275  
   276  		columns, err = RowsToColumns([]Row{&ValidStruct2{}})
   277  		s.Nil(err)
   278  		s.Equal(3, len(columns))
   279  	})
   280  
   281  	s.Run("auto_id_pk", func() {
   282  		type AutoPK struct {
   283  			RowBase
   284  			ID     int64     `milvus:"primary_key;auto_id"`
   285  			Vector []float32 `milvus:"dim:32"`
   286  		}
   287  		columns, err := RowsToColumns([]Row{&AutoPK{}})
   288  		s.Nil(err)
   289  		s.Require().Equal(1, len(columns))
   290  		s.Equal("Vector", columns[0].Name())
   291  	})
   292  
   293  	s.Run("fp16", func() {
   294  		type BF16Struct struct {
   295  			RowBase
   296  			ID     int64  `milvus:"primary_key;auto_id"`
   297  			Vector []byte `milvus:"dim:16;vector_type:bf16"`
   298  		}
   299  		columns, err := RowsToColumns([]Row{&BF16Struct{}})
   300  		s.Nil(err)
   301  		s.Require().Equal(1, len(columns))
   302  		s.Equal("Vector", columns[0].Name())
   303  		s.Equal(FieldTypeBFloat16Vector, columns[0].Type())
   304  	})
   305  
   306  	s.Run("fp16", func() {
   307  		type FP16Struct struct {
   308  			RowBase
   309  			ID     int64  `milvus:"primary_key;auto_id"`
   310  			Vector []byte `milvus:"dim:16;vector_type:fp16"`
   311  		}
   312  		columns, err := RowsToColumns([]Row{&FP16Struct{}})
   313  		s.Nil(err)
   314  		s.Require().Equal(1, len(columns))
   315  		s.Equal("Vector", columns[0].Name())
   316  		s.Equal(FieldTypeFloat16Vector, columns[0].Type())
   317  	})
   318  
   319  	s.Run("invalid_cases", func() {
   320  		// empty input
   321  		_, err := RowsToColumns([]Row{})
   322  		s.NotNil(err)
   323  
   324  		// incompatible rows
   325  		_, err = RowsToColumns([]Row{&ValidStruct{}, &ValidStruct2{}})
   326  		s.NotNil(err)
   327  
   328  		// schema & row not compatible
   329  		_, err = RowsToColumns([]Row{&ValidStruct{}}, &Schema{
   330  			Fields: []*Field{
   331  				{
   332  					Name:     "int64",
   333  					DataType: FieldTypeInt64,
   334  				},
   335  			},
   336  		})
   337  		s.NotNil(err)
   338  	})
   339  }
   340  
   341  func (s *RowsSuite) TestDynamicSchema() {
   342  	s.Run("all_fallback_dynamic", func() {
   343  		columns, err := RowsToColumns([]Row{&ValidStruct{}},
   344  			NewSchema().WithDynamicFieldEnabled(true),
   345  		)
   346  		s.NoError(err)
   347  		s.Equal(1, len(columns))
   348  	})
   349  
   350  	s.Run("dynamic_not_found", func() {
   351  		_, err := RowsToColumns([]Row{&ValidStruct{}},
   352  			NewSchema().WithField(
   353  				NewField().WithName("ID").WithDataType(FieldTypeInt64).WithIsPrimaryKey(true),
   354  			).WithDynamicFieldEnabled(true),
   355  		)
   356  		s.NoError(err)
   357  	})
   358  }
   359  
   360  func (s *RowsSuite) TestReflectValueCandi() {
   361  	cases := []struct {
   362  		tag       string
   363  		v         reflect.Value
   364  		expect    map[string]fieldCandi
   365  		expectErr bool
   366  	}{
   367  		{
   368  			tag: "MapRow",
   369  			v: reflect.ValueOf(MapRow(map[string]interface{}{
   370  				"A": "abd", "B": int64(8),
   371  			})),
   372  			expect: map[string]fieldCandi{
   373  				"A": {
   374  					name: "A",
   375  					v:    reflect.ValueOf("abd"),
   376  				},
   377  				"B": {
   378  					name: "B",
   379  					v:    reflect.ValueOf(int64(8)),
   380  				},
   381  			},
   382  			expectErr: false,
   383  		},
   384  	}
   385  
   386  	for _, c := range cases {
   387  		s.Run(c.tag, func() {
   388  			r, err := reflectValueCandi(c.v)
   389  			if c.expectErr {
   390  				s.Error(err)
   391  				return
   392  			}
   393  			s.NoError(err)
   394  			s.Equal(len(c.expect), len(r))
   395  			for k, v := range c.expect {
   396  				rv, has := r[k]
   397  				s.Require().True(has)
   398  				s.Equal(v.name, rv.name)
   399  			}
   400  		})
   401  	}
   402  }
   403  
   404  func TestRows(t *testing.T) {
   405  	suite.Run(t, new(RowsSuite))
   406  }