github.com/weaviate/weaviate@v1.24.6/modules/text2vec-palm/vectorizer/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantApiEndpoint string 27 wantProjectID string 28 wantModelID string 29 wantTitle string 30 wantErr error 31 }{ 32 { 33 name: "happy flow", 34 cfg: fakeClassConfig{ 35 classConfig: map[string]interface{}{ 36 "projectId": "projectId", 37 }, 38 }, 39 wantApiEndpoint: "us-central1-aiplatform.googleapis.com", 40 wantProjectID: "projectId", 41 wantModelID: "textembedding-gecko@001", 42 wantErr: nil, 43 }, 44 { 45 name: "custom values", 46 cfg: fakeClassConfig{ 47 classConfig: map[string]interface{}{ 48 "apiEndpoint": "google.com", 49 "projectId": "projectId", 50 "titleProperty": "title", 51 }, 52 }, 53 wantApiEndpoint: "google.com", 54 wantProjectID: "projectId", 55 wantModelID: "textembedding-gecko@001", 56 wantTitle: "title", 57 wantErr: nil, 58 }, 59 { 60 name: "empty projectId", 61 cfg: fakeClassConfig{ 62 classConfig: map[string]interface{}{ 63 "projectId": "", 64 }, 65 }, 66 wantErr: errors.Errorf("projectId cannot be empty"), 67 }, 68 { 69 name: "wrong modelId", 70 cfg: fakeClassConfig{ 71 classConfig: map[string]interface{}{ 72 "projectId": "projectId", 73 "modelId": "wrong-model", 74 }, 75 }, 76 wantErr: errors.Errorf("wrong modelId available model names are: " + 77 "[textembedding-gecko@001 textembedding-gecko@latest " + 78 "textembedding-gecko-multilingual@latest textembedding-gecko@003 " + 79 "textembedding-gecko@002 textembedding-gecko-multilingual@001 textembedding-gecko@001]"), 80 }, 81 { 82 name: "all wrong", 83 cfg: fakeClassConfig{ 84 classConfig: map[string]interface{}{ 85 "projectId": "", 86 "modelId": "wrong-model", 87 }, 88 }, 89 wantErr: errors.Errorf("projectId cannot be empty, " + 90 "wrong modelId available model names are: " + 91 "[textembedding-gecko@001 textembedding-gecko@latest " + 92 "textembedding-gecko-multilingual@latest textembedding-gecko@003 " + 93 "textembedding-gecko@002 textembedding-gecko-multilingual@001 textembedding-gecko@001]"), 94 }, 95 { 96 name: "Generative AI", 97 cfg: fakeClassConfig{ 98 classConfig: map[string]interface{}{ 99 "apiEndpoint": "generativelanguage.googleapis.com", 100 }, 101 }, 102 wantApiEndpoint: "generativelanguage.googleapis.com", 103 wantProjectID: "", 104 wantModelID: "embedding-gecko-001", 105 wantErr: nil, 106 }, 107 { 108 name: "Generative AI with model", 109 cfg: fakeClassConfig{ 110 classConfig: map[string]interface{}{ 111 "apiEndpoint": "generativelanguage.googleapis.com", 112 "modelId": "embedding-gecko-001", 113 }, 114 }, 115 wantApiEndpoint: "generativelanguage.googleapis.com", 116 wantProjectID: "", 117 wantModelID: "embedding-gecko-001", 118 wantErr: nil, 119 }, 120 { 121 name: "Generative AI with wrong model", 122 cfg: fakeClassConfig{ 123 classConfig: map[string]interface{}{ 124 "apiEndpoint": "generativelanguage.googleapis.com", 125 "modelId": "textembedding-gecko@001", 126 }, 127 }, 128 wantErr: errors.Errorf("wrong modelId available Generative AI model names are: [embedding-gecko-001]"), 129 }, 130 { 131 name: "wrong properties", 132 cfg: fakeClassConfig{ 133 classConfig: map[string]interface{}{ 134 "projectId": "projectId", 135 }, 136 properties: "wrong-properties", 137 }, 138 wantApiEndpoint: "us-central1-aiplatform.googleapis.com", 139 wantProjectID: "projectId", 140 wantModelID: "textembedding-gecko@001", 141 wantErr: errors.New("properties field needs to be of array type, got: string"), 142 }, 143 } 144 for _, tt := range tests { 145 t.Run(tt.name, func(t *testing.T) { 146 ic := NewClassSettings(tt.cfg) 147 if tt.wantErr != nil { 148 assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error()) 149 } else { 150 assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint()) 151 assert.Equal(t, tt.wantProjectID, ic.ProjectID()) 152 assert.Equal(t, tt.wantModelID, ic.ModelID()) 153 assert.Equal(t, tt.wantTitle, ic.TitleProperty()) 154 } 155 }) 156 } 157 }