github.com/weaviate/weaviate@v1.24.6/modules/text2vec-transformers/vectorizer/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"errors"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/weaviate/weaviate/entities/models"
    21  	"github.com/weaviate/weaviate/usecases/modules"
    22  )
    23  
    24  func TestClassSettings(t *testing.T) {
    25  	t.Run("with all defaults", func(t *testing.T) {
    26  		class := &models.Class{
    27  			Class: "MyClass",
    28  			Properties: []*models.Property{{
    29  				Name: "someProp",
    30  			}},
    31  		}
    32  
    33  		cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "")
    34  		ic := NewClassSettings(cfg)
    35  
    36  		assert.True(t, ic.PropertyIndexed("someProp"))
    37  		assert.False(t, ic.VectorizePropertyName("someProp"))
    38  		assert.True(t, ic.VectorizeClassName())
    39  		assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
    40  	})
    41  
    42  	t.Run("with a nil config", func(t *testing.T) {
    43  		// this is the case if we were running in a situation such as a
    44  		// cross-class vectorization of search time, as is the case with Explore
    45  		// {}, we then expect all default values
    46  
    47  		ic := NewClassSettings(nil)
    48  
    49  		assert.True(t, ic.PropertyIndexed("someProp"))
    50  		assert.False(t, ic.VectorizePropertyName("someProp"))
    51  		assert.True(t, ic.VectorizeClassName())
    52  		assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
    53  	})
    54  
    55  	t.Run("with all explicit config matching the defaults", func(t *testing.T) {
    56  		class := &models.Class{
    57  			Class: "MyClass",
    58  			ModuleConfig: map[string]interface{}{
    59  				"my-module": map[string]interface{}{
    60  					"vectorizeClassName": true,
    61  					"poolingStrategy":    "masked_mean",
    62  				},
    63  			},
    64  			Properties: []*models.Property{{
    65  				Name: "someProp",
    66  				ModuleConfig: map[string]interface{}{
    67  					"my-module": map[string]interface{}{
    68  						"skip":                  false,
    69  						"vectorizePropertyName": false,
    70  					},
    71  				},
    72  			}},
    73  		}
    74  
    75  		cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "")
    76  		ic := NewClassSettings(cfg)
    77  
    78  		assert.True(t, ic.PropertyIndexed("someProp"))
    79  		assert.False(t, ic.VectorizePropertyName("someProp"))
    80  		assert.True(t, ic.VectorizeClassName())
    81  		assert.Equal(t, ic.PoolingStrategy(), "masked_mean")
    82  	})
    83  
    84  	t.Run("with all explicit config using non-default values", func(t *testing.T) {
    85  		class := &models.Class{
    86  			Class: "MyClass",
    87  			ModuleConfig: map[string]interface{}{
    88  				"my-module": map[string]interface{}{
    89  					"vectorizeClassName": false,
    90  					"poolingStrategy":    "cls",
    91  				},
    92  			},
    93  			Properties: []*models.Property{{
    94  				Name: "someProp",
    95  				ModuleConfig: map[string]interface{}{
    96  					"my-module": map[string]interface{}{
    97  						"skip":                  true,
    98  						"vectorizePropertyName": true,
    99  					},
   100  				},
   101  			}},
   102  		}
   103  
   104  		cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "")
   105  		ic := NewClassSettings(cfg)
   106  
   107  		assert.False(t, ic.PropertyIndexed("someProp"))
   108  		assert.True(t, ic.VectorizePropertyName("someProp"))
   109  		assert.False(t, ic.VectorizeClassName())
   110  		assert.Equal(t, ic.PoolingStrategy(), "cls")
   111  	})
   112  
   113  	t.Run("with target vector and properties", func(t *testing.T) {
   114  		targetVector := "targetVector"
   115  		propertyToIndex := "someProp"
   116  		class := &models.Class{
   117  			Class: "MyClass",
   118  			VectorConfig: map[string]models.VectorConfig{
   119  				targetVector: {
   120  					Vectorizer: map[string]interface{}{
   121  						"my-module": map[string]interface{}{
   122  							"vectorizeClassName": false,
   123  							"properties":         []interface{}{propertyToIndex},
   124  						},
   125  					},
   126  					VectorIndexType: "hnsw",
   127  				},
   128  			},
   129  			Properties: []*models.Property{
   130  				{
   131  					Name: propertyToIndex,
   132  					ModuleConfig: map[string]interface{}{
   133  						"my-module": map[string]interface{}{
   134  							"skip":                  true,
   135  							"vectorizePropertyName": true,
   136  						},
   137  					},
   138  				},
   139  				{
   140  					Name: "otherProp",
   141  				},
   142  			},
   143  		}
   144  
   145  		cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", targetVector)
   146  		ic := NewClassSettings(cfg)
   147  
   148  		assert.True(t, ic.PropertyIndexed(propertyToIndex))
   149  		assert.True(t, ic.VectorizePropertyName(propertyToIndex))
   150  		assert.False(t, ic.PropertyIndexed("otherProp"))
   151  		assert.False(t, ic.VectorizePropertyName("otherProp"))
   152  		assert.False(t, ic.VectorizeClassName())
   153  	})
   154  
   155  	t.Run("with inferenceUrl setting", func(t *testing.T) {
   156  		class := &models.Class{
   157  			Class: "MyClass",
   158  			VectorConfig: map[string]models.VectorConfig{
   159  				"withInferenceUrl": {
   160  					Vectorizer: map[string]interface{}{
   161  						"my-module": map[string]interface{}{
   162  							"vectorizeClassName": false,
   163  							"poolingStrategy":    "cls",
   164  							"inferenceUrl":       "http://inference.url",
   165  						},
   166  					},
   167  				},
   168  				"withPassageAndQueryInferenceUrl": {
   169  					Vectorizer: map[string]interface{}{
   170  						"my-module": map[string]interface{}{
   171  							"vectorizeClassName":  false,
   172  							"poolingStrategy":     "cls",
   173  							"passageInferenceUrl": "http://passage.inference.url",
   174  							"queryInferenceUrl":   "http://query.inference.url",
   175  						},
   176  					},
   177  				},
   178  			},
   179  
   180  			Properties: []*models.Property{{
   181  				Name: "someProp",
   182  				ModuleConfig: map[string]interface{}{
   183  					"my-module": map[string]interface{}{
   184  						"skip":                  true,
   185  						"vectorizePropertyName": true,
   186  					},
   187  				},
   188  			}},
   189  		}
   190  
   191  		cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "withInferenceUrl")
   192  		ic := NewClassSettings(cfg)
   193  
   194  		assert.False(t, ic.PropertyIndexed("someProp"))
   195  		assert.True(t, ic.VectorizePropertyName("someProp"))
   196  		assert.False(t, ic.VectorizeClassName())
   197  		assert.Equal(t, ic.PoolingStrategy(), "cls")
   198  		assert.Equal(t, ic.InferenceURL(), "http://inference.url")
   199  		assert.Empty(t, ic.PassageInferenceURL())
   200  		assert.Empty(t, ic.QueryInferenceURL())
   201  
   202  		cfg = modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "withPassageAndQueryInferenceUrl")
   203  		ic = NewClassSettings(cfg)
   204  
   205  		assert.False(t, ic.PropertyIndexed("someProp"))
   206  		assert.True(t, ic.VectorizePropertyName("someProp"))
   207  		assert.False(t, ic.VectorizeClassName())
   208  		assert.Equal(t, ic.PoolingStrategy(), "cls")
   209  		assert.Empty(t, ic.InferenceURL())
   210  		assert.Equal(t, ic.PassageInferenceURL(), "http://passage.inference.url")
   211  		assert.Equal(t, ic.QueryInferenceURL(), "http://query.inference.url")
   212  	})
   213  }
   214  
   215  func Test_classSettings_Validate(t *testing.T) {
   216  	tests := []struct {
   217  		name       string
   218  		vectorizer map[string]interface{}
   219  		wantErr    error
   220  	}{
   221  		{
   222  			name: "only inference url",
   223  			vectorizer: map[string]interface{}{
   224  				"vectorizeClassName": false,
   225  				"poolingStrategy":    "cls",
   226  				"inferenceUrl":       "http://inference.url",
   227  			},
   228  		},
   229  		{
   230  			name: "only passage and query inference urls",
   231  			vectorizer: map[string]interface{}{
   232  				"vectorizeClassName":  false,
   233  				"poolingStrategy":     "cls",
   234  				"passageInferenceUrl": "http://passage.inference.url",
   235  				"queryInferenceUrl":   "http://query.inference.url",
   236  			},
   237  		},
   238  		{
   239  			name: "error - all inference urls",
   240  			vectorizer: map[string]interface{}{
   241  				"vectorizeClassName":  false,
   242  				"poolingStrategy":     "cls",
   243  				"inferenceUrl":        "http://inference.url",
   244  				"passageInferenceUrl": "http://passage.inference.url",
   245  				"queryInferenceUrl":   "http://query.inference.url",
   246  			},
   247  			wantErr: errors.New("either inferenceUrl or passageInferenceUrl together with queryInferenceUrl needs to be set, not both"),
   248  		},
   249  		{
   250  			name: "error - all inference urls, without passage",
   251  			vectorizer: map[string]interface{}{
   252  				"vectorizeClassName": false,
   253  				"poolingStrategy":    "cls",
   254  				"inferenceUrl":       "http://inference.url",
   255  				"queryInferenceUrl":  "http://query.inference.url",
   256  			},
   257  			wantErr: errors.New("either inferenceUrl or passageInferenceUrl together with queryInferenceUrl needs to be set, not both"),
   258  		},
   259  		{
   260  			name: "error - all inference urls, without query",
   261  			vectorizer: map[string]interface{}{
   262  				"vectorizeClassName":  false,
   263  				"poolingStrategy":     "cls",
   264  				"inferenceUrl":        "http://inference.url",
   265  				"passageInferenceUrl": "http://passage.inference.url",
   266  			},
   267  			wantErr: errors.New("either inferenceUrl or passageInferenceUrl together with queryInferenceUrl needs to be set, not both"),
   268  		},
   269  		{
   270  			name: "error - passage inference url set but not query",
   271  			vectorizer: map[string]interface{}{
   272  				"vectorizeClassName":  false,
   273  				"poolingStrategy":     "cls",
   274  				"passageInferenceUrl": "http://passage.inference.url",
   275  			},
   276  			wantErr: errors.New("passageInferenceUrl is set but queryInferenceUrl is empty, both needs to be set"),
   277  		},
   278  		{
   279  			name: "error - query inference url set but not passage",
   280  			vectorizer: map[string]interface{}{
   281  				"vectorizeClassName": false,
   282  				"poolingStrategy":    "cls",
   283  				"queryInferenceUrl":  "http://passage.inference.url",
   284  			},
   285  			wantErr: errors.New("queryInferenceUrl is set but passageInferenceUrl is empty, both needs to be set"),
   286  		},
   287  	}
   288  	for _, tt := range tests {
   289  		t.Run(tt.name, func(t *testing.T) {
   290  			class := &models.Class{
   291  				Class: "MyClass",
   292  				VectorConfig: map[string]models.VectorConfig{
   293  					"namedVector": {
   294  						Vectorizer: map[string]interface{}{
   295  							"my-module": tt.vectorizer,
   296  						},
   297  					},
   298  				},
   299  				Properties: []*models.Property{{
   300  					Name: "someProp",
   301  					ModuleConfig: map[string]interface{}{
   302  						"my-module": map[string]interface{}{
   303  							"skip":                  true,
   304  							"vectorizePropertyName": true,
   305  						},
   306  					},
   307  				}},
   308  			}
   309  
   310  			cfg := modules.NewClassBasedModuleConfig(class, "my-module", "tenant", "namedVector")
   311  			ic := NewClassSettings(cfg)
   312  			err := ic.Validate(class)
   313  			if tt.wantErr != nil {
   314  				require.Error(t, err)
   315  				assert.EqualError(t, err, tt.wantErr.Error())
   316  			} else {
   317  				assert.Nil(t, err)
   318  			}
   319  		})
   320  	}
   321  }