github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/vectorizer/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"errors"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_getPassageModel(t *testing.T) {
    23  	tests := []struct {
    24  		name             string
    25  		cfg              moduletools.ClassConfig
    26  		wantPassageModel string
    27  		wantQueryModel   string
    28  		wantWaitForModel bool
    29  		wantUseGPU       bool
    30  		wantUseCache     bool
    31  		wantEndpointURL  string
    32  		wantError        error
    33  	}{
    34  		{
    35  			name: "CShorten/CORD-19-Title-Abstracts",
    36  			cfg: fakeClassConfig{
    37  				classConfig: map[string]interface{}{
    38  					"model": "CShorten/CORD-19-Title-Abstracts",
    39  					"options": map[string]interface{}{
    40  						"waitForModel": true,
    41  						"useGPU":       false,
    42  						"useCache":     false,
    43  					},
    44  				},
    45  			},
    46  			wantPassageModel: "CShorten/CORD-19-Title-Abstracts",
    47  			wantQueryModel:   "CShorten/CORD-19-Title-Abstracts",
    48  			wantWaitForModel: true,
    49  			wantUseGPU:       false,
    50  			wantUseCache:     false,
    51  		},
    52  		{
    53  			name: "sentence-transformers/all-MiniLM-L6-v2",
    54  			cfg: fakeClassConfig{
    55  				classConfig: map[string]interface{}{
    56  					"model": "sentence-transformers/all-MiniLM-L6-v2",
    57  				},
    58  			},
    59  			wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2",
    60  			wantQueryModel:   "sentence-transformers/all-MiniLM-L6-v2",
    61  			wantWaitForModel: false,
    62  			wantUseGPU:       false,
    63  			wantUseCache:     true,
    64  		},
    65  		{
    66  			name: "DPR models",
    67  			cfg: fakeClassConfig{
    68  				classConfig: map[string]interface{}{
    69  					"passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",
    70  					"queryModel":   "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",
    71  				},
    72  			},
    73  			wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",
    74  			wantQueryModel:   "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",
    75  			wantWaitForModel: false,
    76  			wantUseGPU:       false,
    77  			wantUseCache:     true,
    78  		},
    79  		{
    80  			name: "Hugging Face Inference API - endpointURL",
    81  			cfg: fakeClassConfig{
    82  				classConfig: map[string]interface{}{
    83  					"endpointURL": "http://endpoint.cloud",
    84  				},
    85  			},
    86  			wantPassageModel: "",
    87  			wantQueryModel:   "",
    88  			wantWaitForModel: false,
    89  			wantUseGPU:       false,
    90  			wantUseCache:     true,
    91  			wantEndpointURL:  "http://endpoint.cloud",
    92  		},
    93  		{
    94  			name: "Hugging Face Inference API - wrong properties",
    95  			cfg: fakeClassConfig{
    96  				classConfig: map[string]interface{}{
    97  					"endpointUrl": "http://endpoint.cloud",
    98  					"properties":  "wrong-properties",
    99  				},
   100  			},
   101  			wantPassageModel: "",
   102  			wantQueryModel:   "",
   103  			wantWaitForModel: false,
   104  			wantUseGPU:       false,
   105  			wantUseCache:     true,
   106  			wantEndpointURL:  "http://endpoint.cloud",
   107  			wantError:        errors.New("properties field needs to be of array type, got: string"),
   108  		},
   109  	}
   110  	for _, tt := range tests {
   111  		t.Run(tt.name, func(t *testing.T) {
   112  			ic := NewClassSettings(tt.cfg)
   113  			assert.Equal(t, tt.wantPassageModel, ic.getPassageModel())
   114  			assert.Equal(t, tt.wantQueryModel, ic.getQueryModel())
   115  			assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel())
   116  			assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU())
   117  			assert.Equal(t, tt.wantUseCache, ic.OptionUseCache())
   118  			assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL())
   119  			assert.Equal(t, tt.wantError, ic.validateClassSettings())
   120  		})
   121  	}
   122  }