github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/vectorizer/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "errors" 16 "testing" 17 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_getPassageModel(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantPassageModel string 27 wantQueryModel string 28 wantWaitForModel bool 29 wantUseGPU bool 30 wantUseCache bool 31 wantEndpointURL string 32 wantError error 33 }{ 34 { 35 name: "CShorten/CORD-19-Title-Abstracts", 36 cfg: fakeClassConfig{ 37 classConfig: map[string]interface{}{ 38 "model": "CShorten/CORD-19-Title-Abstracts", 39 "options": map[string]interface{}{ 40 "waitForModel": true, 41 "useGPU": false, 42 "useCache": false, 43 }, 44 }, 45 }, 46 wantPassageModel: "CShorten/CORD-19-Title-Abstracts", 47 wantQueryModel: "CShorten/CORD-19-Title-Abstracts", 48 wantWaitForModel: true, 49 wantUseGPU: false, 50 wantUseCache: false, 51 }, 52 { 53 name: "sentence-transformers/all-MiniLM-L6-v2", 54 cfg: fakeClassConfig{ 55 classConfig: map[string]interface{}{ 56 "model": "sentence-transformers/all-MiniLM-L6-v2", 57 }, 58 }, 59 wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2", 60 wantQueryModel: "sentence-transformers/all-MiniLM-L6-v2", 61 wantWaitForModel: false, 62 wantUseGPU: false, 63 wantUseCache: true, 64 }, 65 { 66 name: "DPR models", 67 cfg: fakeClassConfig{ 68 classConfig: map[string]interface{}{ 69 "passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", 70 "queryModel": "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", 71 }, 72 }, 73 wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base", 74 wantQueryModel: "sentence-transformers/facebook-dpr-question_encoder-single-nq-base", 75 wantWaitForModel: false, 76 wantUseGPU: false, 77 wantUseCache: true, 78 }, 79 { 80 name: "Hugging Face Inference API - endpointURL", 81 cfg: fakeClassConfig{ 82 classConfig: map[string]interface{}{ 83 "endpointURL": "http://endpoint.cloud", 84 }, 85 }, 86 wantPassageModel: "", 87 wantQueryModel: "", 88 wantWaitForModel: false, 89 wantUseGPU: false, 90 wantUseCache: true, 91 wantEndpointURL: "http://endpoint.cloud", 92 }, 93 { 94 name: "Hugging Face Inference API - wrong properties", 95 cfg: fakeClassConfig{ 96 classConfig: map[string]interface{}{ 97 "endpointUrl": "http://endpoint.cloud", 98 "properties": "wrong-properties", 99 }, 100 }, 101 wantPassageModel: "", 102 wantQueryModel: "", 103 wantWaitForModel: false, 104 wantUseGPU: false, 105 wantUseCache: true, 106 wantEndpointURL: "http://endpoint.cloud", 107 wantError: errors.New("properties field needs to be of array type, got: string"), 108 }, 109 } 110 for _, tt := range tests { 111 t.Run(tt.name, func(t *testing.T) { 112 ic := NewClassSettings(tt.cfg) 113 assert.Equal(t, tt.wantPassageModel, ic.getPassageModel()) 114 assert.Equal(t, tt.wantQueryModel, ic.getQueryModel()) 115 assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel()) 116 assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU()) 117 assert.Equal(t, tt.wantUseCache, ic.OptionUseCache()) 118 assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL()) 119 assert.Equal(t, tt.wantError, ic.validateClassSettings()) 120 }) 121 } 122 }