github.com/weaviate/weaviate@v1.24.6/usecases/traverser/near_params_vector_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package traverser 13 14 import ( 15 "context" 16 "reflect" 17 "testing" 18 19 "github.com/go-openapi/strfmt" 20 "github.com/stretchr/testify/assert" 21 "github.com/weaviate/weaviate/entities/additional" 22 "github.com/weaviate/weaviate/entities/schema/crossref" 23 "github.com/weaviate/weaviate/entities/search" 24 "github.com/weaviate/weaviate/entities/searchparams" 25 ) 26 27 func Test_nearParamsVector_validateNearParams(t *testing.T) { 28 type args struct { 29 nearVector *searchparams.NearVector 30 nearObject *searchparams.NearObject 31 moduleParams map[string]interface{} 32 className []string 33 } 34 tests := []struct { 35 name string 36 args args 37 wantErr bool 38 errMessage string 39 }{ 40 { 41 name: "Should be OK, when all near params are nil", 42 args: args{ 43 nearVector: nil, 44 nearObject: nil, 45 moduleParams: nil, 46 className: nil, 47 }, 48 wantErr: false, 49 }, 50 { 51 name: "Should be OK, when nearVector param is set", 52 args: args{ 53 nearVector: &searchparams.NearVector{}, 54 nearObject: nil, 55 moduleParams: nil, 56 className: nil, 57 }, 58 wantErr: false, 59 }, 60 { 61 name: "Should be OK, when nearObject param is set", 62 args: args{ 63 nearVector: nil, 64 nearObject: &searchparams.NearObject{}, 65 moduleParams: nil, 66 className: nil, 67 }, 68 wantErr: false, 69 }, 70 { 71 name: "Should be OK, when moduleParams param is set", 72 args: args{ 73 nearVector: nil, 74 nearObject: nil, 75 moduleParams: map[string]interface{}{ 76 "nearCustomText": &nearCustomTextParams{}, 77 }, 78 className: nil, 79 }, 80 wantErr: false, 81 }, 82 { 83 name: "Should throw error, when nearVector and nearObject is set", 84 args: args{ 85 nearVector: &searchparams.NearVector{}, 86 nearObject: &searchparams.NearObject{}, 87 moduleParams: nil, 88 className: nil, 89 }, 90 wantErr: true, 91 errMessage: "found both 'nearVector' and 'nearObject' parameters which are conflicting, choose one instead", 92 }, 93 { 94 name: "Should throw error, when nearVector and moduleParams is set", 95 args: args{ 96 nearVector: &searchparams.NearVector{}, 97 nearObject: nil, 98 moduleParams: map[string]interface{}{ 99 "nearCustomText": &nearCustomTextParams{}, 100 }, 101 className: nil, 102 }, 103 wantErr: true, 104 errMessage: "found both 'nearText' and 'nearVector' parameters which are conflicting, choose one instead", 105 }, 106 { 107 name: "Should throw error, when nearObject and moduleParams is set", 108 args: args{ 109 nearVector: nil, 110 nearObject: &searchparams.NearObject{}, 111 moduleParams: map[string]interface{}{ 112 "nearCustomText": &nearCustomTextParams{}, 113 }, 114 className: nil, 115 }, 116 wantErr: true, 117 errMessage: "found both 'nearText' and 'nearObject' parameters which are conflicting, choose one instead", 118 }, 119 { 120 name: "Should throw error, when nearVector and nearObject and moduleParams is set", 121 args: args{ 122 nearVector: &searchparams.NearVector{}, 123 nearObject: &searchparams.NearObject{}, 124 moduleParams: map[string]interface{}{ 125 "nearCustomText": &nearCustomTextParams{}, 126 }, 127 className: nil, 128 }, 129 wantErr: true, 130 errMessage: "found 'nearText' and 'nearVector' and 'nearObject' parameters which are conflicting, choose one instead", 131 }, 132 { 133 name: "Should throw error, when nearVector certainty and distance are set", 134 args: args{ 135 nearVector: &searchparams.NearVector{ 136 Certainty: 0.1, 137 Distance: 0.9, 138 WithDistance: true, 139 }, 140 className: nil, 141 }, 142 wantErr: true, 143 errMessage: "found 'certainty' and 'distance' set in nearVector which are conflicting, choose one instead", 144 }, 145 { 146 name: "Should throw error, when nearObject certainty and distance are set", 147 args: args{ 148 nearObject: &searchparams.NearObject{ 149 Certainty: 0.1, 150 Distance: 0.9, 151 WithDistance: true, 152 }, 153 className: nil, 154 }, 155 wantErr: true, 156 errMessage: "found 'certainty' and 'distance' set in nearObject which are conflicting, choose one instead", 157 }, 158 { 159 name: "Should throw error, when nearText certainty and distance are set", 160 args: args{ 161 moduleParams: map[string]interface{}{ 162 "nearCustomText": &nearCustomTextParams{ 163 Certainty: 0.1, 164 Distance: 0.9, 165 WithDistance: true, 166 }, 167 }, 168 className: nil, 169 }, 170 wantErr: true, 171 errMessage: "nearText cannot provide both distance and certainty", 172 }, 173 } 174 for _, tt := range tests { 175 t.Run(tt.name, func(t *testing.T) { 176 e := &nearParamsVector{ 177 modulesProvider: &fakeModulesProvider{}, 178 search: &fakeNearParamsSearcher{}, 179 } 180 err := e.validateNearParams(tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams, tt.args.className...) 181 if (err != nil) != tt.wantErr { 182 t.Errorf("nearParamsVector.validateNearParams() error = %v, wantErr %v", err, tt.wantErr) 183 } 184 if err != nil && tt.errMessage != err.Error() { 185 t.Errorf("nearParamsVector.validateNearParams() error = %v, errMessage = %v", err, tt.errMessage) 186 } 187 }) 188 } 189 } 190 191 func Test_nearParamsVector_vectorFromParams(t *testing.T) { 192 type args struct { 193 ctx context.Context 194 nearVector *searchparams.NearVector 195 nearObject *searchparams.NearObject 196 moduleParams map[string]interface{} 197 className string 198 } 199 tests := []struct { 200 name string 201 args args 202 want []float32 203 wantErr bool 204 }{ 205 { 206 name: "Should get vector from nearVector", 207 args: args{ 208 nearVector: &searchparams.NearVector{ 209 Vector: []float32{1.1, 1.0, 0.1}, 210 }, 211 }, 212 want: []float32{1.1, 1.0, 0.1}, 213 wantErr: false, 214 }, 215 { 216 name: "Should get vector from nearObject", 217 args: args{ 218 nearObject: &searchparams.NearObject{ 219 ID: "uuid", 220 }, 221 }, 222 want: []float32{1.0, 1.0, 1.0}, 223 wantErr: false, 224 }, 225 { 226 name: "Should get vector from nearText", 227 args: args{ 228 moduleParams: map[string]interface{}{ 229 "nearCustomText": &nearCustomTextParams{ 230 Values: []string{"a"}, 231 }, 232 }, 233 }, 234 want: []float32{1, 2, 3}, 235 wantErr: false, 236 }, 237 { 238 name: "Should get vector from nearObject", 239 args: args{ 240 nearObject: &searchparams.NearObject{ 241 Beacon: crossref.NewLocalhost("Class", "uuid").String(), 242 }, 243 }, 244 wantErr: true, 245 }, 246 { 247 name: "Should get vector from nearObject", 248 args: args{ 249 nearObject: &searchparams.NearObject{ 250 Beacon: crossref.NewLocalhost("Class", "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf").String(), 251 }, 252 }, 253 want: []float32{1.0, 1.0, 1.0}, 254 wantErr: false, 255 }, 256 { 257 name: "Should get vector from nearObject across classes", 258 args: args{ 259 nearObject: &searchparams.NearObject{ 260 Beacon: crossref.NewLocalhost("SpecifiedClass", "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf").String(), 261 }, 262 }, 263 want: []float32{0.0, 0.0, 0.0}, 264 wantErr: false, 265 }, 266 } 267 for _, tt := range tests { 268 t.Run(tt.name, func(t *testing.T) { 269 e := &nearParamsVector{ 270 modulesProvider: &fakeModulesProvider{}, 271 search: &fakeNearParamsSearcher{}, 272 } 273 got, targetVector, err := e.vectorFromParams(tt.args.ctx, tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams, tt.args.className, "") 274 if (err != nil) != tt.wantErr { 275 t.Errorf("nearParamsVector.vectorFromParams() error = %v, wantErr %v", err, tt.wantErr) 276 return 277 } 278 if !reflect.DeepEqual(got, tt.want) { 279 t.Errorf("nearParamsVector.vectorFromParams() = %v, want %v", got, tt.want) 280 } 281 assert.Equal(t, "", targetVector) 282 }) 283 } 284 } 285 286 func Test_nearParamsVector_extractCertaintyFromParams(t *testing.T) { 287 type args struct { 288 nearVector *searchparams.NearVector 289 nearObject *searchparams.NearObject 290 moduleParams map[string]interface{} 291 } 292 tests := []struct { 293 name string 294 args args 295 want float64 296 }{ 297 { 298 name: "Should extract distance from nearVector", 299 args: args{ 300 nearVector: &searchparams.NearVector{ 301 Distance: 0.88, 302 WithDistance: true, 303 }, 304 }, 305 want: 1 - 0.88/2, 306 }, 307 { 308 name: "Should extract certainty from nearVector", 309 args: args{ 310 nearVector: &searchparams.NearVector{ 311 Certainty: 0.88, 312 }, 313 }, 314 want: 0.88, 315 }, 316 { 317 name: "Should extract distance from nearObject", 318 args: args{ 319 nearObject: &searchparams.NearObject{ 320 Distance: 0.99, 321 WithDistance: true, 322 }, 323 }, 324 want: 1 - 0.99/2, 325 }, 326 { 327 name: "Should extract certainty from nearObject", 328 args: args{ 329 nearObject: &searchparams.NearObject{ 330 Certainty: 0.99, 331 }, 332 }, 333 want: 0.99, 334 }, 335 { 336 name: "Should extract distance from nearText", 337 args: args{ 338 moduleParams: map[string]interface{}{ 339 "nearCustomText": &nearCustomTextParams{ 340 Distance: 0.77, 341 WithDistance: true, 342 }, 343 }, 344 }, 345 want: 1 - 0.77/2, 346 }, 347 { 348 name: "Should extract certainty from nearText", 349 args: args{ 350 moduleParams: map[string]interface{}{ 351 "nearCustomText": &nearCustomTextParams{ 352 Certainty: 0.77, 353 }, 354 }, 355 }, 356 want: 0.77, 357 }, 358 } 359 for _, tt := range tests { 360 t.Run(tt.name, func(t *testing.T) { 361 e := &nearParamsVector{ 362 modulesProvider: &fakeModulesProvider{}, 363 search: &fakeNearParamsSearcher{}, 364 } 365 got := e.extractCertaintyFromParams(tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams) 366 if !assert.InDelta(t, tt.want, got, 1e-9) { 367 t.Errorf("nearParamsVector.extractCertaintyFromParams() = %v, want %v", got, tt.want) 368 } 369 }) 370 } 371 } 372 373 type fakeNearParamsSearcher struct{} 374 375 func (f *fakeNearParamsSearcher) ObjectsByID(ctx context.Context, id strfmt.UUID, 376 props search.SelectProperties, additional additional.Properties, tenant string, 377 ) (search.Results, error) { 378 return search.Results{ 379 {Vector: []float32{1.0, 1.0, 1.0}}, 380 }, nil 381 } 382 383 func (f *fakeNearParamsSearcher) Object(ctx context.Context, className string, id strfmt.UUID, 384 props search.SelectProperties, additional additional.Properties, 385 repl *additional.ReplicationProperties, tenant string, 386 ) (*search.Result, error) { 387 if className == "SpecifiedClass" { 388 return &search.Result{ 389 Vector: []float32{0.0, 0.0, 0.0}, 390 }, nil 391 } else { 392 return &search.Result{ 393 Vector: []float32{1.0, 1.0, 1.0}, 394 }, nil 395 } 396 }