github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/clients/bert_embeddings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"reflect"
    16  	"testing"
    17  )
    18  
    19  func Test_bertEmbeddingsDecoder_calculateVector(t *testing.T) {
    20  	tests := []struct {
    21  		name       string
    22  		embeddings [][]float32
    23  		want       []float32
    24  		wantErr    bool
    25  	}{
    26  		{
    27  			name:       "nil",
    28  			embeddings: nil,
    29  			wantErr:    true,
    30  		},
    31  		{
    32  			name:       "empty",
    33  			embeddings: [][]float32{},
    34  			wantErr:    true,
    35  		},
    36  		{
    37  			name:       "just one vector",
    38  			embeddings: [][]float32{{-0.17978577315807343}},
    39  			want:       []float32{-0.17978577315807343},
    40  		},
    41  		{
    42  			name: "distilbert-base-uncased",
    43  			embeddings: [][]float32{
    44  				{-0.17978577315807343, -0.0678672045469284, 0.1706605851650238, -0.1639413982629776, -0.12804915010929108, 0.017568372189998627, 0.1610901951789856, 0.19909054040908813, -0.26103103160858154, -0.14505508542060852},
    45  				{-0.25516796112060547, -0.054695576429367065, 0.13527897000312805, -0.3919253945350647, 0.1900954395532608, 0.5994636416435242, 0.5798457264900208, 0.6522972583770752, -0.08617493510246277, -0.35053199529647827},
    46  				{0.930827260017395, 0.3315476179122925, -0.323006272315979, 0.18198077380657196, -0.3299236297607422, -0.5998684763908386, 0.3299814462661743, -0.6352149844169617, 0.5154204368591309, 0.11740084737539291},
    47  			},
    48  			want: []float32{0.1652911752462387, 0.06966160982847214, -0.005688905715942383, -0.12462866306304932, -0.08929244428873062, 0.005721171852201223, 0.35697245597839355, 0.07205760478973389, 0.05607149004936218, -0.1260620802640915},
    49  		},
    50  	}
    51  	for _, tt := range tests {
    52  		t.Run(tt.name, func(t *testing.T) {
    53  			d := bertEmbeddingsDecoder{}
    54  			got, err := d.calculateVector(tt.embeddings)
    55  			if (err != nil) != tt.wantErr {
    56  				t.Errorf("bertEmbeddingsDecoder.calculateVector() error = %v, wantErr %v", err, tt.wantErr)
    57  				return
    58  			}
    59  			if !reflect.DeepEqual(got, tt.want) {
    60  				t.Errorf("bertEmbeddingsDecoder.calculateVector() = %v, want %v", got, tt.want)
    61  			}
    62  		})
    63  	}
    64  }