github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/test/common/response_check.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/milvus-io/milvus-sdk-go/v2/client"
    12  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  // CheckErr check err and errMsg
    17  func CheckErr(t *testing.T, actualErr error, expErrNil bool, expErrorMsg ...string) {
    18  	if expErrNil {
    19  		require.NoError(t, actualErr)
    20  	} else {
    21  		require.Error(t, actualErr)
    22  		switch len(expErrorMsg) {
    23  		case 0:
    24  			log.Fatal("expect error message should not be empty")
    25  		case 1:
    26  			require.ErrorContains(t, actualErr, expErrorMsg[0])
    27  		default:
    28  			contains := false
    29  			for i := 0; i < len(expErrorMsg); i++ {
    30  				if strings.Contains(actualErr.Error(), expErrorMsg[i]) {
    31  					contains = true
    32  				}
    33  			}
    34  			if !contains {
    35  				t.FailNow()
    36  			}
    37  		}
    38  	}
    39  }
    40  
    41  // EqualFields equal two fields
    42  func EqualFields(t *testing.T, fieldA *entity.Field, fieldB *entity.Field) {
    43  	require.Equal(t, fieldA.Name, fieldB.Name, fmt.Sprintf("Expected field name: %s, actual: %s", fieldA.Name, fieldB.Name))
    44  	require.Equal(t, fieldA.AutoID, fieldB.AutoID, fmt.Sprintf("Expected field AutoID: %t, actual: %t", fieldA.AutoID, fieldB.AutoID))
    45  	require.Equal(t, fieldA.PrimaryKey, fieldB.PrimaryKey, fmt.Sprintf("Expected field PrimaryKey: %t, actual: %t", fieldA.PrimaryKey, fieldB.PrimaryKey))
    46  	require.Equal(t, fieldA.Description, fieldB.Description, fmt.Sprintf("Expected field Description: %s, actual: %s", fieldA.Description, fieldB.Description))
    47  	require.Equal(t, fieldA.DataType, fieldB.DataType, fmt.Sprintf("Expected field DataType: %v, actual: %v", fieldA.DataType, fieldB.DataType))
    48  	require.Equal(t, fieldA.IsPartitionKey, fieldB.IsPartitionKey, fmt.Sprintf("Expected field IsPartitionKey: %t, actual: %t", fieldA.IsPartitionKey, fieldB.IsPartitionKey))
    49  	require.Equal(t, fieldA.IsDynamic, fieldB.IsDynamic, fmt.Sprintf("Expected field IsDynamic: %t, actual: %t", fieldA.IsDynamic, fieldB.IsDynamic))
    50  
    51  	// check vector field dim
    52  	switch fieldA.DataType {
    53  	case entity.FieldTypeFloatVector:
    54  		require.Equal(t, fieldA.TypeParams[entity.TypeParamDim], fieldB.TypeParams[entity.TypeParamDim])
    55  	case entity.FieldTypeBinaryVector:
    56  		require.Equal(t, fieldA.TypeParams[entity.TypeParamDim], fieldB.TypeParams[entity.TypeParamDim])
    57  	// check varchar field max_length
    58  	case entity.FieldTypeVarChar:
    59  		require.Equal(t, fieldA.TypeParams[entity.TypeParamMaxLength], fieldB.TypeParams[entity.TypeParamMaxLength])
    60  
    61  	}
    62  	require.Empty(t, fieldA.IndexParams)
    63  	require.Empty(t, fieldB.IndexParams)
    64  	//require.Equal(t, fieldA.IndexParams, fieldB.IndexParams)
    65  }
    66  
    67  // EqualSchema equal two schemas
    68  func EqualSchema(t *testing.T, schemaA entity.Schema, schemaB entity.Schema) {
    69  	require.Equal(t, schemaA.CollectionName, schemaB.CollectionName, fmt.Sprintf("Expected schame CollectionName: %s, actual: %s", schemaA.CollectionName, schemaB.CollectionName))
    70  	require.Equal(t, schemaA.Description, schemaB.Description, fmt.Sprintf("Expected Description: %s, actual: %s", schemaA.Description, schemaB.Description))
    71  	require.Equal(t, schemaA.AutoID, schemaB.AutoID, fmt.Sprintf("Expected schema AutoID: %t, actual: %t", schemaA.AutoID, schemaB.AutoID))
    72  	require.Equal(t, len(schemaA.Fields), len(schemaB.Fields), fmt.Sprintf("Expected schame fields num: %d, actual: %d", len(schemaA.Fields), len(schemaB.Fields)))
    73  	require.Equal(t, schemaA.EnableDynamicField, schemaB.EnableDynamicField, fmt.Sprintf("Expected schame EnableDynamicField: %t, actual: %t", schemaA.EnableDynamicField, schemaB.EnableDynamicField))
    74  	for i := 0; i < len(schemaA.Fields); i++ {
    75  		EqualFields(t, schemaA.Fields[i], schemaB.Fields[i])
    76  	}
    77  }
    78  
    79  // CheckCollection check collection
    80  func CheckCollection(t *testing.T, actualCollection *entity.Collection, expCollName string, expShardNum int32,
    81  	expSchema *entity.Schema, expConsistencyLevel entity.ConsistencyLevel) {
    82  	require.Equalf(t, expCollName, actualCollection.Name, fmt.Sprintf("Expected collection name: %s, actual: %v", expCollName, actualCollection.Name))
    83  	require.Equalf(t, expShardNum, actualCollection.ShardNum, fmt.Sprintf("Expected ShardNum: %d, actual: %d", expShardNum, actualCollection.ShardNum))
    84  	require.Equal(t, expConsistencyLevel, actualCollection.ConsistencyLevel, fmt.Sprintf("Expected ConsistencyLevel: %v, actual: %v", expConsistencyLevel, actualCollection.ConsistencyLevel))
    85  	EqualSchema(t, *expSchema, *actualCollection.Schema)
    86  }
    87  
    88  // CheckContainsCollection check collections contains collName
    89  func CheckContainsCollection(t *testing.T, collections []*entity.Collection, collName string) {
    90  	allCollNames := make([]string, 0, len(collections))
    91  	for _, collection := range collections {
    92  		allCollNames = append(allCollNames, collection.Name)
    93  	}
    94  	require.Containsf(t, allCollNames, collName, fmt.Sprintf("The collection %s not in: %v", collName, allCollNames))
    95  }
    96  
    97  // CheckNotContainsCollection check collections not contains collName
    98  func CheckNotContainsCollection(t *testing.T, collections []*entity.Collection, collName string) {
    99  	allCollNames := make([]string, 0, len(collections))
   100  	for _, collection := range collections {
   101  		allCollNames = append(allCollNames, collection.Name)
   102  	}
   103  	require.NotContainsf(t, allCollNames, collName, fmt.Sprintf("The collection %s should not be in: %v", collName, allCollNames))
   104  }
   105  
   106  // CheckInsertResult check insert result, ids len (insert count), ids data (pks, but no auto ids)
   107  func CheckInsertResult(t *testing.T, actualIds entity.Column, expIds entity.Column) {
   108  	require.Equal(t, actualIds.Len(), expIds.Len())
   109  	switch expIds.Type() {
   110  	// pk field support int64 and varchar type
   111  	case entity.FieldTypeInt64:
   112  		require.ElementsMatch(t, actualIds.(*entity.ColumnInt64).Data(), expIds.(*entity.ColumnInt64).Data())
   113  	case entity.FieldTypeVarChar:
   114  		require.ElementsMatch(t, actualIds.(*entity.ColumnVarChar).Data(), expIds.(*entity.ColumnVarChar).Data())
   115  	default:
   116  		log.Printf("The primary field only support type: [%v, %v]", entity.FieldTypeInt64, entity.FieldTypeVarChar)
   117  	}
   118  }
   119  
   120  // CheckIndexResult check index result, index type, metric type, index params
   121  func CheckIndexResult(t *testing.T, actualIndexes []entity.Index, expIndexes ...entity.Index) {
   122  	mNameActualIndex := make(map[string]entity.Index)
   123  	allActualIndexNames := make([]string, 0, len(actualIndexes))
   124  	for _, actualIndex := range actualIndexes {
   125  		mNameActualIndex[actualIndex.Name()] = actualIndex
   126  		allActualIndexNames = append(allActualIndexNames, actualIndex.Name())
   127  	}
   128  	for _, expIndex := range expIndexes {
   129  		_, has := mNameActualIndex[expIndex.Name()]
   130  		require.Truef(t, has, "expIndex name %s not in actualIndexes %v", expIndex.Name(), allActualIndexNames)
   131  		require.Equal(t, mNameActualIndex[expIndex.Name()].IndexType(), expIndex.IndexType())
   132  		require.Equal(t, mNameActualIndex[expIndex.Name()].Params(), expIndex.Params())
   133  	}
   134  }
   135  
   136  // EqualColumn assert field data is equal of two columns
   137  func EqualColumn(t *testing.T, columnA entity.Column, columnB entity.Column) {
   138  	require.Equal(t, columnA.Name(), columnB.Name())
   139  	require.Equal(t, columnA.Type(), columnB.Type())
   140  	switch columnA.Type() {
   141  	case entity.FieldTypeBool:
   142  		require.ElementsMatch(t, columnA.(*entity.ColumnBool).Data(), columnB.(*entity.ColumnBool).Data())
   143  	case entity.FieldTypeInt8:
   144  		require.ElementsMatch(t, columnA.(*entity.ColumnInt8).Data(), columnB.(*entity.ColumnInt8).Data())
   145  	case entity.FieldTypeInt16:
   146  		require.ElementsMatch(t, columnA.(*entity.ColumnInt16).Data(), columnB.(*entity.ColumnInt16).Data())
   147  	case entity.FieldTypeInt32:
   148  		require.ElementsMatch(t, columnA.(*entity.ColumnInt32).Data(), columnB.(*entity.ColumnInt32).Data())
   149  	case entity.FieldTypeInt64:
   150  		require.ElementsMatch(t, columnA.(*entity.ColumnInt64).Data(), columnB.(*entity.ColumnInt64).Data())
   151  	case entity.FieldTypeFloat:
   152  		require.ElementsMatch(t, columnA.(*entity.ColumnFloat).Data(), columnB.(*entity.ColumnFloat).Data())
   153  	case entity.FieldTypeDouble:
   154  		require.ElementsMatch(t, columnA.(*entity.ColumnDouble).Data(), columnB.(*entity.ColumnDouble).Data())
   155  	case entity.FieldTypeVarChar:
   156  		require.ElementsMatch(t, columnA.(*entity.ColumnVarChar).Data(), columnB.(*entity.ColumnVarChar).Data())
   157  	case entity.FieldTypeJSON:
   158  		log.Printf("columnA: %s", columnA.(*entity.ColumnJSONBytes).Data())
   159  		log.Printf("columnB: %s", columnB.(*entity.ColumnJSONBytes).Data())
   160  		require.ElementsMatch(t, columnA.(*entity.ColumnJSONBytes).Data(), columnB.(*entity.ColumnJSONBytes).Data())
   161  	case entity.FieldTypeFloatVector:
   162  		require.ElementsMatch(t, columnA.(*entity.ColumnFloatVector).Data(), columnB.(*entity.ColumnFloatVector).Data())
   163  	case entity.FieldTypeBinaryVector:
   164  		require.ElementsMatch(t, columnA.(*entity.ColumnBinaryVector).Data(), columnB.(*entity.ColumnBinaryVector).Data())
   165  	case entity.FieldTypeFloat16Vector:
   166  		require.ElementsMatch(t, columnA.(*entity.ColumnFloat16Vector).Data(), columnB.(*entity.ColumnFloat16Vector).Data())
   167  	case entity.FieldTypeBFloat16Vector:
   168  		require.ElementsMatch(t, columnA.(*entity.ColumnBFloat16Vector).Data(), columnB.(*entity.ColumnBFloat16Vector).Data())
   169  	case entity.FieldTypeSparseVector:
   170  		require.ElementsMatch(t, columnA.(*entity.ColumnSparseFloatVector).Data(), columnB.(*entity.ColumnSparseFloatVector).Data())
   171  	case entity.FieldTypeArray:
   172  		EqualArrayColumn(t, columnA, columnB)
   173  	default:
   174  		log.Printf("The column type not in: [%v, %v, %v,  %v, %v,  %v, %v,  %v, %v,  %v, %v, %v]",
   175  			entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32,
   176  			entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeString,
   177  			entity.FieldTypeVarChar, entity.FieldTypeArray, entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector)
   178  	}
   179  }
   180  
   181  // EqualColumn assert field data is equal of two columns
   182  func EqualArrayColumn(t *testing.T, columnA entity.Column, columnB entity.Column) {
   183  	require.Equal(t, columnA.Name(), columnB.Name())
   184  	require.IsType(t, columnA.Type(), entity.FieldTypeArray)
   185  	require.IsType(t, columnB.Type(), entity.FieldTypeArray)
   186  	switch columnA.(type) {
   187  	case *entity.ColumnBoolArray:
   188  		require.ElementsMatch(t, columnA.(*entity.ColumnBoolArray).Data(), columnB.(*entity.ColumnBoolArray).Data())
   189  	case *entity.ColumnInt8Array:
   190  		require.ElementsMatch(t, columnA.(*entity.ColumnInt8Array).Data(), columnB.(*entity.ColumnInt8Array).Data())
   191  	case *entity.ColumnInt16Array:
   192  		require.ElementsMatch(t, columnA.(*entity.ColumnInt16Array).Data(), columnB.(*entity.ColumnInt16Array).Data())
   193  	case *entity.ColumnInt32Array:
   194  		require.ElementsMatch(t, columnA.(*entity.ColumnInt32Array).Data(), columnB.(*entity.ColumnInt32Array).Data())
   195  	case *entity.ColumnInt64Array:
   196  		require.ElementsMatch(t, columnA.(*entity.ColumnInt64Array).Data(), columnB.(*entity.ColumnInt64Array).Data())
   197  	case *entity.ColumnFloatArray:
   198  		require.ElementsMatch(t, columnA.(*entity.ColumnFloatArray).Data(), columnB.(*entity.ColumnFloatArray).Data())
   199  	case *entity.ColumnDoubleArray:
   200  		require.ElementsMatch(t, columnA.(*entity.ColumnDoubleArray).Data(), columnB.(*entity.ColumnDoubleArray).Data())
   201  	case *entity.ColumnVarCharArray:
   202  		require.ElementsMatch(t, columnA.(*entity.ColumnVarCharArray).Data(), columnB.(*entity.ColumnVarCharArray).Data())
   203  	default:
   204  		log.Printf("Now support array type: [%v, %v, %v,  %v, %v,  %v, %v,  %v]",
   205  			entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32,
   206  			entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar)
   207  	}
   208  }
   209  
   210  // CheckQueryResult check query result, column name, type and field
   211  func CheckQueryResult(t *testing.T, actualColumns []entity.Column, expColumns []entity.Column) {
   212  	require.GreaterOrEqual(t, len(actualColumns), len(expColumns),
   213  		"The len of actual columns %d should greater or equal to the expected columns %d", len(actualColumns), len(expColumns))
   214  	for _, expColumn := range expColumns {
   215  		for _, actualColumn := range actualColumns {
   216  			if expColumn.Name() == actualColumn.Name() {
   217  				EqualColumn(t, expColumn, actualColumn)
   218  			}
   219  		}
   220  	}
   221  }
   222  
   223  // CheckOutputFields check query output fields
   224  func CheckOutputFields(t *testing.T, actualColumns []entity.Column, expFields []string) {
   225  	actualFields := make([]string, 0)
   226  	for _, actualColumn := range actualColumns {
   227  		actualFields = append(actualFields, actualColumn.Name())
   228  	}
   229  	require.ElementsMatchf(t, expFields, actualFields, fmt.Sprintf("Expected search output fields: %v, actual: %v", expFields, actualFields))
   230  }
   231  
   232  // CheckSearchResult check search result, check nq, topk, ids, score
   233  func CheckSearchResult(t *testing.T, actualSearchResults []client.SearchResult, expNq int, expTopK int) {
   234  	require.Equal(t, len(actualSearchResults), expNq)
   235  	for _, actualSearchResult := range actualSearchResults {
   236  		require.Equal(t, actualSearchResult.ResultCount, expTopK)
   237  	}
   238  	//expContainedIds entity.Column
   239  
   240  }
   241  
   242  func EqualIntSlice(a []int, b []int) bool {
   243  	if len(a) != len(b) {
   244  		return false
   245  	}
   246  	for i := range a {
   247  		if a[i] != b[i] {
   248  			return false
   249  		}
   250  	}
   251  	return true
   252  }
   253  
   254  type CheckIteratorOption func(opt *checkIteratorOpt)
   255  
   256  type checkIteratorOpt struct {
   257  	expBatchSize    []int
   258  	expOutputFields []string
   259  }
   260  
   261  func WithExpBatchSize(expBatchSize []int) CheckIteratorOption {
   262  	return func(opt *checkIteratorOpt) {
   263  		opt.expBatchSize = expBatchSize
   264  	}
   265  }
   266  
   267  func WithExpOutputFields(expOutputFields []string) CheckIteratorOption {
   268  	return func(opt *checkIteratorOpt) {
   269  		opt.expOutputFields = expOutputFields
   270  	}
   271  }
   272  
   273  // check queryIterator: result limit, each batch size, output fields
   274  func CheckQueryIteratorResult(ctx context.Context, t *testing.T, itr *client.QueryIterator, expLimit int, opts ...CheckIteratorOption) {
   275  	opt := &checkIteratorOpt{}
   276  	for _, o := range opts {
   277  		o(opt)
   278  	}
   279  	actualLimit := 0
   280  	var actualBatchSize []int
   281  	for {
   282  		rs, err := itr.Next(ctx)
   283  		if err != nil {
   284  			if err == io.EOF {
   285  				break
   286  			}
   287  			log.Fatalf("QueryIterator next gets error: %v", err)
   288  		}
   289  		//log.Printf("QueryIterator result len: %d", rs.Len())
   290  		//log.Printf("QueryIterator result data: %d", rs.GetColumn("int64"))
   291  
   292  		if opt.expBatchSize != nil {
   293  			actualBatchSize = append(actualBatchSize, rs.Len())
   294  		}
   295  		var actualOutputFields []string
   296  		if opt.expOutputFields != nil {
   297  			for _, column := range rs {
   298  				actualOutputFields = append(actualOutputFields, column.Name())
   299  			}
   300  			require.ElementsMatch(t, opt.expOutputFields, actualOutputFields)
   301  		}
   302  		actualLimit = actualLimit + rs.Len()
   303  	}
   304  	require.Equal(t, expLimit, actualLimit)
   305  	if opt.expBatchSize != nil {
   306  		log.Printf("QueryIterator result len: %v", actualBatchSize)
   307  		require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize))
   308  	}
   309  }
   310  
   311  // CheckPersistentSegments check persistent segments
   312  func CheckPersistentSegments(t *testing.T, actualSegments []*entity.Segment, expNb int64) {
   313  	actualNb := int64(0)
   314  	for _, segment := range actualSegments {
   315  		actualNb = segment.NumRows + actualNb
   316  	}
   317  	require.Equal(t, actualNb, expNb)
   318  }
   319  
   320  func CheckResourceGroup(t *testing.T, actualRg *entity.ResourceGroup, expRg *entity.ResourceGroup) {
   321  	require.EqualValues(t, expRg, actualRg)
   322  }
   323  
   324  func getDbNames(dbs []entity.Database) []string {
   325  	allDbNames := make([]string, 0, len(dbs))
   326  	for _, db := range dbs {
   327  		allDbNames = append(allDbNames, db.Name)
   328  	}
   329  	return allDbNames
   330  }
   331  
   332  // CheckContainsDb check collections contains collName
   333  func CheckContainsDb(t *testing.T, dbs []entity.Database, dbName string) {
   334  	allDbNames := getDbNames(dbs)
   335  	require.Containsf(t, allDbNames, dbName, fmt.Sprintf("%s db not in dbs: %v", dbName, dbs))
   336  }
   337  
   338  // CheckNotContainsDb check collections contains collName
   339  func CheckNotContainsDb(t *testing.T, dbs []entity.Database, dbName string) {
   340  	allDbNames := getDbNames(dbs)
   341  	require.NotContainsf(t, allDbNames, dbName, fmt.Sprintf("%s db should not be in dbs: %v", dbName, dbs))
   342  }