github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantModel string 27 wantMaxTokens float64 28 wantTemperature float64 29 wantTopP float64 30 wantFrequencyPenalty float64 31 wantPresencePenalty float64 32 wantResourceName string 33 wantDeploymentID string 34 wantIsAzure bool 35 wantErr error 36 wantBaseURL string 37 wantApiVersion string 38 }{ 39 { 40 name: "Happy flow", 41 cfg: fakeClassConfig{ 42 classConfig: map[string]interface{}{}, 43 }, 44 wantModel: "gpt-3.5-turbo", 45 wantMaxTokens: 4097, 46 wantTemperature: 0.0, 47 wantTopP: 1, 48 wantFrequencyPenalty: 0.0, 49 wantPresencePenalty: 0.0, 50 wantErr: nil, 51 wantBaseURL: "https://api.openai.com", 52 wantApiVersion: "2023-05-15", 53 }, 54 { 55 name: "Everything non default configured", 56 cfg: fakeClassConfig{ 57 classConfig: map[string]interface{}{ 58 "model": "gpt-3.5-turbo", 59 "maxTokens": 4097, 60 "temperature": 0.5, 61 "topP": 3, 62 "frequencyPenalty": 0.1, 63 "presencePenalty": 0.9, 64 }, 65 }, 66 wantModel: "gpt-3.5-turbo", 67 wantMaxTokens: 4097, 68 wantTemperature: 0.5, 69 wantTopP: 3, 70 wantFrequencyPenalty: 0.1, 71 wantPresencePenalty: 0.9, 72 wantErr: nil, 73 wantBaseURL: "https://api.openai.com", 74 wantApiVersion: "2023-05-15", 75 }, 76 { 77 name: "OpenAI Proxy", 78 cfg: fakeClassConfig{ 79 classConfig: map[string]interface{}{ 80 "model": "gpt-3.5-turbo", 81 "maxTokens": 4097, 82 "temperature": 0.5, 83 "topP": 3, 84 "frequencyPenalty": 0.1, 85 "presencePenalty": 0.9, 86 "baseURL": "https://proxy.weaviate.dev/", 87 }, 88 }, 89 wantBaseURL: "https://proxy.weaviate.dev/", 90 wantApiVersion: "2023-05-15", 91 wantModel: "gpt-3.5-turbo", 92 wantMaxTokens: 4097, 93 wantTemperature: 0.5, 94 wantTopP: 3, 95 wantFrequencyPenalty: 0.1, 96 wantPresencePenalty: 0.9, 97 wantErr: nil, 98 }, 99 { 100 name: "Legacy config", 101 cfg: fakeClassConfig{ 102 classConfig: map[string]interface{}{ 103 "model": "text-davinci-003", 104 "maxTokens": 1200, 105 "temperature": 0.5, 106 "topP": 3, 107 "frequencyPenalty": 0.1, 108 "presencePenalty": 0.9, 109 }, 110 }, 111 wantModel: "text-davinci-003", 112 wantMaxTokens: 1200, 113 wantTemperature: 0.5, 114 wantTopP: 3, 115 wantFrequencyPenalty: 0.1, 116 wantPresencePenalty: 0.9, 117 wantErr: nil, 118 wantBaseURL: "https://api.openai.com", 119 wantApiVersion: "2023-05-15", 120 }, 121 { 122 name: "Azure OpenAI config", 123 cfg: fakeClassConfig{ 124 classConfig: map[string]interface{}{ 125 "resourceName": "weaviate", 126 "deploymentId": "gpt-3.5-turbo", 127 "maxTokens": 4097, 128 "temperature": 0.5, 129 "topP": 3, 130 "frequencyPenalty": 0.1, 131 "presencePenalty": 0.9, 132 }, 133 }, 134 wantResourceName: "weaviate", 135 wantDeploymentID: "gpt-3.5-turbo", 136 wantIsAzure: true, 137 wantModel: "gpt-3.5-turbo", 138 wantMaxTokens: 4097, 139 wantTemperature: 0.5, 140 wantTopP: 3, 141 wantFrequencyPenalty: 0.1, 142 wantPresencePenalty: 0.9, 143 wantErr: nil, 144 wantBaseURL: "https://api.openai.com", 145 wantApiVersion: "2023-05-15", 146 }, 147 { 148 name: "Azure OpenAI config with baseURL", 149 cfg: fakeClassConfig{ 150 classConfig: map[string]interface{}{ 151 "baseURL": "some-base-url", 152 "resourceName": "weaviate", 153 "deploymentId": "gpt-3.5-turbo", 154 "maxTokens": 4097, 155 "temperature": 0.5, 156 "topP": 3, 157 "frequencyPenalty": 0.1, 158 "presencePenalty": 0.9, 159 }, 160 }, 161 wantResourceName: "weaviate", 162 wantDeploymentID: "gpt-3.5-turbo", 163 wantIsAzure: true, 164 wantModel: "gpt-3.5-turbo", 165 wantMaxTokens: 4097, 166 wantTemperature: 0.5, 167 wantTopP: 3, 168 wantFrequencyPenalty: 0.1, 169 wantPresencePenalty: 0.9, 170 wantErr: nil, 171 wantBaseURL: "some-base-url", 172 wantApiVersion: "2023-05-15", 173 }, 174 { 175 name: "With gpt-3.5-turbo-16k model", 176 cfg: fakeClassConfig{ 177 classConfig: map[string]interface{}{ 178 "model": "gpt-3.5-turbo-16k", 179 "maxTokens": 4097, 180 "temperature": 0.5, 181 "topP": 3, 182 "frequencyPenalty": 0.1, 183 "presencePenalty": 0.9, 184 }, 185 }, 186 wantModel: "gpt-3.5-turbo-16k", 187 wantMaxTokens: 4097, 188 wantTemperature: 0.5, 189 wantTopP: 3, 190 wantFrequencyPenalty: 0.1, 191 wantPresencePenalty: 0.9, 192 wantErr: nil, 193 wantBaseURL: "https://api.openai.com", 194 wantApiVersion: "2023-05-15", 195 }, 196 { 197 name: "Wrong maxTokens configured", 198 cfg: fakeClassConfig{ 199 classConfig: map[string]interface{}{ 200 "maxTokens": true, 201 }, 202 }, 203 wantErr: errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used"), 204 }, 205 { 206 name: "Wrong temperature configured", 207 cfg: fakeClassConfig{ 208 classConfig: map[string]interface{}{ 209 "temperature": true, 210 }, 211 }, 212 wantErr: errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0"), 213 }, 214 { 215 name: "Wrong frequencyPenalty configured", 216 cfg: fakeClassConfig{ 217 classConfig: map[string]interface{}{ 218 "frequencyPenalty": true, 219 }, 220 }, 221 wantErr: errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0"), 222 }, 223 { 224 name: "Wrong presencePenalty configured", 225 cfg: fakeClassConfig{ 226 classConfig: map[string]interface{}{ 227 "presencePenalty": true, 228 }, 229 }, 230 wantErr: errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0"), 231 }, 232 { 233 name: "Wrong topP configured", 234 cfg: fakeClassConfig{ 235 classConfig: map[string]interface{}{ 236 "topP": true, 237 }, 238 }, 239 wantErr: errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5"), 240 }, 241 { 242 name: "Wrong Azure config - empty deploymentId", 243 cfg: fakeClassConfig{ 244 classConfig: map[string]interface{}{ 245 "resourceName": "resource-name", 246 }, 247 }, 248 wantErr: errors.Errorf("both resourceName and deploymentId must be provided"), 249 }, 250 { 251 name: "Wrong Azure config - empty resourceName", 252 cfg: fakeClassConfig{ 253 classConfig: map[string]interface{}{ 254 "deploymentId": "deployment-name", 255 }, 256 }, 257 wantErr: errors.Errorf("both resourceName and deploymentId must be provided"), 258 }, 259 { 260 name: "Wrong Azure config - wrong api version", 261 cfg: fakeClassConfig{ 262 classConfig: map[string]interface{}{ 263 "apiVersion": "wrong-api-version", 264 }, 265 }, 266 wantErr: errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: " + 267 "[2022-12-01 2023-03-15-preview 2023-05-15 2023-06-01-preview 2023-07-01-preview " + 268 "2023-08-01-preview 2023-09-01-preview 2023-12-01-preview]"), 269 }, 270 } 271 for _, tt := range tests { 272 t.Run(tt.name, func(t *testing.T) { 273 ic := NewClassSettings(tt.cfg) 274 if tt.wantErr != nil { 275 assert.EqualError(t, tt.wantErr, ic.Validate(nil).Error()) 276 } else { 277 assert.Equal(t, tt.wantModel, ic.Model()) 278 assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens()) 279 assert.Equal(t, tt.wantTemperature, ic.Temperature()) 280 assert.Equal(t, tt.wantTopP, ic.TopP()) 281 assert.Equal(t, tt.wantFrequencyPenalty, ic.FrequencyPenalty()) 282 assert.Equal(t, tt.wantPresencePenalty, ic.PresencePenalty()) 283 assert.Equal(t, tt.wantResourceName, ic.ResourceName()) 284 assert.Equal(t, tt.wantDeploymentID, ic.DeploymentID()) 285 assert.Equal(t, tt.wantIsAzure, ic.IsAzure()) 286 assert.Equal(t, tt.wantBaseURL, ic.BaseURL()) 287 assert.Equal(t, tt.wantApiVersion, ic.ApiVersion()) 288 } 289 }) 290 } 291 } 292 293 type fakeClassConfig struct { 294 classConfig map[string]interface{} 295 } 296 297 func (f fakeClassConfig) Class() map[string]interface{} { 298 return f.classConfig 299 } 300 301 func (f fakeClassConfig) Tenant() string { 302 return "" 303 } 304 305 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 306 return f.classConfig 307 } 308 309 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 310 return nil 311 } 312 313 func (f fakeClassConfig) TargetVector() string { 314 return "" 315 }