github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/test/common/response_check.go (about) 1 package common 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "log" 8 "strings" 9 "testing" 10 11 "github.com/milvus-io/milvus-sdk-go/v2/client" 12 "github.com/milvus-io/milvus-sdk-go/v2/entity" 13 "github.com/stretchr/testify/require" 14 ) 15 16 // CheckErr check err and errMsg 17 func CheckErr(t *testing.T, actualErr error, expErrNil bool, expErrorMsg ...string) { 18 if expErrNil { 19 require.NoError(t, actualErr) 20 } else { 21 require.Error(t, actualErr) 22 switch len(expErrorMsg) { 23 case 0: 24 log.Fatal("expect error message should not be empty") 25 case 1: 26 require.ErrorContains(t, actualErr, expErrorMsg[0]) 27 default: 28 contains := false 29 for i := 0; i < len(expErrorMsg); i++ { 30 if strings.Contains(actualErr.Error(), expErrorMsg[i]) { 31 contains = true 32 } 33 } 34 if !contains { 35 t.FailNow() 36 } 37 } 38 } 39 } 40 41 // EqualFields equal two fields 42 func EqualFields(t *testing.T, fieldA *entity.Field, fieldB *entity.Field) { 43 require.Equal(t, fieldA.Name, fieldB.Name, fmt.Sprintf("Expected field name: %s, actual: %s", fieldA.Name, fieldB.Name)) 44 require.Equal(t, fieldA.AutoID, fieldB.AutoID, fmt.Sprintf("Expected field AutoID: %t, actual: %t", fieldA.AutoID, fieldB.AutoID)) 45 require.Equal(t, fieldA.PrimaryKey, fieldB.PrimaryKey, fmt.Sprintf("Expected field PrimaryKey: %t, actual: %t", fieldA.PrimaryKey, fieldB.PrimaryKey)) 46 require.Equal(t, fieldA.Description, fieldB.Description, fmt.Sprintf("Expected field Description: %s, actual: %s", fieldA.Description, fieldB.Description)) 47 require.Equal(t, fieldA.DataType, fieldB.DataType, fmt.Sprintf("Expected field DataType: %v, actual: %v", fieldA.DataType, fieldB.DataType)) 48 require.Equal(t, fieldA.IsPartitionKey, fieldB.IsPartitionKey, fmt.Sprintf("Expected field IsPartitionKey: %t, actual: %t", fieldA.IsPartitionKey, fieldB.IsPartitionKey)) 49 require.Equal(t, fieldA.IsDynamic, fieldB.IsDynamic, fmt.Sprintf("Expected field IsDynamic: %t, actual: %t", fieldA.IsDynamic, fieldB.IsDynamic)) 50 51 // check vector field dim 52 switch fieldA.DataType { 53 case entity.FieldTypeFloatVector: 54 require.Equal(t, fieldA.TypeParams[entity.TypeParamDim], fieldB.TypeParams[entity.TypeParamDim]) 55 case entity.FieldTypeBinaryVector: 56 require.Equal(t, fieldA.TypeParams[entity.TypeParamDim], fieldB.TypeParams[entity.TypeParamDim]) 57 // check varchar field max_length 58 case entity.FieldTypeVarChar: 59 require.Equal(t, fieldA.TypeParams[entity.TypeParamMaxLength], fieldB.TypeParams[entity.TypeParamMaxLength]) 60 61 } 62 require.Empty(t, fieldA.IndexParams) 63 require.Empty(t, fieldB.IndexParams) 64 //require.Equal(t, fieldA.IndexParams, fieldB.IndexParams) 65 } 66 67 // EqualSchema equal two schemas 68 func EqualSchema(t *testing.T, schemaA entity.Schema, schemaB entity.Schema) { 69 require.Equal(t, schemaA.CollectionName, schemaB.CollectionName, fmt.Sprintf("Expected schame CollectionName: %s, actual: %s", schemaA.CollectionName, schemaB.CollectionName)) 70 require.Equal(t, schemaA.Description, schemaB.Description, fmt.Sprintf("Expected Description: %s, actual: %s", schemaA.Description, schemaB.Description)) 71 require.Equal(t, schemaA.AutoID, schemaB.AutoID, fmt.Sprintf("Expected schema AutoID: %t, actual: %t", schemaA.AutoID, schemaB.AutoID)) 72 require.Equal(t, len(schemaA.Fields), len(schemaB.Fields), fmt.Sprintf("Expected schame fields num: %d, actual: %d", len(schemaA.Fields), len(schemaB.Fields))) 73 require.Equal(t, schemaA.EnableDynamicField, schemaB.EnableDynamicField, fmt.Sprintf("Expected schame EnableDynamicField: %t, actual: %t", schemaA.EnableDynamicField, schemaB.EnableDynamicField)) 74 for i := 0; i < len(schemaA.Fields); i++ { 75 EqualFields(t, schemaA.Fields[i], schemaB.Fields[i]) 76 } 77 } 78 79 // CheckCollection check collection 80 func CheckCollection(t *testing.T, actualCollection *entity.Collection, expCollName string, expShardNum int32, 81 expSchema *entity.Schema, expConsistencyLevel entity.ConsistencyLevel) { 82 require.Equalf(t, expCollName, actualCollection.Name, fmt.Sprintf("Expected collection name: %s, actual: %v", expCollName, actualCollection.Name)) 83 require.Equalf(t, expShardNum, actualCollection.ShardNum, fmt.Sprintf("Expected ShardNum: %d, actual: %d", expShardNum, actualCollection.ShardNum)) 84 require.Equal(t, expConsistencyLevel, actualCollection.ConsistencyLevel, fmt.Sprintf("Expected ConsistencyLevel: %v, actual: %v", expConsistencyLevel, actualCollection.ConsistencyLevel)) 85 EqualSchema(t, *expSchema, *actualCollection.Schema) 86 } 87 88 // CheckContainsCollection check collections contains collName 89 func CheckContainsCollection(t *testing.T, collections []*entity.Collection, collName string) { 90 allCollNames := make([]string, 0, len(collections)) 91 for _, collection := range collections { 92 allCollNames = append(allCollNames, collection.Name) 93 } 94 require.Containsf(t, allCollNames, collName, fmt.Sprintf("The collection %s not in: %v", collName, allCollNames)) 95 } 96 97 // CheckNotContainsCollection check collections not contains collName 98 func CheckNotContainsCollection(t *testing.T, collections []*entity.Collection, collName string) { 99 allCollNames := make([]string, 0, len(collections)) 100 for _, collection := range collections { 101 allCollNames = append(allCollNames, collection.Name) 102 } 103 require.NotContainsf(t, allCollNames, collName, fmt.Sprintf("The collection %s should not be in: %v", collName, allCollNames)) 104 } 105 106 // CheckInsertResult check insert result, ids len (insert count), ids data (pks, but no auto ids) 107 func CheckInsertResult(t *testing.T, actualIds entity.Column, expIds entity.Column) { 108 require.Equal(t, actualIds.Len(), expIds.Len()) 109 switch expIds.Type() { 110 // pk field support int64 and varchar type 111 case entity.FieldTypeInt64: 112 require.ElementsMatch(t, actualIds.(*entity.ColumnInt64).Data(), expIds.(*entity.ColumnInt64).Data()) 113 case entity.FieldTypeVarChar: 114 require.ElementsMatch(t, actualIds.(*entity.ColumnVarChar).Data(), expIds.(*entity.ColumnVarChar).Data()) 115 default: 116 log.Printf("The primary field only support type: [%v, %v]", entity.FieldTypeInt64, entity.FieldTypeVarChar) 117 } 118 } 119 120 // CheckIndexResult check index result, index type, metric type, index params 121 func CheckIndexResult(t *testing.T, actualIndexes []entity.Index, expIndexes ...entity.Index) { 122 mNameActualIndex := make(map[string]entity.Index) 123 allActualIndexNames := make([]string, 0, len(actualIndexes)) 124 for _, actualIndex := range actualIndexes { 125 mNameActualIndex[actualIndex.Name()] = actualIndex 126 allActualIndexNames = append(allActualIndexNames, actualIndex.Name()) 127 } 128 for _, expIndex := range expIndexes { 129 _, has := mNameActualIndex[expIndex.Name()] 130 require.Truef(t, has, "expIndex name %s not in actualIndexes %v", expIndex.Name(), allActualIndexNames) 131 require.Equal(t, mNameActualIndex[expIndex.Name()].IndexType(), expIndex.IndexType()) 132 require.Equal(t, mNameActualIndex[expIndex.Name()].Params(), expIndex.Params()) 133 } 134 } 135 136 // EqualColumn assert field data is equal of two columns 137 func EqualColumn(t *testing.T, columnA entity.Column, columnB entity.Column) { 138 require.Equal(t, columnA.Name(), columnB.Name()) 139 require.Equal(t, columnA.Type(), columnB.Type()) 140 switch columnA.Type() { 141 case entity.FieldTypeBool: 142 require.ElementsMatch(t, columnA.(*entity.ColumnBool).Data(), columnB.(*entity.ColumnBool).Data()) 143 case entity.FieldTypeInt8: 144 require.ElementsMatch(t, columnA.(*entity.ColumnInt8).Data(), columnB.(*entity.ColumnInt8).Data()) 145 case entity.FieldTypeInt16: 146 require.ElementsMatch(t, columnA.(*entity.ColumnInt16).Data(), columnB.(*entity.ColumnInt16).Data()) 147 case entity.FieldTypeInt32: 148 require.ElementsMatch(t, columnA.(*entity.ColumnInt32).Data(), columnB.(*entity.ColumnInt32).Data()) 149 case entity.FieldTypeInt64: 150 require.ElementsMatch(t, columnA.(*entity.ColumnInt64).Data(), columnB.(*entity.ColumnInt64).Data()) 151 case entity.FieldTypeFloat: 152 require.ElementsMatch(t, columnA.(*entity.ColumnFloat).Data(), columnB.(*entity.ColumnFloat).Data()) 153 case entity.FieldTypeDouble: 154 require.ElementsMatch(t, columnA.(*entity.ColumnDouble).Data(), columnB.(*entity.ColumnDouble).Data()) 155 case entity.FieldTypeVarChar: 156 require.ElementsMatch(t, columnA.(*entity.ColumnVarChar).Data(), columnB.(*entity.ColumnVarChar).Data()) 157 case entity.FieldTypeJSON: 158 log.Printf("columnA: %s", columnA.(*entity.ColumnJSONBytes).Data()) 159 log.Printf("columnB: %s", columnB.(*entity.ColumnJSONBytes).Data()) 160 require.ElementsMatch(t, columnA.(*entity.ColumnJSONBytes).Data(), columnB.(*entity.ColumnJSONBytes).Data()) 161 case entity.FieldTypeFloatVector: 162 require.ElementsMatch(t, columnA.(*entity.ColumnFloatVector).Data(), columnB.(*entity.ColumnFloatVector).Data()) 163 case entity.FieldTypeBinaryVector: 164 require.ElementsMatch(t, columnA.(*entity.ColumnBinaryVector).Data(), columnB.(*entity.ColumnBinaryVector).Data()) 165 case entity.FieldTypeFloat16Vector: 166 require.ElementsMatch(t, columnA.(*entity.ColumnFloat16Vector).Data(), columnB.(*entity.ColumnFloat16Vector).Data()) 167 case entity.FieldTypeBFloat16Vector: 168 require.ElementsMatch(t, columnA.(*entity.ColumnBFloat16Vector).Data(), columnB.(*entity.ColumnBFloat16Vector).Data()) 169 case entity.FieldTypeSparseVector: 170 require.ElementsMatch(t, columnA.(*entity.ColumnSparseFloatVector).Data(), columnB.(*entity.ColumnSparseFloatVector).Data()) 171 case entity.FieldTypeArray: 172 EqualArrayColumn(t, columnA, columnB) 173 default: 174 log.Printf("The column type not in: [%v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v, %v]", 175 entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32, 176 entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeString, 177 entity.FieldTypeVarChar, entity.FieldTypeArray, entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector) 178 } 179 } 180 181 // EqualColumn assert field data is equal of two columns 182 func EqualArrayColumn(t *testing.T, columnA entity.Column, columnB entity.Column) { 183 require.Equal(t, columnA.Name(), columnB.Name()) 184 require.IsType(t, columnA.Type(), entity.FieldTypeArray) 185 require.IsType(t, columnB.Type(), entity.FieldTypeArray) 186 switch columnA.(type) { 187 case *entity.ColumnBoolArray: 188 require.ElementsMatch(t, columnA.(*entity.ColumnBoolArray).Data(), columnB.(*entity.ColumnBoolArray).Data()) 189 case *entity.ColumnInt8Array: 190 require.ElementsMatch(t, columnA.(*entity.ColumnInt8Array).Data(), columnB.(*entity.ColumnInt8Array).Data()) 191 case *entity.ColumnInt16Array: 192 require.ElementsMatch(t, columnA.(*entity.ColumnInt16Array).Data(), columnB.(*entity.ColumnInt16Array).Data()) 193 case *entity.ColumnInt32Array: 194 require.ElementsMatch(t, columnA.(*entity.ColumnInt32Array).Data(), columnB.(*entity.ColumnInt32Array).Data()) 195 case *entity.ColumnInt64Array: 196 require.ElementsMatch(t, columnA.(*entity.ColumnInt64Array).Data(), columnB.(*entity.ColumnInt64Array).Data()) 197 case *entity.ColumnFloatArray: 198 require.ElementsMatch(t, columnA.(*entity.ColumnFloatArray).Data(), columnB.(*entity.ColumnFloatArray).Data()) 199 case *entity.ColumnDoubleArray: 200 require.ElementsMatch(t, columnA.(*entity.ColumnDoubleArray).Data(), columnB.(*entity.ColumnDoubleArray).Data()) 201 case *entity.ColumnVarCharArray: 202 require.ElementsMatch(t, columnA.(*entity.ColumnVarCharArray).Data(), columnB.(*entity.ColumnVarCharArray).Data()) 203 default: 204 log.Printf("Now support array type: [%v, %v, %v, %v, %v, %v, %v, %v]", 205 entity.FieldTypeBool, entity.FieldTypeInt8, entity.FieldTypeInt16, entity.FieldTypeInt32, 206 entity.FieldTypeInt64, entity.FieldTypeFloat, entity.FieldTypeDouble, entity.FieldTypeVarChar) 207 } 208 } 209 210 // CheckQueryResult check query result, column name, type and field 211 func CheckQueryResult(t *testing.T, actualColumns []entity.Column, expColumns []entity.Column) { 212 require.GreaterOrEqual(t, len(actualColumns), len(expColumns), 213 "The len of actual columns %d should greater or equal to the expected columns %d", len(actualColumns), len(expColumns)) 214 for _, expColumn := range expColumns { 215 for _, actualColumn := range actualColumns { 216 if expColumn.Name() == actualColumn.Name() { 217 EqualColumn(t, expColumn, actualColumn) 218 } 219 } 220 } 221 } 222 223 // CheckOutputFields check query output fields 224 func CheckOutputFields(t *testing.T, actualColumns []entity.Column, expFields []string) { 225 actualFields := make([]string, 0) 226 for _, actualColumn := range actualColumns { 227 actualFields = append(actualFields, actualColumn.Name()) 228 } 229 require.ElementsMatchf(t, expFields, actualFields, fmt.Sprintf("Expected search output fields: %v, actual: %v", expFields, actualFields)) 230 } 231 232 // CheckSearchResult check search result, check nq, topk, ids, score 233 func CheckSearchResult(t *testing.T, actualSearchResults []client.SearchResult, expNq int, expTopK int) { 234 require.Equal(t, len(actualSearchResults), expNq) 235 for _, actualSearchResult := range actualSearchResults { 236 require.Equal(t, actualSearchResult.ResultCount, expTopK) 237 } 238 //expContainedIds entity.Column 239 240 } 241 242 func EqualIntSlice(a []int, b []int) bool { 243 if len(a) != len(b) { 244 return false 245 } 246 for i := range a { 247 if a[i] != b[i] { 248 return false 249 } 250 } 251 return true 252 } 253 254 type CheckIteratorOption func(opt *checkIteratorOpt) 255 256 type checkIteratorOpt struct { 257 expBatchSize []int 258 expOutputFields []string 259 } 260 261 func WithExpBatchSize(expBatchSize []int) CheckIteratorOption { 262 return func(opt *checkIteratorOpt) { 263 opt.expBatchSize = expBatchSize 264 } 265 } 266 267 func WithExpOutputFields(expOutputFields []string) CheckIteratorOption { 268 return func(opt *checkIteratorOpt) { 269 opt.expOutputFields = expOutputFields 270 } 271 } 272 273 // check queryIterator: result limit, each batch size, output fields 274 func CheckQueryIteratorResult(ctx context.Context, t *testing.T, itr *client.QueryIterator, expLimit int, opts ...CheckIteratorOption) { 275 opt := &checkIteratorOpt{} 276 for _, o := range opts { 277 o(opt) 278 } 279 actualLimit := 0 280 var actualBatchSize []int 281 for { 282 rs, err := itr.Next(ctx) 283 if err != nil { 284 if err == io.EOF { 285 break 286 } 287 log.Fatalf("QueryIterator next gets error: %v", err) 288 } 289 //log.Printf("QueryIterator result len: %d", rs.Len()) 290 //log.Printf("QueryIterator result data: %d", rs.GetColumn("int64")) 291 292 if opt.expBatchSize != nil { 293 actualBatchSize = append(actualBatchSize, rs.Len()) 294 } 295 var actualOutputFields []string 296 if opt.expOutputFields != nil { 297 for _, column := range rs { 298 actualOutputFields = append(actualOutputFields, column.Name()) 299 } 300 require.ElementsMatch(t, opt.expOutputFields, actualOutputFields) 301 } 302 actualLimit = actualLimit + rs.Len() 303 } 304 require.Equal(t, expLimit, actualLimit) 305 if opt.expBatchSize != nil { 306 log.Printf("QueryIterator result len: %v", actualBatchSize) 307 require.True(t, EqualIntSlice(opt.expBatchSize, actualBatchSize)) 308 } 309 } 310 311 // CheckPersistentSegments check persistent segments 312 func CheckPersistentSegments(t *testing.T, actualSegments []*entity.Segment, expNb int64) { 313 actualNb := int64(0) 314 for _, segment := range actualSegments { 315 actualNb = segment.NumRows + actualNb 316 } 317 require.Equal(t, actualNb, expNb) 318 } 319 320 func CheckResourceGroup(t *testing.T, actualRg *entity.ResourceGroup, expRg *entity.ResourceGroup) { 321 require.EqualValues(t, expRg, actualRg) 322 } 323 324 func getDbNames(dbs []entity.Database) []string { 325 allDbNames := make([]string, 0, len(dbs)) 326 for _, db := range dbs { 327 allDbNames = append(allDbNames, db.Name) 328 } 329 return allDbNames 330 } 331 332 // CheckContainsDb check collections contains collName 333 func CheckContainsDb(t *testing.T, dbs []entity.Database, dbName string) { 334 allDbNames := getDbNames(dbs) 335 require.Containsf(t, allDbNames, dbName, fmt.Sprintf("%s db not in dbs: %v", dbName, dbs)) 336 } 337 338 // CheckNotContainsDb check collections contains collName 339 func CheckNotContainsDb(t *testing.T, dbs []entity.Database, dbName string) { 340 allDbNames := getDbNames(dbs) 341 require.NotContainsf(t, allDbNames, dbName, fmt.Sprintf("%s db should not be in dbs: %v", dbName, dbs)) 342 }