github.com/weaviate/weaviate@v1.24.6/usecases/traverser/near_params_vector_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package traverser
    13  
    14  import (
    15  	"context"
    16  	"reflect"
    17  	"testing"
    18  
    19  	"github.com/go-openapi/strfmt"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/weaviate/weaviate/entities/additional"
    22  	"github.com/weaviate/weaviate/entities/schema/crossref"
    23  	"github.com/weaviate/weaviate/entities/search"
    24  	"github.com/weaviate/weaviate/entities/searchparams"
    25  )
    26  
    27  func Test_nearParamsVector_validateNearParams(t *testing.T) {
    28  	type args struct {
    29  		nearVector   *searchparams.NearVector
    30  		nearObject   *searchparams.NearObject
    31  		moduleParams map[string]interface{}
    32  		className    []string
    33  	}
    34  	tests := []struct {
    35  		name       string
    36  		args       args
    37  		wantErr    bool
    38  		errMessage string
    39  	}{
    40  		{
    41  			name: "Should be OK, when all near params are nil",
    42  			args: args{
    43  				nearVector:   nil,
    44  				nearObject:   nil,
    45  				moduleParams: nil,
    46  				className:    nil,
    47  			},
    48  			wantErr: false,
    49  		},
    50  		{
    51  			name: "Should be OK, when nearVector param is set",
    52  			args: args{
    53  				nearVector:   &searchparams.NearVector{},
    54  				nearObject:   nil,
    55  				moduleParams: nil,
    56  				className:    nil,
    57  			},
    58  			wantErr: false,
    59  		},
    60  		{
    61  			name: "Should be OK, when nearObject param is set",
    62  			args: args{
    63  				nearVector:   nil,
    64  				nearObject:   &searchparams.NearObject{},
    65  				moduleParams: nil,
    66  				className:    nil,
    67  			},
    68  			wantErr: false,
    69  		},
    70  		{
    71  			name: "Should be OK, when moduleParams param is set",
    72  			args: args{
    73  				nearVector: nil,
    74  				nearObject: nil,
    75  				moduleParams: map[string]interface{}{
    76  					"nearCustomText": &nearCustomTextParams{},
    77  				},
    78  				className: nil,
    79  			},
    80  			wantErr: false,
    81  		},
    82  		{
    83  			name: "Should throw error, when nearVector and nearObject is set",
    84  			args: args{
    85  				nearVector:   &searchparams.NearVector{},
    86  				nearObject:   &searchparams.NearObject{},
    87  				moduleParams: nil,
    88  				className:    nil,
    89  			},
    90  			wantErr:    true,
    91  			errMessage: "found both 'nearVector' and 'nearObject' parameters which are conflicting, choose one instead",
    92  		},
    93  		{
    94  			name: "Should throw error, when nearVector and moduleParams is set",
    95  			args: args{
    96  				nearVector: &searchparams.NearVector{},
    97  				nearObject: nil,
    98  				moduleParams: map[string]interface{}{
    99  					"nearCustomText": &nearCustomTextParams{},
   100  				},
   101  				className: nil,
   102  			},
   103  			wantErr:    true,
   104  			errMessage: "found both 'nearText' and 'nearVector' parameters which are conflicting, choose one instead",
   105  		},
   106  		{
   107  			name: "Should throw error, when nearObject and moduleParams is set",
   108  			args: args{
   109  				nearVector: nil,
   110  				nearObject: &searchparams.NearObject{},
   111  				moduleParams: map[string]interface{}{
   112  					"nearCustomText": &nearCustomTextParams{},
   113  				},
   114  				className: nil,
   115  			},
   116  			wantErr:    true,
   117  			errMessage: "found both 'nearText' and 'nearObject' parameters which are conflicting, choose one instead",
   118  		},
   119  		{
   120  			name: "Should throw error, when nearVector and nearObject and moduleParams is set",
   121  			args: args{
   122  				nearVector: &searchparams.NearVector{},
   123  				nearObject: &searchparams.NearObject{},
   124  				moduleParams: map[string]interface{}{
   125  					"nearCustomText": &nearCustomTextParams{},
   126  				},
   127  				className: nil,
   128  			},
   129  			wantErr:    true,
   130  			errMessage: "found 'nearText' and 'nearVector' and 'nearObject' parameters which are conflicting, choose one instead",
   131  		},
   132  		{
   133  			name: "Should throw error, when nearVector certainty and distance are set",
   134  			args: args{
   135  				nearVector: &searchparams.NearVector{
   136  					Certainty:    0.1,
   137  					Distance:     0.9,
   138  					WithDistance: true,
   139  				},
   140  				className: nil,
   141  			},
   142  			wantErr:    true,
   143  			errMessage: "found 'certainty' and 'distance' set in nearVector which are conflicting, choose one instead",
   144  		},
   145  		{
   146  			name: "Should throw error, when nearObject certainty and distance are set",
   147  			args: args{
   148  				nearObject: &searchparams.NearObject{
   149  					Certainty:    0.1,
   150  					Distance:     0.9,
   151  					WithDistance: true,
   152  				},
   153  				className: nil,
   154  			},
   155  			wantErr:    true,
   156  			errMessage: "found 'certainty' and 'distance' set in nearObject which are conflicting, choose one instead",
   157  		},
   158  		{
   159  			name: "Should throw error, when nearText certainty and distance are set",
   160  			args: args{
   161  				moduleParams: map[string]interface{}{
   162  					"nearCustomText": &nearCustomTextParams{
   163  						Certainty:    0.1,
   164  						Distance:     0.9,
   165  						WithDistance: true,
   166  					},
   167  				},
   168  				className: nil,
   169  			},
   170  			wantErr:    true,
   171  			errMessage: "nearText cannot provide both distance and certainty",
   172  		},
   173  	}
   174  	for _, tt := range tests {
   175  		t.Run(tt.name, func(t *testing.T) {
   176  			e := &nearParamsVector{
   177  				modulesProvider: &fakeModulesProvider{},
   178  				search:          &fakeNearParamsSearcher{},
   179  			}
   180  			err := e.validateNearParams(tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams, tt.args.className...)
   181  			if (err != nil) != tt.wantErr {
   182  				t.Errorf("nearParamsVector.validateNearParams() error = %v, wantErr %v", err, tt.wantErr)
   183  			}
   184  			if err != nil && tt.errMessage != err.Error() {
   185  				t.Errorf("nearParamsVector.validateNearParams() error = %v, errMessage = %v", err, tt.errMessage)
   186  			}
   187  		})
   188  	}
   189  }
   190  
   191  func Test_nearParamsVector_vectorFromParams(t *testing.T) {
   192  	type args struct {
   193  		ctx          context.Context
   194  		nearVector   *searchparams.NearVector
   195  		nearObject   *searchparams.NearObject
   196  		moduleParams map[string]interface{}
   197  		className    string
   198  	}
   199  	tests := []struct {
   200  		name    string
   201  		args    args
   202  		want    []float32
   203  		wantErr bool
   204  	}{
   205  		{
   206  			name: "Should get vector from nearVector",
   207  			args: args{
   208  				nearVector: &searchparams.NearVector{
   209  					Vector: []float32{1.1, 1.0, 0.1},
   210  				},
   211  			},
   212  			want:    []float32{1.1, 1.0, 0.1},
   213  			wantErr: false,
   214  		},
   215  		{
   216  			name: "Should get vector from nearObject",
   217  			args: args{
   218  				nearObject: &searchparams.NearObject{
   219  					ID: "uuid",
   220  				},
   221  			},
   222  			want:    []float32{1.0, 1.0, 1.0},
   223  			wantErr: false,
   224  		},
   225  		{
   226  			name: "Should get vector from nearText",
   227  			args: args{
   228  				moduleParams: map[string]interface{}{
   229  					"nearCustomText": &nearCustomTextParams{
   230  						Values: []string{"a"},
   231  					},
   232  				},
   233  			},
   234  			want:    []float32{1, 2, 3},
   235  			wantErr: false,
   236  		},
   237  		{
   238  			name: "Should get vector from nearObject",
   239  			args: args{
   240  				nearObject: &searchparams.NearObject{
   241  					Beacon: crossref.NewLocalhost("Class", "uuid").String(),
   242  				},
   243  			},
   244  			wantErr: true,
   245  		},
   246  		{
   247  			name: "Should get vector from nearObject",
   248  			args: args{
   249  				nearObject: &searchparams.NearObject{
   250  					Beacon: crossref.NewLocalhost("Class", "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf").String(),
   251  				},
   252  			},
   253  			want:    []float32{1.0, 1.0, 1.0},
   254  			wantErr: false,
   255  		},
   256  		{
   257  			name: "Should get vector from nearObject across classes",
   258  			args: args{
   259  				nearObject: &searchparams.NearObject{
   260  					Beacon: crossref.NewLocalhost("SpecifiedClass", "e5dc4a4c-ef0f-3aed-89a3-a73435c6bbcf").String(),
   261  				},
   262  			},
   263  			want:    []float32{0.0, 0.0, 0.0},
   264  			wantErr: false,
   265  		},
   266  	}
   267  	for _, tt := range tests {
   268  		t.Run(tt.name, func(t *testing.T) {
   269  			e := &nearParamsVector{
   270  				modulesProvider: &fakeModulesProvider{},
   271  				search:          &fakeNearParamsSearcher{},
   272  			}
   273  			got, targetVector, err := e.vectorFromParams(tt.args.ctx, tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams, tt.args.className, "")
   274  			if (err != nil) != tt.wantErr {
   275  				t.Errorf("nearParamsVector.vectorFromParams() error = %v, wantErr %v", err, tt.wantErr)
   276  				return
   277  			}
   278  			if !reflect.DeepEqual(got, tt.want) {
   279  				t.Errorf("nearParamsVector.vectorFromParams() = %v, want %v", got, tt.want)
   280  			}
   281  			assert.Equal(t, "", targetVector)
   282  		})
   283  	}
   284  }
   285  
   286  func Test_nearParamsVector_extractCertaintyFromParams(t *testing.T) {
   287  	type args struct {
   288  		nearVector   *searchparams.NearVector
   289  		nearObject   *searchparams.NearObject
   290  		moduleParams map[string]interface{}
   291  	}
   292  	tests := []struct {
   293  		name string
   294  		args args
   295  		want float64
   296  	}{
   297  		{
   298  			name: "Should extract distance from nearVector",
   299  			args: args{
   300  				nearVector: &searchparams.NearVector{
   301  					Distance:     0.88,
   302  					WithDistance: true,
   303  				},
   304  			},
   305  			want: 1 - 0.88/2,
   306  		},
   307  		{
   308  			name: "Should extract certainty from nearVector",
   309  			args: args{
   310  				nearVector: &searchparams.NearVector{
   311  					Certainty: 0.88,
   312  				},
   313  			},
   314  			want: 0.88,
   315  		},
   316  		{
   317  			name: "Should extract distance from nearObject",
   318  			args: args{
   319  				nearObject: &searchparams.NearObject{
   320  					Distance:     0.99,
   321  					WithDistance: true,
   322  				},
   323  			},
   324  			want: 1 - 0.99/2,
   325  		},
   326  		{
   327  			name: "Should extract certainty from nearObject",
   328  			args: args{
   329  				nearObject: &searchparams.NearObject{
   330  					Certainty: 0.99,
   331  				},
   332  			},
   333  			want: 0.99,
   334  		},
   335  		{
   336  			name: "Should extract distance from nearText",
   337  			args: args{
   338  				moduleParams: map[string]interface{}{
   339  					"nearCustomText": &nearCustomTextParams{
   340  						Distance:     0.77,
   341  						WithDistance: true,
   342  					},
   343  				},
   344  			},
   345  			want: 1 - 0.77/2,
   346  		},
   347  		{
   348  			name: "Should extract certainty from nearText",
   349  			args: args{
   350  				moduleParams: map[string]interface{}{
   351  					"nearCustomText": &nearCustomTextParams{
   352  						Certainty: 0.77,
   353  					},
   354  				},
   355  			},
   356  			want: 0.77,
   357  		},
   358  	}
   359  	for _, tt := range tests {
   360  		t.Run(tt.name, func(t *testing.T) {
   361  			e := &nearParamsVector{
   362  				modulesProvider: &fakeModulesProvider{},
   363  				search:          &fakeNearParamsSearcher{},
   364  			}
   365  			got := e.extractCertaintyFromParams(tt.args.nearVector, tt.args.nearObject, tt.args.moduleParams)
   366  			if !assert.InDelta(t, tt.want, got, 1e-9) {
   367  				t.Errorf("nearParamsVector.extractCertaintyFromParams() = %v, want %v", got, tt.want)
   368  			}
   369  		})
   370  	}
   371  }
   372  
   373  type fakeNearParamsSearcher struct{}
   374  
   375  func (f *fakeNearParamsSearcher) ObjectsByID(ctx context.Context, id strfmt.UUID,
   376  	props search.SelectProperties, additional additional.Properties, tenant string,
   377  ) (search.Results, error) {
   378  	return search.Results{
   379  		{Vector: []float32{1.0, 1.0, 1.0}},
   380  	}, nil
   381  }
   382  
   383  func (f *fakeNearParamsSearcher) Object(ctx context.Context, className string, id strfmt.UUID,
   384  	props search.SelectProperties, additional additional.Properties,
   385  	repl *additional.ReplicationProperties, tenant string,
   386  ) (*search.Result, error) {
   387  	if className == "SpecifiedClass" {
   388  		return &search.Result{
   389  			Vector: []float32{0.0, 0.0, 0.0},
   390  		}, nil
   391  	} else {
   392  		return &search.Result{
   393  			Vector: []float32{1.0, 1.0, 1.0},
   394  		}, nil
   395  	}
   396  }