github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantModel string 27 wantMaxTokens float64 28 wantTemperature float64 29 wantTopP float64 30 wantFrequencyPenalty float64 31 wantPresencePenalty float64 32 wantResourceName string 33 wantDeploymentID string 34 wantIsAzure bool 35 wantErr error 36 wantBaseURL string 37 }{ 38 { 39 name: "Happy flow", 40 cfg: fakeClassConfig{ 41 classConfig: map[string]interface{}{}, 42 }, 43 wantModel: "text-ada-001", 44 wantMaxTokens: 16, 45 wantTemperature: 0.0, 46 wantTopP: 1, 47 wantFrequencyPenalty: 0.0, 48 wantPresencePenalty: 0.0, 49 wantErr: nil, 50 wantBaseURL: "https://api.openai.com", 51 }, 52 { 53 name: "Everything non default configured", 54 cfg: fakeClassConfig{ 55 classConfig: map[string]interface{}{ 56 "model": "text-babbage-001", 57 "maxTokens": 100, 58 "temperature": 0.5, 59 "topP": 3, 60 "frequencyPenalty": 0.1, 61 "presencePenalty": 0.9, 62 "baseURL": "https://openai.proxy.dev", 63 }, 64 }, 65 wantModel: "text-babbage-001", 66 wantMaxTokens: 100, 67 wantTemperature: 0.5, 68 wantTopP: 3, 69 wantFrequencyPenalty: 0.1, 70 wantPresencePenalty: 0.9, 71 wantBaseURL: "https://openai.proxy.dev", 72 wantErr: nil, 73 }, 74 { 75 name: "Azure OpenAI config", 76 cfg: fakeClassConfig{ 77 classConfig: map[string]interface{}{ 78 "resourceName": "weaviate", 79 "deploymentId": "text-ada-001", 80 }, 81 }, 82 wantModel: "text-ada-001", 83 wantResourceName: "weaviate", 84 wantDeploymentID: "text-ada-001", 85 wantIsAzure: true, 86 wantMaxTokens: 16, 87 wantTemperature: 0.0, 88 wantTopP: 1, 89 wantFrequencyPenalty: 0.0, 90 wantPresencePenalty: 0.0, 91 wantErr: nil, 92 wantBaseURL: "https://api.openai.com", 93 }, 94 { 95 name: "Wrong model data type configured", 96 cfg: fakeClassConfig{ 97 classConfig: map[string]interface{}{ 98 "model": true, 99 }, 100 }, 101 wantErr: errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels), 102 }, 103 { 104 name: "Wrong model data type configured", 105 cfg: fakeClassConfig{ 106 classConfig: map[string]interface{}{ 107 "model": "this-is-a-non-existing-model", 108 }, 109 }, 110 wantErr: errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels), 111 }, 112 { 113 name: "Wrong maxTokens configured", 114 cfg: fakeClassConfig{ 115 classConfig: map[string]interface{}{ 116 "maxTokens": true, 117 }, 118 }, 119 wantErr: errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used"), 120 }, 121 { 122 name: "Wrong temperature configured", 123 cfg: fakeClassConfig{ 124 classConfig: map[string]interface{}{ 125 "temperature": true, 126 }, 127 }, 128 wantErr: errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0"), 129 }, 130 { 131 name: "Wrong frequencyPenalty configured", 132 cfg: fakeClassConfig{ 133 classConfig: map[string]interface{}{ 134 "frequencyPenalty": true, 135 }, 136 }, 137 wantErr: errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0"), 138 }, 139 { 140 name: "Wrong presencePenalty configured", 141 cfg: fakeClassConfig{ 142 classConfig: map[string]interface{}{ 143 "presencePenalty": true, 144 }, 145 }, 146 wantErr: errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0"), 147 }, 148 { 149 name: "Wrong topP configured", 150 cfg: fakeClassConfig{ 151 classConfig: map[string]interface{}{ 152 "topP": true, 153 }, 154 }, 155 wantErr: errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5"), 156 }, 157 { 158 name: "Wrong Azure OpenAI config - empty deploymentId", 159 cfg: fakeClassConfig{ 160 classConfig: map[string]interface{}{ 161 "resourceName": "resource-name", 162 }, 163 }, 164 wantErr: errors.Errorf("both resourceName and deploymentId must be provided"), 165 }, 166 { 167 name: "Wrong Azure OpenAI config - empty resourceName", 168 cfg: fakeClassConfig{ 169 classConfig: map[string]interface{}{ 170 "deploymentId": "ada", 171 }, 172 }, 173 wantErr: errors.Errorf("both resourceName and deploymentId must be provided"), 174 }, 175 } 176 for _, tt := range tests { 177 t.Run(tt.name, func(t *testing.T) { 178 ic := NewClassSettings(tt.cfg) 179 if tt.wantErr != nil { 180 assert.EqualError(t, tt.wantErr, ic.Validate(nil).Error()) 181 } else { 182 assert.Equal(t, tt.wantModel, ic.Model()) 183 assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens()) 184 assert.Equal(t, tt.wantTemperature, ic.Temperature()) 185 assert.Equal(t, tt.wantTopP, ic.TopP()) 186 assert.Equal(t, tt.wantFrequencyPenalty, ic.FrequencyPenalty()) 187 assert.Equal(t, tt.wantPresencePenalty, ic.PresencePenalty()) 188 assert.Equal(t, tt.wantResourceName, ic.ResourceName()) 189 assert.Equal(t, tt.wantDeploymentID, ic.DeploymentID()) 190 assert.Equal(t, tt.wantIsAzure, ic.IsAzure()) 191 assert.Equal(t, tt.wantBaseURL, ic.BaseURL()) 192 } 193 }) 194 } 195 } 196 197 type fakeClassConfig struct { 198 classConfig map[string]interface{} 199 } 200 201 func (f fakeClassConfig) Class() map[string]interface{} { 202 return f.classConfig 203 } 204 205 func (f fakeClassConfig) Tenant() string { 206 return "" 207 } 208 209 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 210 return f.classConfig 211 } 212 213 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 214 return nil 215 } 216 217 func (f fakeClassConfig) TargetVector() string { 218 return "" 219 }