github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/vectorizer/texts_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  // as used in the nearText searcher
    23  func TestVectorizingTexts(t *testing.T) {
    24  	type testCase struct {
    25  		name                     string
    26  		input                    []string
    27  		expectedHuggingFaceModel string
    28  		huggingFaceModel         string
    29  		huggingFaceEndpointURL   string
    30  	}
    31  
    32  	tests := []testCase{
    33  		{
    34  			name:                     "single word",
    35  			input:                    []string{"hello"},
    36  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    37  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    38  		},
    39  		{
    40  			name:                     "multiple words",
    41  			input:                    []string{"hello world, this is me!"},
    42  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    43  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    44  		},
    45  		{
    46  			name:                     "multiple sentences (joined with a dot)",
    47  			input:                    []string{"this is sentence 1", "and here's number 2"},
    48  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    49  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    50  		},
    51  		{
    52  			name:                     "multiple sentences already containing a dot",
    53  			input:                    []string{"this is sentence 1.", "and here's number 2"},
    54  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    55  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    56  		},
    57  		{
    58  			name:                     "multiple sentences already containing a question mark",
    59  			input:                    []string{"this is sentence 1?", "and here's number 2"},
    60  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    61  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    62  		},
    63  		{
    64  			name:                     "multiple sentences already containing an exclamation mark",
    65  			input:                    []string{"this is sentence 1!", "and here's number 2"},
    66  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    67  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    68  		},
    69  		{
    70  			name:                     "multiple sentences already containing comma",
    71  			input:                    []string{"this is sentence 1,", "and here's number 2"},
    72  			huggingFaceModel:         "sentence-transformers/gtr-t5-xl",
    73  			expectedHuggingFaceModel: "sentence-transformers/gtr-t5-xl",
    74  		},
    75  		{
    76  			name:                     "single word with inference url",
    77  			input:                    []string{"hello"},
    78  			huggingFaceEndpointURL:   "http://url.cloud",
    79  			expectedHuggingFaceModel: "sentence-transformers/msmarco-bert-base-dot-v5",
    80  		},
    81  	}
    82  
    83  	for _, test := range tests {
    84  		t.Run(test.name, func(t *testing.T) {
    85  			client := &fakeClient{}
    86  
    87  			v := New(client)
    88  
    89  			settings := &fakeClassConfig{
    90  				model:       test.huggingFaceModel,
    91  				endpointURL: test.huggingFaceEndpointURL,
    92  			}
    93  			vec, err := v.Texts(context.Background(), test.input, settings)
    94  
    95  			require.Nil(t, err)
    96  			assert.Equal(t, []float32{0.1, 1.1, 2.1, 3.1}, vec)
    97  			assert.Equal(t, client.lastConfig.Model, test.expectedHuggingFaceModel)
    98  		})
    99  	}
   100  }