github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/data_test.go (about) 1 // Copyright (C) 2019-2021 Zilliz. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance 4 // with the License. You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software distributed under the License 9 // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 10 // or implied. See the License for the specific language governing permissions and limitations under the License. 11 12 package client 13 14 import ( 15 "context" 16 "math/rand" 17 "testing" 18 "time" 19 20 "github.com/cockroachdb/errors" 21 22 "github.com/golang/protobuf/proto" 23 "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" 24 "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" 25 "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" 26 "github.com/milvus-io/milvus-sdk-go/v2/entity" 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/mock" 29 "github.com/stretchr/testify/require" 30 "github.com/stretchr/testify/suite" 31 ) 32 33 func TestGrpcClientFlush(t *testing.T) { 34 ctx := context.Background() 35 36 c := testClient(ctx, t) 37 38 t.Run("test async flush", func(t *testing.T) { 39 assert.Nil(t, c.Flush(ctx, testCollectionName, true, WithFlushMsgBase(&commonpb.MsgBase{}))) 40 }) 41 42 t.Run("test sync flush", func(t *testing.T) { 43 // 1~10 segments 44 segCount := rand.Intn(10) + 1 45 segments := make([]int64, 0, segCount) 46 for i := 0; i < segCount; i++ { 47 segments = append(segments, rand.Int63()) 48 } 49 // 510ms ~ 2s 50 flushTime := 510 + rand.Intn(1500) 51 start := time.Now() 52 flag := false 53 mockServer.SetInjection(MFlush, func(_ context.Context, raw proto.Message) (proto.Message, error) { 54 req, ok := raw.(*milvuspb.FlushRequest) 55 resp := &milvuspb.FlushResponse{} 56 if !ok { 57 s, err := BadRequestStatus() 58 resp.Status = s 59 return resp, err 60 } 61 assert.ElementsMatch(t, []string{testCollectionName}, req.GetCollectionNames()) 62 63 resp.CollSegIDs = make(map[string]*schemapb.LongArray) 64 resp.CollSegIDs[testCollectionName] = &schemapb.LongArray{ 65 Data: segments, 66 } 67 68 s, err := SuccessStatus() 69 resp.Status = s 70 return resp, err 71 }) 72 73 mockServer.SetInjection(MGetFlushState, func(_ context.Context, raw proto.Message) (proto.Message, error) { 74 req, ok := raw.(*milvuspb.GetFlushStateRequest) 75 resp := &milvuspb.GetFlushStateResponse{} 76 if !ok { 77 s, err := BadRequestStatus() 78 resp.Status = s 79 return resp, err 80 } 81 assert.ElementsMatch(t, segments, req.GetSegmentIDs()) 82 resp.Flushed = false 83 if time.Since(start) > time.Duration(flushTime)*time.Millisecond { 84 resp.Flushed = true 85 flag = true 86 } 87 88 s, err := SuccessStatus() 89 resp.Status = s 90 return resp, err 91 }) 92 assert.Nil(t, c.Flush(ctx, testCollectionName, false)) 93 assert.True(t, flag) 94 95 start = time.Now() 96 flag = false 97 quickCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) 98 defer cancel() 99 assert.NotNil(t, c.Flush(quickCtx, testCollectionName, false)) 100 }) 101 } 102 103 func TestGrpcDeleteByPks(t *testing.T) { 104 ctx := context.Background() 105 106 c := testClient(ctx, t) 107 defer c.Close() 108 109 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 1, testCollectionName, defaultSchema())) 110 defer mockServer.DelInjection(MDescribeCollection) 111 112 t.Run("normal delete by pks", func(t *testing.T) { 113 partName := "testPart" 114 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, partName)) 115 defer mockServer.DelInjection(MHasPartition) 116 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 117 req, ok := raw.(*milvuspb.DeleteRequest) 118 if !ok { 119 t.FailNow() 120 } 121 assert.Equal(t, testCollectionName, req.GetCollectionName()) 122 assert.Equal(t, partName, req.GetPartitionName()) 123 124 resp := &milvuspb.MutationResult{} 125 s, err := SuccessStatus() 126 resp.Status = s 127 return resp, err 128 }) 129 defer mockServer.DelInjection(MDelete) 130 131 err := c.DeleteByPks(ctx, testCollectionName, partName, entity.NewColumnInt64(testPrimaryField, []int64{1, 2, 3})) 132 assert.NoError(t, err) 133 }) 134 135 t.Run("Bad request deletes", func(t *testing.T) { 136 partName := "testPart" 137 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, partName)) 138 defer mockServer.DelInjection(MHasPartition) 139 140 // non-exist collection 141 err := c.DeleteByPks(ctx, "non-exists-collection", "", entity.NewColumnInt64("pk", []int64{})) 142 assert.Error(t, err) 143 144 // non-exist parition 145 err = c.DeleteByPks(ctx, testCollectionName, "non-exists-part", entity.NewColumnInt64("pk", []int64{})) 146 assert.Error(t, err) 147 148 // zero length pk 149 err = c.DeleteByPks(ctx, testCollectionName, "", entity.NewColumnInt64(testPrimaryField, []int64{})) 150 assert.Error(t, err) 151 152 // string pk field 153 err = c.DeleteByPks(ctx, testCollectionName, "", entity.NewColumnString(testPrimaryField, []string{"1"})) 154 assert.Error(t, err) 155 156 // pk name not match 157 err = c.DeleteByPks(ctx, testCollectionName, "", entity.NewColumnInt64("not_pk", []int64{1})) 158 assert.Error(t, err) 159 }) 160 161 t.Run("delete services fail", func(t *testing.T) { 162 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 163 resp := &milvuspb.MutationResult{} 164 return resp, errors.New("mockServer.d error") 165 }) 166 167 err := c.DeleteByPks(ctx, testCollectionName, "", entity.NewColumnInt64(testPrimaryField, []int64{1})) 168 assert.Error(t, err) 169 170 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 171 resp := &milvuspb.MutationResult{} 172 resp.Status = &commonpb.Status{ 173 ErrorCode: commonpb.ErrorCode_UnexpectedError, 174 } 175 return resp, nil 176 }) 177 err = c.DeleteByPks(ctx, testCollectionName, "", entity.NewColumnInt64(testPrimaryField, []int64{1})) 178 assert.Error(t, err) 179 }) 180 } 181 182 func TestGrpcDelete(t *testing.T) { 183 ctx := context.Background() 184 185 c := testClient(ctx, t) 186 defer c.Close() 187 188 mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 1, testCollectionName, defaultSchema())) 189 defer mockServer.DelInjection(MDescribeCollection) 190 191 t.Run("normal delete by pks", func(t *testing.T) { 192 partName := "testPart" 193 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, true, partName)) 194 defer mockServer.DelInjection(MHasPartition) 195 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 196 req, ok := raw.(*milvuspb.DeleteRequest) 197 if !ok { 198 t.FailNow() 199 } 200 assert.Equal(t, testCollectionName, req.GetCollectionName()) 201 assert.Equal(t, partName, req.GetPartitionName()) 202 203 resp := &milvuspb.MutationResult{} 204 s, err := SuccessStatus() 205 resp.Status = s 206 return resp, err 207 }) 208 defer mockServer.DelInjection(MDelete) 209 210 err := c.Delete(ctx, testCollectionName, partName, "") 211 assert.NoError(t, err) 212 }) 213 214 t.Run("Bad request deletes", func(t *testing.T) { 215 partName := "testPart" 216 mockServer.SetInjection(MHasPartition, hasPartitionInjection(t, testCollectionName, false, partName)) 217 defer mockServer.DelInjection(MHasPartition) 218 219 // non-exist collection 220 err := c.Delete(ctx, "non-exists-collection", "", "") 221 assert.Error(t, err) 222 223 // non-exist parition 224 err = c.Delete(ctx, testCollectionName, "non-exists-part", "") 225 assert.Error(t, err) 226 }) 227 t.Run("delete services fail", func(t *testing.T) { 228 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 229 resp := &milvuspb.MutationResult{} 230 return resp, errors.New("mockServer.d error") 231 }) 232 233 err := c.Delete(ctx, testCollectionName, "", "") 234 assert.Error(t, err) 235 236 mockServer.SetInjection(MDelete, func(_ context.Context, raw proto.Message) (proto.Message, error) { 237 resp := &milvuspb.MutationResult{} 238 resp.Status = &commonpb.Status{ 239 ErrorCode: commonpb.ErrorCode_UnexpectedError, 240 } 241 return resp, nil 242 }) 243 err = c.Delete(ctx, testCollectionName, "", "") 244 assert.Error(t, err) 245 }) 246 } 247 248 type SearchSuite struct { 249 MockSuiteBase 250 sch *entity.Schema 251 schDynamic *entity.Schema 252 } 253 254 func (s *SearchSuite) SetupSuite() { 255 s.MockSuiteBase.SetupSuite() 256 257 s.sch = entity.NewSchema().WithName(testCollectionName). 258 WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 259 WithField(entity.NewField().WithName("Attr").WithDataType(entity.FieldTypeInt64)). 260 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 261 s.schDynamic = entity.NewSchema().WithName(testCollectionName).WithDynamicFieldEnabled(true). 262 WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 263 WithField(entity.NewField().WithName("$meta").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true)). 264 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 265 } 266 267 func (s *SearchSuite) TestSearchFail() { 268 c := s.client 269 ctx, cancel := context.WithCancel(context.Background()) 270 defer cancel() 271 272 partName := "part_1" 273 vectors := generateFloatVector(10, testVectorDim) 274 sp, err := entity.NewIndexFlatSearchParam() 275 s.Require().NoError(err) 276 s.resetMock() 277 278 s.Run("service_not_ready", func() { 279 _, err := (&GrpcClient{}).Search(ctx, testCollectionName, []string{}, partName, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 280 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong)) 281 s.Error(err) 282 s.ErrorIs(err, ErrClientNotReady) 283 }) 284 285 s.Run("fail_describecollection_error", func() { 286 defer s.resetMock() 287 288 s.setupDescribeCollectionError(commonpb.ErrorCode_Success, errors.New("mock error")) 289 290 _, err := c.Search(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 291 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong)) 292 s.Error(err) 293 }) 294 295 s.Run("fail_describecollection_errcode", func() { 296 defer s.resetMock() 297 298 s.setupDescribeCollectionError(commonpb.ErrorCode_UnexpectedError, nil) 299 300 _, err := c.Search(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 301 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong)) 302 s.Error(err) 303 }) 304 305 s.Run("fail_guaranteed_non_custom_cl", func() { 306 defer s.resetMock() 307 308 s.setupDescribeCollection(testCollectionName, s.sch) 309 310 _, err := c.Search(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 311 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong), WithGuaranteeTimestamp(1000000)) 312 s.Error(err) 313 }) 314 315 s.Run("fail_search_error", func() { 316 defer s.resetMock() 317 318 s.setupDescribeCollection(testCollectionName, s.sch) 319 s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). 320 Return(nil, errors.New("mock error")) 321 322 _, err := c.Search(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 323 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong)) 324 s.Error(err) 325 }) 326 327 s.Run("fail_search_errcode", func() { 328 defer s.resetMock() 329 330 s.setupDescribeCollection(testCollectionName, s.sch) 331 s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). 332 Return(&milvuspb.SearchResults{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil) 333 334 _, err := c.Search(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, "vector", 335 entity.L2, 5, sp, WithSearchQueryConsistencyLevel(entity.ClStrong)) 336 s.Error(err) 337 }) 338 } 339 340 func (s *SearchSuite) TestSearchSuccess() { 341 c := s.client 342 ctx, cancel := context.WithCancel(context.Background()) 343 defer cancel() 344 345 partName := "part_1" 346 vectors := generateFloatVector(10, testVectorDim) 347 sp, err := entity.NewIndexFlatSearchParam() 348 s.Require().NoError(err) 349 s.resetMock() 350 351 expr := "ID > 0" 352 353 s.Run("non_dynamic_schema", func() { 354 defer s.resetMock() 355 s.setupDescribeCollection(testCollectionName, s.sch) 356 s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). 357 Run(func(_ context.Context, req *milvuspb.SearchRequest) { 358 s.Equal(testCollectionName, req.GetCollectionName()) 359 s.Equal(expr, req.GetDsl()) 360 s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType()) 361 s.ElementsMatch([]string{"ID"}, req.GetOutputFields()) 362 s.ElementsMatch([]string{partName}, req.GetPartitionNames()) 363 }). 364 Return(&milvuspb.SearchResults{ 365 Status: getSuccessStatus(), 366 Results: &schemapb.SearchResultData{ 367 NumQueries: 1, 368 TopK: 10, 369 FieldsData: []*schemapb.FieldData{ 370 s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 371 }, 372 Ids: &schemapb.IDs{ 373 IdField: &schemapb.IDs_IntId{ 374 IntId: &schemapb.LongArray{ 375 Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 376 }, 377 }, 378 }, 379 Scores: make([]float32, 10), 380 Topks: []int64{10}, 381 }, 382 }, nil) 383 384 r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, 385 testVectorField, entity.L2, 10, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(10000000000)) 386 s.NoError(err) 387 s.Require().Equal(1, len(r)) 388 result := r[0] 389 s.Require().NotNil(result.Fields.GetColumn("ID")) 390 }) 391 392 s.Run("dynamic_schema", func() { 393 defer s.resetMock() 394 s.setupDescribeCollection(testCollectionName, s.schDynamic) 395 s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). 396 Run(func(_ context.Context, req *milvuspb.SearchRequest) { 397 s.Equal(testCollectionName, req.GetCollectionName()) 398 s.Equal(expr, req.GetDsl()) 399 s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType()) 400 s.ElementsMatch([]string{"A", "B"}, req.GetOutputFields()) 401 s.ElementsMatch([]string{partName}, req.GetPartitionNames()) 402 }). 403 Return(&milvuspb.SearchResults{ 404 Status: getSuccessStatus(), 405 Results: &schemapb.SearchResultData{ 406 NumQueries: 1, 407 TopK: 2, 408 FieldsData: []*schemapb.FieldData{ 409 s.getJSONBytesFieldData("", [][]byte{ 410 []byte(`{"A": 123, "B": "456"}`), 411 []byte(`{"B": "abc", "A": 456}`), 412 }, true), 413 }, 414 Ids: &schemapb.IDs{ 415 IdField: &schemapb.IDs_IntId{ 416 IntId: &schemapb.LongArray{ 417 Data: []int64{1, 2}, 418 }, 419 }, 420 }, 421 Scores: make([]float32, 2), 422 Topks: []int64{2}, 423 }, 424 }, nil) 425 426 r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"A", "B"}, []entity.Vector{entity.FloatVector(vectors[0])}, 427 testVectorField, entity.L2, 2, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClBounded)) 428 s.NoError(err) 429 s.Require().Equal(1, len(r)) 430 result := r[0] 431 columnA := result.Fields.GetColumn("A") 432 s.Require().NotNil(columnA) 433 column, ok := columnA.(*entity.ColumnDynamic) 434 s.Require().True(ok) 435 v, err := column.GetAsInt64(0) 436 s.NoError(err) 437 s.Equal(int64(123), v) 438 439 columnB := result.Fields.GetColumn("B") 440 s.Require().NotNil(columnB) 441 column, ok = columnB.(*entity.ColumnDynamic) 442 s.Require().True(ok) 443 str, err := column.GetAsString(1) 444 s.NoError(err) 445 s.Equal("abc", str) 446 }) 447 448 s.Run("group_by", func() { 449 defer s.resetMock() 450 s.setupDescribeCollection(testCollectionName, s.sch) 451 s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). 452 Run(func(_ context.Context, req *milvuspb.SearchRequest) { 453 s.Equal(testCollectionName, req.GetCollectionName()) 454 s.Equal(expr, req.GetDsl()) 455 s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType()) 456 s.ElementsMatch([]string{"ID"}, req.GetOutputFields()) 457 s.ElementsMatch([]string{partName}, req.GetPartitionNames()) 458 }). 459 Return(&milvuspb.SearchResults{ 460 Status: getSuccessStatus(), 461 Results: &schemapb.SearchResultData{ 462 NumQueries: 1, 463 TopK: 10, 464 FieldsData: []*schemapb.FieldData{ 465 s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), 466 }, 467 Ids: &schemapb.IDs{ 468 IdField: &schemapb.IDs_IntId{ 469 IntId: &schemapb.LongArray{ 470 Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 471 }, 472 }, 473 }, 474 Scores: make([]float32, 10), 475 Topks: []int64{10}, 476 GroupByFieldValue: s.getInt64FieldData("Attr", []int64{10, 10, 10, 10, 10, 10, 10, 10, 10, 10}), 477 }, 478 }, nil) 479 480 r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, 481 testVectorField, entity.L2, 10, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(10000000000), WithGroupByField("Attr")) 482 s.NoError(err) 483 s.Require().Equal(1, len(r)) 484 result := r[0] 485 s.Require().NotNil(result.Fields.GetColumn("ID")) 486 }) 487 } 488 489 func TestSearch(t *testing.T) { 490 suite.Run(t, new(SearchSuite)) 491 } 492 493 type QuerySuite struct { 494 MockSuiteBase 495 sch *entity.Schema 496 schDynamic *entity.Schema 497 } 498 499 func (s *QuerySuite) SetupSuite() { 500 s.MockSuiteBase.SetupSuite() 501 502 s.sch = entity.NewSchema().WithName(testCollectionName). 503 WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 504 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 505 s.schDynamic = entity.NewSchema().WithName(testCollectionName).WithDynamicFieldEnabled(true). 506 WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeVarChar).WithIsPrimaryKey(true)). 507 WithField(entity.NewField().WithName("$meta").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true)). 508 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) 509 } 510 511 func (s *QuerySuite) GetFail() { 512 c := s.client 513 ctx, cancel := context.WithCancel((context.Background())) 514 defer cancel() 515 516 idCol := entity.NewColumnInt64("ID", []int64{1}) 517 s.Run("service_not_ready", func() { 518 _, err := (&GrpcClient{}).Get(ctx, testCollectionName, idCol) 519 s.Error(err) 520 s.ErrorIs(err, ErrClientNotReady) 521 }) 522 523 s.Run("ids_len_0", func() { 524 _, err := c.Get(ctx, testCollectionName, entity.NewColumnInt64("ID", []int64{}), GetWithOutputFields("ID")) 525 s.Error(err) 526 }) 527 528 s.Run("describe_failed", func() { 529 defer s.resetMock() 530 s.setupDescribeCollectionError(commonpb.ErrorCode_Success, errors.New("mock error")) 531 532 _, err := c.Get(ctx, testCollectionName, idCol) 533 s.Error(err) 534 }) 535 } 536 537 func (s *QuerySuite) TestQueryByPksFail() { 538 c := s.client 539 ctx, cancel := context.WithCancel(context.Background()) 540 defer cancel() 541 542 partName := "part_1" 543 idCol := entity.NewColumnInt64("ID", []int64{1}) 544 s.Run("service_not_ready", func() { 545 _, err := (&GrpcClient{}).QueryByPks(ctx, testCollectionName, []string{partName}, idCol, []string{"ID"}) 546 s.Error(err) 547 s.ErrorIs(err, ErrClientNotReady) 548 }) 549 550 s.Run("ids_len_0", func() { 551 _, err := c.QueryByPks(ctx, testCollectionName, []string{partName}, entity.NewColumnInt64("ID", []int64{}), []string{"ID"}) 552 s.Error(err) 553 }) 554 555 s.Run("query_failed", func() { 556 defer s.resetMock() 557 s.setupDescribeCollectionError(commonpb.ErrorCode_Success, errors.New("mock error")) 558 559 _, err := c.QueryByPks(ctx, testCollectionName, []string{partName}, idCol, []string{"ID"}) 560 s.Error(err) 561 }) 562 } 563 564 func (s *QuerySuite) TestQueryFail() { 565 c := s.client 566 ctx, cancel := context.WithCancel(context.Background()) 567 defer cancel() 568 569 partName := "part_1" 570 s.resetMock() 571 572 s.Run("service_not_ready", func() { 573 _, err := (&GrpcClient{}).Query(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 574 s.Error(err) 575 s.ErrorIs(err, ErrClientNotReady) 576 }) 577 578 s.Run("fail_describecollection_error", func() { 579 defer s.resetMock() 580 581 s.setupDescribeCollectionError(commonpb.ErrorCode_Success, errors.New("mock error")) 582 583 _, err := c.Query(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 584 s.Error(err) 585 }) 586 587 s.Run("fail_describecollection_errcode", func() { 588 defer s.resetMock() 589 590 s.setupDescribeCollectionError(commonpb.ErrorCode_UnexpectedError, nil) 591 592 _, err := c.Query(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 593 s.Error(err) 594 }) 595 596 s.Run("fail_guaranteed_non_custom_cl", func() { 597 defer s.resetMock() 598 599 s.setupDescribeCollection(testCollectionName, s.sch) 600 601 _, err := c.Query(ctx, testCollectionName, []string{partName}, "", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong), WithGuaranteeTimestamp(1000000)) 602 s.Error(err) 603 }) 604 605 s.Run("fail_search_error", func() { 606 defer s.resetMock() 607 608 s.setupDescribeCollection(testCollectionName, s.sch) 609 s.mock.EXPECT().Query(mock.Anything, mock.AnythingOfType("*milvuspb.QueryRequest")). 610 Return(nil, errors.New("mock error")) 611 612 _, err := c.Query(ctx, testCollectionName, []string{partName}, "ID in {1}", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 613 s.Error(err) 614 }) 615 616 s.Run("fail_search_errcode", func() { 617 defer s.resetMock() 618 619 s.setupDescribeCollection(testCollectionName, s.sch) 620 s.mock.EXPECT().Query(mock.Anything, mock.AnythingOfType("*milvuspb.QueryRequest")). 621 Return(&milvuspb.QueryResults{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil) 622 623 _, err := c.Query(ctx, testCollectionName, []string{partName}, "ID in {1}", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 624 s.Error(err) 625 }) 626 627 s.Run("fail_response_type_error", func() { 628 defer s.resetMock() 629 630 s.setupDescribeCollection(testCollectionName, s.sch) 631 s.mock.EXPECT().Query(mock.Anything, mock.AnythingOfType("*milvuspb.QueryRequest")). 632 Return(&milvuspb.QueryResults{ 633 Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, 634 FieldsData: []*schemapb.FieldData{ 635 { 636 FieldName: "ID", 637 Type: schemapb.DataType_String, //wrong data type here 638 Field: &schemapb.FieldData_Scalars{ 639 Scalars: &schemapb.ScalarField{ 640 Data: &schemapb.ScalarField_LongData{ 641 LongData: &schemapb.LongArray{ 642 Data: []int64{1}, 643 }, 644 }, 645 }, 646 }, 647 }, 648 }, 649 }, nil) 650 651 _, err := c.Query(ctx, testCollectionName, []string{partName}, "ID in {1}", []string{"ID"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 652 s.Error(err) 653 }) 654 } 655 656 func (s *QuerySuite) TestQuerySuccess() { 657 c := s.client 658 ctx, cancel := context.WithCancel(context.Background()) 659 defer cancel() 660 661 partName := "part_1" 662 s.resetMock() 663 664 expr := "ID in {1}" 665 666 s.Run("non_dynamic", func() { 667 defer s.resetMock() 668 669 s.setupDescribeCollection(testCollectionName, s.sch) 670 s.mock.EXPECT().Query(mock.Anything, mock.AnythingOfType("*milvuspb.QueryRequest")). 671 Run(func(_ context.Context, req *milvuspb.QueryRequest) {}). 672 Return(&milvuspb.QueryResults{ 673 Status: getSuccessStatus(), 674 FieldsData: []*schemapb.FieldData{ 675 s.getInt64FieldData("ID", []int64{1}), 676 s.getFloatVectorFieldData("vector", 1, []float32{0.1}), 677 }, 678 }, nil) 679 680 rs, err := c.Query(ctx, testCollectionName, []string{partName}, expr, []string{"ID", "vector"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 681 s.NoError(err) 682 s.Require().Equal(2, len(rs)) 683 colID, ok := rs.GetColumn("ID").(*entity.ColumnInt64) 684 s.Require().True(ok) 685 s.NotNil(colID) 686 v, err := colID.Get(0) 687 s.NoError(err) 688 s.EqualValues(1, v) 689 colVector, ok := rs.GetColumn("vector").(*entity.ColumnFloatVector) 690 s.Require().True(ok) 691 s.NotNil(colVector) 692 v, err = colVector.Get(0) 693 s.NoError(err) 694 s.EqualValues([]float32{0.1}, v) 695 }) 696 697 s.Run("dynamic_schema", func() { 698 defer s.resetMock() 699 700 s.setupDescribeCollection(testCollectionName, s.schDynamic) 701 s.mock.EXPECT().Query(mock.Anything, mock.AnythingOfType("*milvuspb.QueryRequest")). 702 Run(func(_ context.Context, req *milvuspb.QueryRequest) {}). 703 Return(&milvuspb.QueryResults{ 704 Status: getSuccessStatus(), 705 FieldsData: []*schemapb.FieldData{ 706 s.getVarcharFieldData("ID", []string{"1"}), 707 s.getFloatVectorFieldData("vector", 1, []float32{0.1}), 708 s.getJSONBytesFieldData("$meta", [][]byte{ 709 []byte(`{"A": 123, "B": "456"}`), 710 []byte(`{"B": "abc", "A": 456}`), 711 }, true), 712 }, 713 }, nil) 714 715 rs, err := c.Query(ctx, testCollectionName, []string{partName}, `id in {"1"}`, []string{"ID", "vector", "A"}, WithSearchQueryConsistencyLevel(entity.ClStrong)) 716 s.NoError(err) 717 s.Require().Equal(3, len(rs)) 718 colID, ok := rs.GetColumn("ID").(*entity.ColumnVarChar) 719 s.Require().True(ok) 720 s.NotNil(colID) 721 v, err := colID.Get(0) 722 s.NoError(err) 723 s.EqualValues("1", v) 724 colVector, ok := rs.GetColumn("vector").(*entity.ColumnFloatVector) 725 s.Require().True(ok) 726 s.NotNil(colVector) 727 v, err = colVector.Get(0) 728 s.NoError(err) 729 s.EqualValues([]float32{0.1}, v) 730 731 columnA := rs.GetColumn("A").(*entity.ColumnDynamic) 732 s.Require().True(ok) 733 s.Require().NotNil(columnA) 734 v, err = columnA.GetAsInt64(0) 735 s.NoError(err) 736 s.Equal(int64(123), v) 737 }) 738 } 739 740 func TestQuery(t *testing.T) { 741 suite.Run(t, new(QuerySuite)) 742 } 743 744 func TestGrpcCalcDistanceWithIDs(t *testing.T) { 745 ctx := context.Background() 746 t.Run("bad client calls CalcDistance", func(t *testing.T) { 747 c := &GrpcClient{} 748 r, err := c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, nil, nil) 749 assert.Nil(t, r) 750 assert.NotNil(t, err) 751 assert.EqualValues(t, ErrClientNotReady, err) 752 }) 753 754 c := testClient(ctx, t) 755 mockServer.SetInjection(MDescribeCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) { 756 req, ok := raw.(*milvuspb.DescribeCollectionRequest) 757 resp := &milvuspb.DescribeCollectionResponse{} 758 if !ok { 759 s, err := BadRequestStatus() 760 resp.Status = s 761 return resp, err 762 } 763 assert.Equal(t, testCollectionName, req.GetCollectionName()) 764 765 sch := defaultSchema() 766 resp.Schema = sch.ProtoMessage() 767 768 s, err := SuccessStatus() 769 resp.Status = s 770 return resp, err 771 }) 772 t.Run("call with ctx done", func(t *testing.T) { 773 ctxDone, cancel := context.WithCancel(context.Background()) 774 cancel() 775 776 r, err := c.CalcDistance(ctxDone, testCollectionName, []string{}, entity.L2, 777 entity.NewColumnInt64("int64", []int64{1}), entity.NewColumnInt64("int64", []int64{1})) 778 assert.Nil(t, r) 779 assert.NotNil(t, err) 780 }) 781 782 t.Run("invalid ids call", func(t *testing.T) { 783 r, err := c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 784 nil, nil) 785 assert.Nil(t, r) 786 assert.NotNil(t, err) 787 788 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 789 entity.NewColumnInt64("non-exists", []int64{1}), entity.NewColumnInt64("non-exists", []int64{1})) 790 assert.Nil(t, r) 791 assert.NotNil(t, err) 792 793 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 794 entity.NewColumnInt64("non-exists", []int64{1}), entity.NewColumnInt64("int64", []int64{1})) 795 assert.Nil(t, r) 796 assert.NotNil(t, err) 797 798 }) 799 800 t.Run("valid calls", func(t *testing.T) { 801 mockServer.SetInjection(MCalcDistance, func(_ context.Context, raw proto.Message) (proto.Message, error) { 802 req, ok := raw.(*milvuspb.CalcDistanceRequest) 803 resp := &milvuspb.CalcDistanceResults{} 804 if !ok { 805 s, err := BadRequestStatus() 806 resp.Status = s 807 return resp, err 808 } 809 idsLeft := req.GetOpLeft().GetIdArray() 810 valuesLeft := req.GetOpLeft().GetDataArray() 811 idsRight := req.GetOpRight().GetIdArray() 812 valuesRight := req.GetOpRight().GetDataArray() 813 assert.True(t, idsLeft != nil || valuesLeft != nil) 814 assert.True(t, idsRight != nil || valuesRight != nil) 815 816 if idsLeft != nil { 817 assert.Equal(t, testCollectionName, idsLeft.CollectionName) 818 } 819 if idsRight != nil { 820 assert.Equal(t, testCollectionName, idsRight.CollectionName) 821 } 822 823 // this injection returns float distance 824 dl := 0 825 if idsLeft != nil { 826 dl = len(idsLeft.IdArray.GetIntId().GetData()) 827 } 828 if valuesLeft != nil { 829 dl = len(valuesLeft.GetFloatVector().GetData()) / int(valuesLeft.Dim) 830 } 831 dr := 0 832 if idsRight != nil { 833 dr = len(idsRight.IdArray.GetIntId().GetData()) 834 } 835 if valuesRight != nil { 836 dr = len(valuesRight.GetFloatVector().GetData()) / int(valuesRight.Dim) 837 } 838 839 resp.Array = &milvuspb.CalcDistanceResults_FloatDist{ 840 FloatDist: &schemapb.FloatArray{ 841 Data: make([]float32, dl*dr), 842 }, 843 } 844 845 s, err := SuccessStatus() 846 resp.Status = s 847 return resp, err 848 }) 849 r, err := c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 850 entity.NewColumnInt64("vector", []int64{1}), entity.NewColumnInt64("vector", []int64{1})) 851 assert.Nil(t, err) 852 assert.NotNil(t, r) 853 854 vectors := generateFloatVector(5, testVectorDim) 855 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 856 entity.NewColumnInt64("vector", []int64{1}), entity.NewColumnFloatVector("vector", testVectorDim, vectors)) 857 assert.Nil(t, err) 858 assert.NotNil(t, r) 859 860 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 861 entity.NewColumnFloatVector("vector", testVectorDim, vectors), entity.NewColumnInt64("vector", []int64{1})) 862 assert.Nil(t, err) 863 assert.NotNil(t, r) 864 865 // test IntDistance, 866 mockServer.SetInjection(MCalcDistance, func(_ context.Context, raw proto.Message) (proto.Message, error) { 867 req, ok := raw.(*milvuspb.CalcDistanceRequest) 868 resp := &milvuspb.CalcDistanceResults{} 869 if !ok { 870 s, err := BadRequestStatus() 871 resp.Status = s 872 return resp, err 873 } 874 idsLeft := req.GetOpLeft().GetIdArray() 875 valuesLeft := req.GetOpLeft().GetDataArray() 876 idsRight := req.GetOpRight().GetIdArray() 877 valuesRight := req.GetOpRight().GetDataArray() 878 assert.True(t, idsLeft != nil || valuesLeft != nil) 879 assert.True(t, idsRight != nil || valuesRight != nil) 880 881 if idsLeft != nil { 882 assert.Equal(t, testCollectionName, idsLeft.CollectionName) 883 } 884 if idsRight != nil { 885 assert.Equal(t, testCollectionName, idsRight.CollectionName) 886 } 887 888 // this injection returns float distance 889 dl := 0 890 if idsLeft != nil { 891 dl = len(idsLeft.IdArray.GetIntId().GetData()) 892 } 893 if valuesLeft != nil { 894 dl = len(valuesLeft.GetFloatVector().GetData()) / int(valuesLeft.Dim) 895 } 896 dr := 0 897 if idsRight != nil { 898 dr = len(idsRight.IdArray.GetIntId().GetData()) 899 } 900 if valuesRight != nil { 901 dr = len(valuesRight.GetFloatVector().GetData()) / int(valuesRight.Dim) 902 } 903 904 resp.Array = &milvuspb.CalcDistanceResults_IntDist{ 905 IntDist: &schemapb.IntArray{ 906 Data: make([]int32, dl*dr), 907 }, 908 } 909 910 s, err := SuccessStatus() 911 resp.Status = s 912 return resp, err 913 }) 914 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.HAMMING, 915 entity.NewColumnInt64("vector", []int64{1}), entity.NewColumnInt64("vector", []int64{1})) 916 assert.Nil(t, err) 917 assert.NotNil(t, r) 918 919 // test str id 920 mockServer.SetInjection(MDescribeCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) { 921 req, ok := raw.(*milvuspb.DescribeCollectionRequest) 922 resp := &milvuspb.DescribeCollectionResponse{} 923 if !ok { 924 s, err := BadRequestStatus() 925 resp.Status = s 926 return resp, err 927 } 928 assert.Equal(t, testCollectionName, req.GetCollectionName()) 929 930 sch := defaultSchema() 931 sch.Fields[0].DataType = entity.FieldTypeString 932 sch.Fields[0].Name = "str" 933 resp.Schema = sch.ProtoMessage() 934 935 s, err := SuccessStatus() 936 resp.Status = s 937 return resp, err 938 }) 939 mockServer.SetInjection(MCalcDistance, func(_ context.Context, raw proto.Message) (proto.Message, error) { 940 req, ok := raw.(*milvuspb.CalcDistanceRequest) 941 resp := &milvuspb.CalcDistanceResults{} 942 if !ok { 943 s, err := BadRequestStatus() 944 resp.Status = s 945 return resp, err 946 } 947 idsLeft := req.GetOpLeft().GetIdArray() 948 idsRight := req.GetOpRight().GetIdArray() 949 assert.NotNil(t, idsLeft) 950 assert.NotNil(t, idsRight) 951 952 assert.Equal(t, testCollectionName, idsLeft.CollectionName) 953 assert.Equal(t, testCollectionName, idsRight.CollectionName) 954 955 // only int ids supported for now TODO update string test cases 956 assert.NotNil(t, idsLeft.IdArray.GetStrId()) 957 assert.NotNil(t, idsRight.IdArray.GetStrId()) 958 959 // this injection returns float distance 960 dl := len(idsLeft.IdArray.GetStrId().GetData()) 961 962 resp.Array = &milvuspb.CalcDistanceResults_FloatDist{ 963 FloatDist: &schemapb.FloatArray{ 964 Data: make([]float32, dl), 965 }, 966 } 967 968 s, err := SuccessStatus() 969 resp.Status = s 970 return resp, err 971 }) 972 r, err = c.CalcDistance(ctx, testCollectionName, []string{}, entity.L2, 973 entity.NewColumnString("vector", []string{"1"}), entity.NewColumnString("vector", []string{"1"})) 974 assert.Nil(t, err) 975 assert.NotNil(t, r) 976 }) 977 } 978 979 func TestIsCollectionPrimaryKey(t *testing.T) { 980 t.Run("nil cases", func(t *testing.T) { 981 assert.False(t, isCollectionPrimaryKey(nil, nil)) 982 assert.False(t, isCollectionPrimaryKey(&entity.Collection{}, entity.NewColumnInt64("id", []int64{}))) 983 }) 984 985 t.Run("check cases", func(t *testing.T) { 986 assert.False(t, isCollectionPrimaryKey(&entity.Collection{ 987 Schema: defaultSchema(), 988 }, entity.NewColumnInt64("id", []int64{}))) 989 assert.False(t, isCollectionPrimaryKey(&entity.Collection{ 990 Schema: defaultSchema(), 991 }, entity.NewColumnInt32("int64", []int32{}))) 992 assert.True(t, isCollectionPrimaryKey(&entity.Collection{ 993 Schema: defaultSchema(), 994 }, entity.NewColumnInt64("int64", []int64{}))) 995 996 }) 997 } 998 999 func TestEstRowSize(t *testing.T) { 1000 // a schema contains all supported vector 1001 sch := entity.NewSchema().WithName(testCollectionName).WithAutoID(false). 1002 WithField(entity.NewField().WithName(testPrimaryField).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). 1003 WithField(entity.NewField().WithName("attr1").WithDataType(entity.FieldTypeInt8)). 1004 WithField(entity.NewField().WithName("attr2").WithDataType(entity.FieldTypeInt16)). 1005 WithField(entity.NewField().WithName("attr3").WithDataType(entity.FieldTypeInt32)). 1006 WithField(entity.NewField().WithName("attr4").WithDataType(entity.FieldTypeFloat)). 1007 WithField(entity.NewField().WithName("attr5").WithDataType(entity.FieldTypeDouble)). 1008 WithField(entity.NewField().WithName("attr6").WithDataType(entity.FieldTypeBool)). 1009 WithField(entity.NewField().WithName("attr6").WithDataType(entity.FieldTypeBool)). 1010 WithField(entity.NewField().WithName(testVectorField).WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)). 1011 WithField(entity.NewField().WithName("binary_vector").WithDataType(entity.FieldTypeBinaryVector).WithDim(testVectorDim)) 1012 1013 // one row 1014 columnID := entity.NewColumnInt64(testPrimaryField, []int64{0}) 1015 columnAttr1 := entity.NewColumnInt8("attr1", []int8{0}) 1016 columnAttr2 := entity.NewColumnInt16("attr2", []int16{0}) 1017 columnAttr3 := entity.NewColumnInt32("attr3", []int32{0}) 1018 columnAttr4 := entity.NewColumnFloat("attr4", []float32{0}) 1019 columnAttr5 := entity.NewColumnDouble("attr5", []float64{0}) 1020 columnAttr6 := entity.NewColumnBool("attr6", []bool{true}) 1021 columnFv := entity.NewColumnFloatVector(testVectorField, testVectorDim, generateFloatVector(1, testVectorDim)) 1022 columnBv := entity.NewColumnBinaryVector("binary_vector", testVectorDim, generateBinaryVector(1, testVectorDim)) 1023 1024 sr := &milvuspb.SearchResults{ 1025 Results: &schemapb.SearchResultData{ 1026 FieldsData: []*schemapb.FieldData{ 1027 columnID.FieldData(), 1028 columnAttr1.FieldData(), 1029 columnAttr2.FieldData(), 1030 columnAttr3.FieldData(), 1031 columnAttr4.FieldData(), 1032 columnAttr5.FieldData(), 1033 columnAttr6.FieldData(), 1034 columnFv.FieldData(), 1035 columnBv.FieldData(), 1036 }, 1037 }, 1038 } 1039 bs, err := proto.Marshal(sr) 1040 assert.Nil(t, err) 1041 sr1l := len(bs) 1042 // 2Row 1043 columnID = entity.NewColumnInt64(testPrimaryField, []int64{0, 1}) 1044 columnAttr1 = entity.NewColumnInt8("attr1", []int8{0, 1}) 1045 columnAttr2 = entity.NewColumnInt16("attr2", []int16{0, 1}) 1046 columnAttr3 = entity.NewColumnInt32("attr3", []int32{0, 1}) 1047 columnAttr4 = entity.NewColumnFloat("attr4", []float32{0, 1}) 1048 columnAttr5 = entity.NewColumnDouble("attr5", []float64{0, 1}) 1049 columnAttr6 = entity.NewColumnBool("attr6", []bool{true, true}) 1050 columnFv = entity.NewColumnFloatVector(testVectorField, testVectorDim, generateFloatVector(2, testVectorDim)) 1051 columnBv = entity.NewColumnBinaryVector("binary_vector", testVectorDim, generateBinaryVector(2, testVectorDim)) 1052 1053 sr = &milvuspb.SearchResults{ 1054 Results: &schemapb.SearchResultData{ 1055 FieldsData: []*schemapb.FieldData{ 1056 columnID.FieldData(), 1057 columnAttr1.FieldData(), 1058 columnAttr2.FieldData(), 1059 columnAttr3.FieldData(), 1060 columnAttr4.FieldData(), 1061 columnAttr5.FieldData(), 1062 columnAttr6.FieldData(), 1063 columnFv.FieldData(), 1064 columnBv.FieldData(), 1065 }, 1066 }, 1067 } 1068 bs, err = proto.Marshal(sr) 1069 assert.Nil(t, err) 1070 sr2l := len(bs) 1071 1072 t.Log(sr1l, sr2l, sr2l-sr1l) 1073 est := estRowSize(sch, []string{}) 1074 t.Log(est) 1075 1076 assert.Greater(t, est, int64(sr2l-sr1l)) 1077 } 1078 1079 func generateFloatVector(num, dim int) [][]float32 { 1080 r := make([][]float32, 0, num) 1081 for i := 0; i < num; i++ { 1082 v := make([]float32, 0, dim) 1083 for j := 0; j < dim; j++ { 1084 v = append(v, rand.Float32()) 1085 } 1086 r = append(r, v) 1087 } 1088 return r 1089 } 1090 1091 func generateBinaryVector(num, dim int) [][]byte { 1092 r := make([][]byte, 0, num) 1093 for i := 0; i < num; i++ { 1094 v := make([]byte, 0, dim/8) 1095 rand.Read(v) 1096 r = append(r, v) 1097 } 1098 return r 1099 } 1100 1101 func TestVector2PlaceHolder(t *testing.T) { 1102 t.Run("FloatVector", func(t *testing.T) { 1103 data := generateFloatVector(10, 32) 1104 vectors := make([]entity.Vector, 0, len(data)) 1105 for _, row := range data { 1106 vectors = append(vectors, entity.FloatVector(row)) 1107 } 1108 1109 phv := vector2Placeholder(vectors) 1110 assert.Equal(t, "$0", phv.Tag) 1111 assert.Equal(t, commonpb.PlaceholderType_FloatVector, phv.Type) 1112 require.Equal(t, len(vectors), len(phv.Values)) 1113 for idx, line := range phv.Values { 1114 assert.Equal(t, vectors[idx].Serialize(), line) 1115 } 1116 }) 1117 1118 t.Run("BinaryVector", func(t *testing.T) { 1119 data := generateBinaryVector(10, 32) 1120 vectors := make([]entity.Vector, 0, len(data)) 1121 for _, row := range data { 1122 vectors = append(vectors, entity.BinaryVector(row)) 1123 } 1124 1125 phv := vector2Placeholder(vectors) 1126 assert.Equal(t, "$0", phv.Tag) 1127 assert.Equal(t, commonpb.PlaceholderType_BinaryVector, phv.Type) 1128 require.Equal(t, len(vectors), len(phv.Values)) 1129 for idx, line := range phv.Values { 1130 assert.Equal(t, vectors[idx].Serialize(), line) 1131 } 1132 }) 1133 } 1134 1135 type WildcardSuite struct { 1136 suite.Suite 1137 1138 schema *entity.Schema 1139 } 1140 1141 func (s *WildcardSuite) SetupTest() { 1142 s.schema = entity.NewSchema(). 1143 WithField(entity.NewField().WithName("pk").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). 1144 WithField(entity.NewField().WithName("attr").WithDataType(entity.FieldTypeInt64)). 1145 WithField(entity.NewField().WithName("$meta").WithDataType(entity.FieldTypeJSON).WithIsDynamic(true)). 1146 WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(128)) 1147 } 1148 1149 func (s *WildcardSuite) TestExpandWildcard() { 1150 type testCase struct { 1151 tag string 1152 input []string 1153 expect []string 1154 expectWildCard bool 1155 } 1156 1157 cases := []testCase{ 1158 {tag: "normal", input: []string{"pk", "attr"}, expect: []string{"pk", "attr"}}, 1159 {tag: "with_wildcard", input: []string{"*"}, expect: []string{"pk", "attr", "$meta", "vector"}, expectWildCard: true}, 1160 {tag: "wildcard_dynamic", input: []string{"*", "a"}, expect: []string{"pk", "attr", "$meta", "vector", "a"}, expectWildCard: true}, 1161 } 1162 1163 for _, tc := range cases { 1164 s.Run(tc.tag, func() { 1165 output, wildCard := expandWildcard(s.schema, tc.input) 1166 s.ElementsMatch(tc.expect, output) 1167 s.Equal(tc.expectWildCard, wildCard) 1168 }) 1169 } 1170 } 1171 1172 func TestExpandWildcard(t *testing.T) { 1173 suite.Run(t, new(WildcardSuite)) 1174 }