github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "fmt" 16 "testing" 17 18 "github.com/pkg/errors" 19 "github.com/stretchr/testify/assert" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 ) 22 23 func Test_classSettings_Validate(t *testing.T) { 24 tests := []struct { 25 name string 26 cfg moduletools.ClassConfig 27 wantApiEndpoint string 28 wantProjectID string 29 wantModelID string 30 wantTemperature float64 31 wantTokenLimit int 32 wantTopK int 33 wantTopP float64 34 wantErr error 35 }{ 36 { 37 name: "happy flow", 38 cfg: fakeClassConfig{ 39 classConfig: map[string]interface{}{ 40 "projectId": "projectId", 41 }, 42 }, 43 wantApiEndpoint: "us-central1-aiplatform.googleapis.com", 44 wantProjectID: "projectId", 45 wantModelID: "chat-bison", 46 wantTemperature: 0.2, 47 wantTokenLimit: 256, 48 wantTopK: 40, 49 wantTopP: 0.95, 50 wantErr: nil, 51 }, 52 { 53 name: "custom values", 54 cfg: fakeClassConfig{ 55 classConfig: map[string]interface{}{ 56 "apiEndpoint": "google.com", 57 "projectId": "cloud-project", 58 "modelId": "model-id", 59 "temperature": 0.25, 60 "tokenLimit": 254, 61 "topK": 30, 62 "topP": 0.97, 63 }, 64 }, 65 wantApiEndpoint: "google.com", 66 wantProjectID: "cloud-project", 67 wantModelID: "model-id", 68 wantTemperature: 0.25, 69 wantTokenLimit: 254, 70 wantTopK: 30, 71 wantTopP: 0.97, 72 wantErr: nil, 73 }, 74 { 75 name: "wrong temperature", 76 cfg: fakeClassConfig{ 77 classConfig: map[string]interface{}{ 78 "projectId": "cloud-project", 79 "temperature": 2, 80 }, 81 }, 82 wantErr: errors.Errorf("temperature has to be float value between 0 and 1"), 83 }, 84 { 85 name: "wrong tokenLimit", 86 cfg: fakeClassConfig{ 87 classConfig: map[string]interface{}{ 88 "projectId": "cloud-project", 89 "tokenLimit": 2000, 90 }, 91 }, 92 wantErr: errors.Errorf("tokenLimit has to be an integer value between 1 and 1024"), 93 }, 94 { 95 name: "wrong topK", 96 cfg: fakeClassConfig{ 97 classConfig: map[string]interface{}{ 98 "projectId": "cloud-project", 99 "topK": 2000, 100 }, 101 }, 102 wantErr: errors.Errorf("topK has to be an integer value between 1 and 40"), 103 }, 104 { 105 name: "wrong topP", 106 cfg: fakeClassConfig{ 107 classConfig: map[string]interface{}{ 108 "projectId": "cloud-project", 109 "topP": 3, 110 }, 111 }, 112 wantErr: errors.Errorf("topP has to be float value between 0 and 1"), 113 }, 114 { 115 name: "wrong all", 116 cfg: fakeClassConfig{ 117 classConfig: map[string]interface{}{ 118 "projectId": "", 119 "temperature": 2, 120 "tokenLimit": 2000, 121 "topK": 2000, 122 "topP": 3, 123 }, 124 }, 125 wantErr: errors.Errorf("projectId cannot be empty, " + 126 "temperature has to be float value between 0 and 1, " + 127 "tokenLimit has to be an integer value between 1 and 1024, " + 128 "topK has to be an integer value between 1 and 40, " + 129 "topP has to be float value between 0 and 1"), 130 }, 131 { 132 name: "Generative AI", 133 cfg: fakeClassConfig{ 134 classConfig: map[string]interface{}{ 135 "apiEndpoint": "generativelanguage.googleapis.com", 136 }, 137 }, 138 wantApiEndpoint: "generativelanguage.googleapis.com", 139 wantProjectID: "", 140 wantModelID: "chat-bison-001", 141 wantTemperature: 0.2, 142 wantTokenLimit: 256, 143 wantTopK: 40, 144 wantTopP: 0.95, 145 wantErr: nil, 146 }, 147 { 148 name: "Generative AI with model", 149 cfg: fakeClassConfig{ 150 classConfig: map[string]interface{}{ 151 "apiEndpoint": "generativelanguage.googleapis.com", 152 "modelId": "chat-bison-001", 153 }, 154 }, 155 wantApiEndpoint: "generativelanguage.googleapis.com", 156 wantProjectID: "", 157 wantModelID: "chat-bison-001", 158 wantTemperature: 0.2, 159 wantTokenLimit: 256, 160 wantTopK: 40, 161 wantTopP: 0.95, 162 wantErr: nil, 163 }, 164 { 165 name: "Generative AI with gemini-ultra model", 166 cfg: fakeClassConfig{ 167 classConfig: map[string]interface{}{ 168 "apiEndpoint": "generativelanguage.googleapis.com", 169 "modelId": "gemini-ultra", 170 }, 171 }, 172 wantApiEndpoint: "generativelanguage.googleapis.com", 173 wantProjectID: "", 174 wantModelID: "gemini-ultra", 175 wantTemperature: 0.2, 176 wantTokenLimit: 256, 177 wantTopK: 40, 178 wantTopP: 0.95, 179 wantErr: nil, 180 }, 181 { 182 name: "Generative AI with not supported model", 183 cfg: fakeClassConfig{ 184 classConfig: map[string]interface{}{ 185 "apiEndpoint": "generativelanguage.googleapis.com", 186 "modelId": "unsupported-model", 187 }, 188 }, 189 wantErr: fmt.Errorf("unsupported-model is not supported available models are: [chat-bison-001 gemini-pro gemini-pro-vision gemini-ultra]"), 190 }, 191 } 192 for _, tt := range tests { 193 t.Run(tt.name, func(t *testing.T) { 194 ic := NewClassSettings(tt.cfg) 195 if tt.wantErr != nil { 196 assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error()) 197 } else { 198 assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint()) 199 assert.Equal(t, tt.wantProjectID, ic.ProjectID()) 200 assert.Equal(t, tt.wantModelID, ic.ModelID()) 201 assert.Equal(t, tt.wantTemperature, ic.Temperature()) 202 assert.Equal(t, tt.wantTokenLimit, ic.TokenLimit()) 203 assert.Equal(t, tt.wantTopK, ic.TopK()) 204 assert.Equal(t, tt.wantTopP, ic.TopP()) 205 } 206 }) 207 } 208 } 209 210 type fakeClassConfig struct { 211 classConfig map[string]interface{} 212 } 213 214 func (f fakeClassConfig) Class() map[string]interface{} { 215 return f.classConfig 216 } 217 218 func (f fakeClassConfig) Tenant() string { 219 return "" 220 } 221 222 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 223 return f.classConfig 224 } 225 226 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 227 return nil 228 } 229 230 func (f fakeClassConfig) TargetVector() string { 231 return "" 232 }