github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/clients/bert_embeddings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "reflect" 16 "testing" 17 ) 18 19 func Test_bertEmbeddingsDecoder_calculateVector(t *testing.T) { 20 tests := []struct { 21 name string 22 embeddings [][]float32 23 want []float32 24 wantErr bool 25 }{ 26 { 27 name: "nil", 28 embeddings: nil, 29 wantErr: true, 30 }, 31 { 32 name: "empty", 33 embeddings: [][]float32{}, 34 wantErr: true, 35 }, 36 { 37 name: "just one vector", 38 embeddings: [][]float32{{-0.17978577315807343}}, 39 want: []float32{-0.17978577315807343}, 40 }, 41 { 42 name: "distilbert-base-uncased", 43 embeddings: [][]float32{ 44 {-0.17978577315807343, -0.0678672045469284, 0.1706605851650238, -0.1639413982629776, -0.12804915010929108, 0.017568372189998627, 0.1610901951789856, 0.19909054040908813, -0.26103103160858154, -0.14505508542060852}, 45 {-0.25516796112060547, -0.054695576429367065, 0.13527897000312805, -0.3919253945350647, 0.1900954395532608, 0.5994636416435242, 0.5798457264900208, 0.6522972583770752, -0.08617493510246277, -0.35053199529647827}, 46 {0.930827260017395, 0.3315476179122925, -0.323006272315979, 0.18198077380657196, -0.3299236297607422, -0.5998684763908386, 0.3299814462661743, -0.6352149844169617, 0.5154204368591309, 0.11740084737539291}, 47 }, 48 want: []float32{0.1652911752462387, 0.06966160982847214, -0.005688905715942383, -0.12462866306304932, -0.08929244428873062, 0.005721171852201223, 0.35697245597839355, 0.07205760478973389, 0.05607149004936218, -0.1260620802640915}, 49 }, 50 } 51 for _, tt := range tests { 52 t.Run(tt.name, func(t *testing.T) { 53 d := bertEmbeddingsDecoder{} 54 got, err := d.calculateVector(tt.embeddings) 55 if (err != nil) != tt.wantErr { 56 t.Errorf("bertEmbeddingsDecoder.calculateVector() error = %v, wantErr %v", err, tt.wantErr) 57 return 58 } 59 if !reflect.DeepEqual(got, tt.want) { 60 t.Errorf("bertEmbeddingsDecoder.calculateVector() = %v, want %v", got, tt.want) 61 } 62 }) 63 } 64 }